diff --git a/native/spark-expr/src/datetime_funcs/extract_date_part.rs b/native/spark-expr/src/datetime_funcs/extract_date_part.rs index acb7d2266e..e7c50b12b0 100644 --- a/native/spark-expr/src/datetime_funcs/extract_date_part.rs +++ b/native/spark-expr/src/datetime_funcs/extract_date_part.rs @@ -18,7 +18,7 @@ use crate::utils::array_with_timezone; use arrow::compute::{date_part, DatePart}; use arrow::datatypes::{DataType, TimeUnit::Microsecond}; -use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::common::{internal_datafusion_err, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -86,9 +86,26 @@ macro_rules! extract_date_part { let result = date_part(&array, DatePart::$date_part_variant)?; Ok(ColumnarValue::Array(result)) } - _ => Err(DataFusionError::Execution( - concat!($fn_name, "(scalar) should be fold in Spark JVM side.").to_string(), - )), + [ColumnarValue::Scalar(scalar)] => { + // When Spark's ConstantFolding is disabled, literal-only expressions like + // hour can reach the native engine as scalar inputs. + // Instead of failing and requiring JVM folding, we evaluate the scalar + // natively by broadcasting it to a single-element array and then + // converting the result back to a scalar. + let array = scalar.clone().to_array_of_size(1)?; + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::$date_part_variant)?; + let scalar_result = ScalarValue::try_from_array(&result, 0)?; + + Ok(ColumnarValue::Scalar(scalar_result)) + } } } diff --git a/native/spark-expr/src/datetime_funcs/unix_timestamp.rs b/native/spark-expr/src/datetime_funcs/unix_timestamp.rs index c4f1576293..fb1cdfd5ee 100644 --- a/native/spark-expr/src/datetime_funcs/unix_timestamp.rs +++ b/native/spark-expr/src/datetime_funcs/unix_timestamp.rs @@ -19,7 +19,7 @@ use crate::utils::array_with_timezone; use arrow::array::{Array, AsArray, PrimitiveArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Int64Type, TimeUnit::Microsecond}; -use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::common::{internal_datafusion_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -77,79 +77,28 @@ impl ScalarUDFImpl for SparkUnixTimestamp { .map_err(|_| internal_datafusion_err!("unix_timestamp expects exactly one argument"))?; match args { - [ColumnarValue::Array(array)] => match array.data_type() { - DataType::Timestamp(_, _) => { - let is_utc = self.timezone == "UTC"; - let array = if is_utc - && matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC") - { - array - } else { - array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), - )? - }; - - let timestamp_array = - array.as_primitive::(); - - let result: PrimitiveArray = if timestamp_array.null_count() == 0 { - timestamp_array - .values() - .iter() - .map(|µs| micros / MICROS_PER_SECOND) - .collect() - } else { - timestamp_array - .iter() - .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) - .collect() - }; - - Ok(ColumnarValue::Array(Arc::new(result))) - } - DataType::Date32 => { - let timestamp_array = cast(&array, &DataType::Timestamp(Microsecond, None))?; - - let is_utc = self.timezone == "UTC"; - let array = if is_utc { - timestamp_array - } else { - array_with_timezone( - timestamp_array, - self.timezone.clone(), - Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), - )? - }; - - let timestamp_array = - array.as_primitive::(); - - let result: PrimitiveArray = if timestamp_array.null_count() == 0 { - timestamp_array - .values() - .iter() - .map(|µs| micros / MICROS_PER_SECOND) - .collect() - } else { - timestamp_array - .iter() - .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) - .collect() - }; - - Ok(ColumnarValue::Array(Arc::new(result))) - } - _ => Err(DataFusionError::Execution(format!( - "unix_timestamp does not support input type: {:?}", - array.data_type() - ))), - }, - _ => Err(DataFusionError::Execution( - "unix_timestamp(scalar) should be fold in Spark JVM side.".to_string(), - )), + [ColumnarValue::Array(array)] => self.eval_array(&array), + [ColumnarValue::Scalar(scalar)] => { + // When Spark's ConstantFolding is disabled, literal-only expressions like + // unix_timestamp can reach the native engine + // as scalar inputs. Evaluate the scalar natively by broadcasting it to a + // single-element array and converting the result back to a scalar. + let array = scalar.clone().to_array_of_size(1)?; + let result = self.eval_array(&array)?; + + let result_array = match result { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(_) => { + return Err(DataFusionError::Internal( + "unix_timestamp: expected array result from eval_array".to_string(), + )) + } + }; + + let scalar_result = ScalarValue::try_from_array(&result_array, 0)?; + + Ok(ColumnarValue::Scalar(scalar_result)) + } } } @@ -158,6 +107,82 @@ impl ScalarUDFImpl for SparkUnixTimestamp { } } +impl SparkUnixTimestamp { + fn eval_array(&self, array: &Arc) -> datafusion::common::Result { + match array.data_type() { + DataType::Timestamp(_, _) => { + let is_utc = self.timezone == "UTC"; + let array = if is_utc + && matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC") + { + Arc::clone(array) + } else { + array_with_timezone( + Arc::clone(array), + self.timezone.clone(), + Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), + )? + }; + + let timestamp_array = + array.as_primitive::(); + + let result: PrimitiveArray = if timestamp_array.null_count() == 0 { + timestamp_array + .values() + .iter() + .map(|µs| micros / MICROS_PER_SECOND) + .collect() + } else { + timestamp_array + .iter() + .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) + .collect() + }; + + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Date32 => { + let timestamp_array = + cast(array.as_ref(), &DataType::Timestamp(Microsecond, None))?; + + let is_utc = self.timezone == "UTC"; + let array = if is_utc { + timestamp_array + } else { + array_with_timezone( + timestamp_array, + self.timezone.clone(), + Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), + )? + }; + + let timestamp_array = + array.as_primitive::(); + + let result: PrimitiveArray = if timestamp_array.null_count() == 0 { + timestamp_array + .values() + .iter() + .map(|µs| micros / MICROS_PER_SECOND) + .collect() + } else { + timestamp_array + .iter() + .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) + .collect() + }; + + Ok(ColumnarValue::Array(Arc::new(result))) + } + _ => Err(DataFusionError::Execution(format!( + "unix_timestamp does not support input type: {:?}", + array.data_type() + ))), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/hour.sql b/spark/src/test/resources/sql-tests/expressions/datetime/hour.sql index 551dd60c06..325d7fe264 100644 --- a/spark/src/test/resources/sql-tests/expressions/datetime/hour.sql +++ b/spark/src/test/resources/sql-tests/expressions/datetime/hour.sql @@ -27,5 +27,5 @@ query SELECT hour(ts) FROM test_hour -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3336) +query SELECT hour(timestamp('2024-01-15 00:00:00')), hour(timestamp('2024-01-15 12:30:45')), hour(timestamp('2024-01-15 23:59:59')) diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/minute.sql b/spark/src/test/resources/sql-tests/expressions/datetime/minute.sql index 0b75084352..8b9dc544c5 100644 --- a/spark/src/test/resources/sql-tests/expressions/datetime/minute.sql +++ b/spark/src/test/resources/sql-tests/expressions/datetime/minute.sql @@ -27,5 +27,5 @@ query SELECT minute(ts) FROM test_minute -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3336) +query SELECT minute(timestamp('2024-01-15 10:00:00')), minute(timestamp('2024-01-15 10:30:00')), minute(timestamp('2024-01-15 10:59:59')) diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/second.sql b/spark/src/test/resources/sql-tests/expressions/datetime/second.sql index 5db4e9e743..99ee008232 100644 --- a/spark/src/test/resources/sql-tests/expressions/datetime/second.sql +++ b/spark/src/test/resources/sql-tests/expressions/datetime/second.sql @@ -27,5 +27,5 @@ query SELECT second(ts) FROM test_second -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3336) +query SELECT second(timestamp('2024-01-15 10:30:00')), second(timestamp('2024-01-15 10:30:30')), second(timestamp('2024-01-15 10:30:59')) diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/unix_timestamp.sql b/spark/src/test/resources/sql-tests/expressions/datetime/unix_timestamp.sql index fa4d9df07f..bbd8d107fd 100644 --- a/spark/src/test/resources/sql-tests/expressions/datetime/unix_timestamp.sql +++ b/spark/src/test/resources/sql-tests/expressions/datetime/unix_timestamp.sql @@ -27,5 +27,5 @@ query SELECT unix_timestamp(ts) FROM test_unix_ts -- literal arguments -query ignore(https://github.com/apache/datafusion-comet/issues/3336) +query SELECT unix_timestamp(timestamp('1970-01-01 00:00:00')), unix_timestamp(timestamp('2024-06-15 10:30:45'))