This is an automated email from the ASF dual-hosted git repository.
comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 1f75eda09 chore: Implement date_trunc as ScalarUDFImpl (#1880)
1f75eda09 is described below
commit 1f75eda09890a30903bfd9a7e02c2287588b8d76
Author: Leung Ming <[email protected]>
AuthorDate: Tue Jun 17 01:49:15 2025 +0800
chore: Implement date_trunc as ScalarUDFImpl (#1880)
---
native/core/src/execution/planner.rs | 17 ++--
native/proto/src/proto/expr.proto | 6 --
native/spark-expr/src/datetime_funcs/date_trunc.rs | 92 ++++++++--------------
native/spark-expr/src/datetime_funcs/mod.rs | 2 +-
native/spark-expr/src/lib.rs | 2 +-
.../org/apache/comet/serde/QueryPlanSerde.scala | 18 +----
6 files changed, 43 insertions(+), 94 deletions(-)
diff --git a/native/core/src/execution/planner.rs
b/native/core/src/execution/planner.rs
index 5d5d39635..09853b6d4 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -66,6 +66,7 @@ use datafusion::{
};
use datafusion_comet_spark_expr::{
create_comet_physical_fun, create_negate_expr, SparkBitwiseCount,
SparkBitwiseNot,
+ SparkDateTrunc,
};
use crate::execution::operators::ExecutionError::GeneralError;
@@ -105,10 +106,10 @@ use datafusion_comet_proto::{
};
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Contains, Correlation,
Covariance,
- CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields,
GetStructField, HourExpr,
- IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike,
SecondExpr,
- SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr,
SumDecimal,
- TimestampTruncExpr, ToJson, UnboundColumn, Variance,
+ CreateNamedStruct, EndsWith, GetArrayStructFields, GetStructField,
HourExpr, IfExpr, Like,
+ ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr,
SparkCastOptions, StartsWith,
+ Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr,
ToJson, UnboundColumn,
+ Variance,
};
use datafusion_spark::function::math::expm1::SparkExpm1;
use itertools::Itertools;
@@ -158,6 +159,7 @@ impl PhysicalPlanner {
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseCount::default()));
+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateTrunc::default()));
Self {
exec_context_id: TEST_EXEC_CONTEXT_ID,
session_ctx,
@@ -475,13 +477,6 @@ impl PhysicalPlanner {
Ok(Arc::new(SecondExpr::new(child, timezone)))
}
- ExprStruct::TruncDate(expr) => {
- let child =
- self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&input_schema))?;
- let format = self.create_expr(expr.format.as_ref().unwrap(),
input_schema)?;
-
- Ok(Arc::new(DateTruncExpr::new(child, format)))
- }
ExprStruct::TruncTimestamp(expr) => {
let child =
self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&input_schema))?;
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index d74e675f7..4a1f6eb4f 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -70,7 +70,6 @@ message Expr {
BinaryExpr bitwiseShiftLeft = 43;
IfExpr if = 44;
NormalizeNaNAndZero normalize_nan_and_zero = 45;
- TruncDate truncDate = 46;
TruncTimestamp truncTimestamp = 47;
Abs abs = 49;
Subquery subquery = 50;
@@ -344,11 +343,6 @@ message IfExpr {
Expr false_expr = 3;
}
-message TruncDate {
- Expr child = 1;
- Expr format = 2;
-}
-
message TruncTimestamp {
Expr format = 1;
Expr child = 2;
diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs
b/native/spark-expr/src/datetime_funcs/date_trunc.rs
index 1f91ba64b..861f5a2ae 100644
--- a/native/spark-expr/src/datetime_funcs/date_trunc.rs
+++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs
@@ -15,76 +15,58 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
-use datafusion::common::{DataFusionError, ScalarValue::Utf8};
-use datafusion::logical_expr::ColumnarValue;
-use datafusion::physical_expr::PhysicalExpr;
-use std::hash::Hash;
-use std::{
- any::Any,
- fmt::{Debug, Display, Formatter},
- sync::Arc,
+use arrow::datatypes::DataType;
+use datafusion::common::{utils::take_function_args, DataFusionError, Result,
ScalarValue::Utf8};
+use datafusion::logical_expr::{
+ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
+use std::any::Any;
use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};
-#[derive(Debug, Eq)]
-pub struct DateTruncExpr {
- /// An array with DataType::Date32
- child: Arc<dyn PhysicalExpr>,
- /// Scalar UTF8 string matching the valid values in Spark SQL:
https://spark.apache.org/docs/latest/api/sql/index.html#trunc
- format: Arc<dyn PhysicalExpr>,
+#[derive(Debug)]
+pub struct SparkDateTrunc {
+ signature: Signature,
+ aliases: Vec<String>,
}
-impl Hash for DateTruncExpr {
- fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
- self.child.hash(state);
- self.format.hash(state);
- }
-}
-impl PartialEq for DateTruncExpr {
- fn eq(&self, other: &Self) -> bool {
- self.child.eq(&other.child) && self.format.eq(&other.format)
- }
-}
-
-impl DateTruncExpr {
- pub fn new(child: Arc<dyn PhysicalExpr>, format: Arc<dyn PhysicalExpr>) ->
Self {
- DateTruncExpr { child, format }
+impl SparkDateTrunc {
+ pub fn new() -> Self {
+ Self {
+ signature: Signature::exact(
+ vec![DataType::Date32, DataType::Utf8],
+ Volatility::Immutable,
+ ),
+ aliases: vec![],
+ }
}
}
-impl Display for DateTruncExpr {
- fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
- write!(
- f,
- "DateTrunc [child:{}, format: {}]",
- self.child, self.format
- )
+impl Default for SparkDateTrunc {
+ fn default() -> Self {
+ Self::new()
}
}
-impl PhysicalExpr for DateTruncExpr {
+impl ScalarUDFImpl for SparkDateTrunc {
fn as_any(&self) -> &dyn Any {
self
}
- fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
- unimplemented!()
+ fn name(&self) -> &str {
+ "date_trunc"
}
- fn data_type(&self, input_schema: &Schema) ->
datafusion::common::Result<DataType> {
- self.child.data_type(input_schema)
+ fn signature(&self) -> &Signature {
+ &self.signature
}
- fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
- Ok(true)
+ fn return_type(&self, _: &[DataType]) -> Result<DataType> {
+ Ok(DataType::Date32)
}
- fn evaluate(&self, batch: &RecordBatch) ->
datafusion::common::Result<ColumnarValue> {
- let date = self.child.evaluate(batch)?;
- let format = self.format.evaluate(batch)?;
+ fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
+ let [date, format] = take_function_args(self.name(), args.args)?;
match (date, format) {
(ColumnarValue::Array(date),
ColumnarValue::Scalar(Utf8(Some(format)))) => {
let result = date_trunc_dyn(&date, format)?;
@@ -101,17 +83,7 @@ impl PhysicalExpr for DateTruncExpr {
}
}
- fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
- vec![&self.child]
- }
-
- fn with_new_children(
- self: Arc<Self>,
- children: Vec<Arc<dyn PhysicalExpr>>,
- ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
- Ok(Arc::new(DateTruncExpr::new(
- Arc::clone(&children[0]),
- Arc::clone(&self.format),
- )))
+ fn aliases(&self) -> &[String] {
+ &self.aliases
}
}
diff --git a/native/spark-expr/src/datetime_funcs/mod.rs
b/native/spark-expr/src/datetime_funcs/mod.rs
index 1f4d42728..e0baa1fce 100644
--- a/native/spark-expr/src/datetime_funcs/mod.rs
+++ b/native/spark-expr/src/datetime_funcs/mod.rs
@@ -23,7 +23,7 @@ mod second;
mod timestamp_trunc;
pub use date_arithmetic::{spark_date_add, spark_date_sub};
-pub use date_trunc::DateTruncExpr;
+pub use date_trunc::SparkDateTrunc;
pub use hour::HourExpr;
pub use minute::MinuteExpr;
pub use second::SecondExpr;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index ae8e639b3..c2aac93e2 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -60,7 +60,7 @@ pub use conversion_funcs::*;
pub use comet_scalar_funcs::create_comet_physical_fun;
pub use datetime_funcs::{
- spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr,
SecondExpr,
+ spark_date_add, spark_date_sub, HourExpr, MinuteExpr, SecondExpr,
SparkDateTrunc,
TimestampTruncExpr,
};
pub use error::{SparkError, SparkResult};
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 13bea457d..90a90e773 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1044,21 +1044,9 @@ object QueryPlanSerde extends Logging with CometExprShim
{
case TruncDate(child, format) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
val formatExpr = exprToProtoInternal(format, inputs, binding)
-
- if (childExpr.isDefined && formatExpr.isDefined) {
- val builder = ExprOuterClass.TruncDate.newBuilder()
- builder.setChild(childExpr.get)
- builder.setFormat(formatExpr.get)
-
- Some(
- ExprOuterClass.Expr
- .newBuilder()
- .setTruncDate(builder)
- .build())
- } else {
- withInfo(expr, child, format)
- None
- }
+ val optExpr =
+ scalarFunctionExprToProtoWithReturnType("date_trunc", DateType,
childExpr, formatExpr)
+ optExprWithInfo(optExpr, expr, child, format)
case TruncTimestamp(format, child, timeZoneId) =>
val childExpr = exprToProtoInternal(child, inputs, binding)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]