Skip to content
Open
25 changes: 21 additions & 4 deletions native/spark-expr/src/datetime_funcs/extract_date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::utils::array_with_timezone;
use arrow::compute::{date_part, DatePart};
use arrow::datatypes::{DataType, TimeUnit::Microsecond};
use datafusion::common::{internal_datafusion_err, DataFusionError};
use datafusion::common::{internal_datafusion_err, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -86,9 +86,26 @@ macro_rules! extract_date_part {
let result = date_part(&array, DatePart::$date_part_variant)?;
Ok(ColumnarValue::Array(result))
}
_ => Err(DataFusionError::Execution(
concat!($fn_name, "(scalar) should be fold in Spark JVM side.").to_string(),
)),
[ColumnarValue::Scalar(scalar)] => {
// When Spark's ConstantFolding is disabled, literal-only expressions like
// hour can reach the native engine as scalar inputs.
// Instead of failing and requiring JVM folding, we evaluate the scalar
// natively by broadcasting it to a single-element array and then
// converting the result back to a scalar.
let array = scalar.clone().to_array_of_size(1)?;
let array = array_with_timezone(
array,
self.timezone.clone(),
Some(&DataType::Timestamp(
Microsecond,
Some(self.timezone.clone().into()),
)),
)?;
let result = date_part(&array, DatePart::$date_part_variant)?;
let scalar_result = ScalarValue::try_from_array(&result, 0)?;

Ok(ColumnarValue::Scalar(scalar_result))
}
}
}

Expand Down
173 changes: 99 additions & 74 deletions native/spark-expr/src/datetime_funcs/unix_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::utils::array_with_timezone;
use arrow::array::{Array, AsArray, PrimitiveArray};
use arrow::compute::cast;
use arrow::datatypes::{DataType, Int64Type, TimeUnit::Microsecond};
use datafusion::common::{internal_datafusion_err, DataFusionError};
use datafusion::common::{internal_datafusion_err, DataFusionError, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
Expand Down Expand Up @@ -77,79 +77,28 @@ impl ScalarUDFImpl for SparkUnixTimestamp {
.map_err(|_| internal_datafusion_err!("unix_timestamp expects exactly one argument"))?;

match args {
[ColumnarValue::Array(array)] => match array.data_type() {
DataType::Timestamp(_, _) => {
let is_utc = self.timezone == "UTC";
let array = if is_utc
&& matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC")
{
array
} else {
array_with_timezone(
array,
self.timezone.clone(),
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
)?
};

let timestamp_array =
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();

let result: PrimitiveArray<Int64Type> = if timestamp_array.null_count() == 0 {
timestamp_array
.values()
.iter()
.map(|&micros| micros / MICROS_PER_SECOND)
.collect()
} else {
timestamp_array
.iter()
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
.collect()
};

Ok(ColumnarValue::Array(Arc::new(result)))
}
DataType::Date32 => {
let timestamp_array = cast(&array, &DataType::Timestamp(Microsecond, None))?;

let is_utc = self.timezone == "UTC";
let array = if is_utc {
timestamp_array
} else {
array_with_timezone(
timestamp_array,
self.timezone.clone(),
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
)?
};

let timestamp_array =
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();

let result: PrimitiveArray<Int64Type> = if timestamp_array.null_count() == 0 {
timestamp_array
.values()
.iter()
.map(|&micros| micros / MICROS_PER_SECOND)
.collect()
} else {
timestamp_array
.iter()
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
.collect()
};

Ok(ColumnarValue::Array(Arc::new(result)))
}
_ => Err(DataFusionError::Execution(format!(
"unix_timestamp does not support input type: {:?}",
array.data_type()
))),
},
_ => Err(DataFusionError::Execution(
"unix_timestamp(scalar) should be fold in Spark JVM side.".to_string(),
)),
[ColumnarValue::Array(array)] => self.eval_array(&array),
[ColumnarValue::Scalar(scalar)] => {
// When Spark's ConstantFolding is disabled, literal-only expressions like
// unix_timestamp can reach the native engine
// as scalar inputs. Evaluate the scalar natively by broadcasting it to a
// single-element array and converting the result back to a scalar.
let array = scalar.clone().to_array_of_size(1)?;
let result = self.eval_array(&array)?;

let result_array = match result {
ColumnarValue::Array(array) => array,
ColumnarValue::Scalar(_) => {
return Err(DataFusionError::Internal(
"unix_timestamp: expected array result from eval_array".to_string(),
))
}
};

let scalar_result = ScalarValue::try_from_array(&result_array, 0)?;

Ok(ColumnarValue::Scalar(scalar_result))
}
}
}

Expand All @@ -158,6 +107,82 @@ impl ScalarUDFImpl for SparkUnixTimestamp {
}
}

impl SparkUnixTimestamp {
fn eval_array(&self, array: &Arc<dyn Array>) -> datafusion::common::Result<ColumnarValue> {
match array.data_type() {
DataType::Timestamp(_, _) => {
let is_utc = self.timezone == "UTC";
let array = if is_utc
&& matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC")
{
Arc::clone(array)
} else {
array_with_timezone(
Arc::clone(array),
self.timezone.clone(),
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
)?
};

let timestamp_array =
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();

let result: PrimitiveArray<Int64Type> = if timestamp_array.null_count() == 0 {
timestamp_array
.values()
.iter()
.map(|&micros| micros / MICROS_PER_SECOND)
.collect()
} else {
timestamp_array
.iter()
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
.collect()
};

Ok(ColumnarValue::Array(Arc::new(result)))
}
DataType::Date32 => {
let timestamp_array =
cast(array.as_ref(), &DataType::Timestamp(Microsecond, None))?;

let is_utc = self.timezone == "UTC";
let array = if is_utc {
timestamp_array
} else {
array_with_timezone(
timestamp_array,
self.timezone.clone(),
Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))),
)?
};

let timestamp_array =
array.as_primitive::<arrow::datatypes::TimestampMicrosecondType>();

let result: PrimitiveArray<Int64Type> = if timestamp_array.null_count() == 0 {
timestamp_array
.values()
.iter()
.map(|&micros| micros / MICROS_PER_SECOND)
.collect()
} else {
timestamp_array
.iter()
.map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND)))
.collect()
};

Ok(ColumnarValue::Array(Arc::new(result)))
}
_ => Err(DataFusionError::Execution(format!(
"unix_timestamp does not support input type: {:?}",
array.data_type()
))),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ query
SELECT hour(ts) FROM test_hour

-- literal arguments
query ignore(https://github.com/apache/datafusion-comet/issues/3336)
query
SELECT hour(timestamp('2024-01-15 00:00:00')), hour(timestamp('2024-01-15 12:30:45')), hour(timestamp('2024-01-15 23:59:59'))
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ query
SELECT minute(ts) FROM test_minute

-- literal arguments
query ignore(https://github.com/apache/datafusion-comet/issues/3336)
query
SELECT minute(timestamp('2024-01-15 10:00:00')), minute(timestamp('2024-01-15 10:30:00')), minute(timestamp('2024-01-15 10:59:59'))
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ query
SELECT second(ts) FROM test_second

-- literal arguments
query ignore(https://github.com/apache/datafusion-comet/issues/3336)
query
SELECT second(timestamp('2024-01-15 10:30:00')), second(timestamp('2024-01-15 10:30:30')), second(timestamp('2024-01-15 10:30:59'))
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ query
SELECT unix_timestamp(ts) FROM test_unix_ts

-- literal arguments
query ignore(https://github.com/apache/datafusion-comet/issues/3336)
query
SELECT unix_timestamp(timestamp('1970-01-01 00:00:00')), unix_timestamp(timestamp('2024-06-15 10:30:45'))