diff --git a/src/databricks/sql/parameters/native.py b/src/databricks/sql/parameters/native.py index b7c448254..d0fb8d82c 100644 --- a/src/databricks/sql/parameters/native.py +++ b/src/databricks/sql/parameters/native.py @@ -659,7 +659,7 @@ def dbsql_parameter_from_primitive( elif isinstance(value, str): return StringParameter(value=value, name=name) elif isinstance(value, float): - return FloatParameter(value=value, name=name) + return DoubleParameter(value=value, name=name) elif isinstance(value, datetime.datetime): return TimestampParameter(value=value, name=name) elif isinstance(value, datetime.date): diff --git a/tests/unit/test_parameters.py b/tests/unit/test_parameters.py index cf2e24951..e6024cd1a 100644 --- a/tests/unit/test_parameters.py +++ b/tests/unit/test_parameters.py @@ -295,7 +295,7 @@ def test_tspark_param_ordinal(self): (BigIntegerParameter, Primitive.BIGINT), (BooleanParameter, Primitive.BOOL), (DateParameter, Primitive.DATE), - (FloatParameter, Primitive.FLOAT), + (DoubleParameter, Primitive.FLOAT), (VoidParameter, Primitive.NONE), (TimestampParameter, Primitive.TIMESTAMP), (MapParameter, Primitive.MAP), @@ -305,7 +305,7 @@ def test_tspark_param_ordinal(self): def test_inference(self, _type: TDbsqlParameter, prim: Primitive): """This method only tests inferrable types. - Not tested are TinyIntParameter, SmallIntParameter DoubleParameter and TimestampNTZParameter + Not tested are TinyIntParameter, SmallIntParameter, FloatParameter and TimestampNTZParameter """ inferred_type = dbsql_parameter_from_primitive(prim.value)