This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 977a189b1 fix: Fall back to Spark for `trunc` / `date_trunc` functions 
when format string is unsupported, or is not a literal value (#2634)
977a189b1 is described below

commit 977a189b1dda912b4f2fb3893f8340ceb7c12aa0
Author: Andy Grove <[email protected]>
AuthorDate: Thu Oct 30 14:58:22 2025 -0600

    fix: Fall back to Spark for `trunc` / `date_trunc` functions when format 
string is unsupported, or is not a literal value (#2634)
---
 .github/workflows/pr_build_linux.yml               |   1 +
 .github/workflows/pr_build_macos.yml               |   1 +
 .../main/scala/org/apache/comet/CometConf.scala    |   4 +
 .../scala/org/apache/comet/serde/datetime.scala    |  54 +++++++++
 .../org/apache/comet/CometExpressionSuite.scala    |  30 ++---
 .../comet/CometTemporalExpressionSuite.scala       | 125 +++++++++++++++++++++
 6 files changed, 202 insertions(+), 13 deletions(-)

diff --git a/.github/workflows/pr_build_linux.yml 
b/.github/workflows/pr_build_linux.yml
index 118b756ee..2867f61da 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -141,6 +141,7 @@ jobs:
               org.apache.spark.CometPluginsDefaultSuite
               org.apache.spark.CometPluginsNonOverrideSuite
               org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+              org.apache.comet.CometTemporalExpressionSuite
               org.apache.spark.sql.CometTPCDSQuerySuite
               org.apache.spark.sql.CometTPCDSQueryTestSuite
               org.apache.spark.sql.CometTPCHQuerySuite
diff --git a/.github/workflows/pr_build_macos.yml 
b/.github/workflows/pr_build_macos.yml
index 465533834..0fd1cb606 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -106,6 +106,7 @@ jobs:
               org.apache.spark.CometPluginsDefaultSuite
               org.apache.spark.CometPluginsNonOverrideSuite
               org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+              org.apache.comet.CometTemporalExpressionSuite
               org.apache.spark.sql.CometTPCDSQuerySuite
               org.apache.spark.sql.CometTPCDSQueryTestSuite
               org.apache.spark.sql.CometTPCHQuerySuite
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index e3ff30eb4..d48d14972 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -740,6 +740,10 @@ object CometConf extends ShimCometConf {
     s"${CometConf.COMET_EXPR_CONFIG_PREFIX}.$name.allowIncompatible"
   }
 
+  def getExprAllowIncompatConfigKey(exprClass: Class[_]): String = {
+    
s"${CometConf.COMET_EXPR_CONFIG_PREFIX}.${exprClass.getSimpleName}.allowIncompatible"
+  }
+
   def getBooleanConf(name: String, defaultValue: Boolean, conf: SQLConf): 
Boolean = {
     conf.getConfString(name, defaultValue.toString).toLowerCase(Locale.ROOT) 
== "true"
   }
diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala 
b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
index 267b2b441..ef2b0f793 100644
--- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
@@ -19,8 +19,11 @@
 
 package org.apache.comet.serde
 
+import java.util.Locale
+
 import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateSub, 
DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month, 
Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year}
 import org.apache.spark.sql.types.{DateType, IntegerType}
+import org.apache.spark.unsafe.types.UTF8String
 
 import org.apache.comet.CometSparkSessionExtensions.withInfo
 import org.apache.comet.serde.CometGetDateField.CometGetDateField
@@ -256,6 +259,24 @@ object CometDateAdd extends 
CometScalarFunction[DateAdd]("date_add")
 object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
 
 object CometTruncDate extends CometExpressionSerde[TruncDate] {
+
+  val supportedFormats: Seq[String] =
+    Seq("year", "yyyy", "yy", "quarter", "mon", "month", "mm", "week")
+
+  override def getSupportLevel(expr: TruncDate): SupportLevel = {
+    expr.format match {
+      case Literal(fmt: UTF8String, _) =>
+        if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
+          Compatible()
+        } else {
+          Unsupported(Some(s"Format $fmt is not supported"))
+        }
+      case _ =>
+        Incompatible(
+          Some("Invalid format strings will throw an exception instead of 
returning NULL"))
+    }
+  }
+
   override def convert(
       expr: TruncDate,
       inputs: Seq[Attribute],
@@ -274,6 +295,39 @@ object CometTruncDate extends 
CometExpressionSerde[TruncDate] {
 }
 
 object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
+
+  val supportedFormats: Seq[String] =
+    Seq(
+      "year",
+      "yyyy",
+      "yy",
+      "quarter",
+      "mon",
+      "month",
+      "mm",
+      "week",
+      "day",
+      "dd",
+      "hour",
+      "minute",
+      "second",
+      "millisecond",
+      "microsecond")
+
+  override def getSupportLevel(expr: TruncTimestamp): SupportLevel = {
+    expr.format match {
+      case Literal(fmt: UTF8String, _) =>
+        if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
+          Compatible()
+        } else {
+          Unsupported(Some(s"Format $fmt is not supported"))
+        }
+      case _ =>
+        Incompatible(
+          Some("Invalid format strings will throw an exception instead of 
returning NULL"))
+    }
+  }
+
   override def convert(
       expr: TruncTimestamp,
       inputs: Seq[Attribute],
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 54d336221..7b6ed1945 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -30,7 +30,7 @@ import org.scalatest.Tag
 
 import org.apache.hadoop.fs.Path
 import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal, 
TruncDate, TruncTimestamp}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
 import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec, 
CometWindowExec}
 import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan, 
WholeStageCodegenExec}
@@ -706,11 +706,13 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         val path = new Path(dir.toURI.toString, 
"date_trunc_with_format.parquet")
         makeDateTimeWithFormatTable(path, dictionaryEnabled = 
dictionaryEnabled, numRows)
         withParquetTable(path.toString, "dateformattbl") {
-          checkSparkAnswerAndOperator(
-            "SELECT " +
-              "dateformat, _7, " +
-              "trunc(_7, dateformat) " +
-              " from dateformattbl ")
+          
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[TruncDate]) -> 
"true") {
+            checkSparkAnswerAndOperator(
+              "SELECT " +
+                "dateformat, _7, " +
+                "trunc(_7, dateformat) " +
+                " from dateformattbl ")
+          }
         }
       }
     }
@@ -787,13 +789,15 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
   }
 
   test("date_trunc with format array") {
-    withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
-      val numRows = 1000
-      Seq(true, false).foreach { dictionaryEnabled =>
-        withTempDir { dir =>
-          val path = new Path(dir.toURI.toString, 
"timestamp_trunc_with_format.parquet")
-          makeDateTimeWithFormatTable(path, dictionaryEnabled = 
dictionaryEnabled, numRows)
-          withParquetTable(path.toString, "timeformattbl") {
+    val numRows = 1000
+    Seq(true, false).foreach { dictionaryEnabled =>
+      withTempDir { dir =>
+        val path = new Path(dir.toURI.toString, 
"timestamp_trunc_with_format.parquet")
+        makeDateTimeWithFormatTable(path, dictionaryEnabled = 
dictionaryEnabled, numRows)
+        withParquetTable(path.toString, "timeformattbl") {
+          withSQLConf(
+            CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
+            CometConf.getExprAllowIncompatConfigKey(classOf[TruncTimestamp]) 
-> "true") {
             checkSparkAnswerAndOperator(
               "SELECT " +
                 "format, _0, _1, _2, _3, _4, _5, " +
diff --git 
a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala
new file mode 100644
index 000000000..9a23c76d8
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import scala.util.Random
+
+import org.apache.spark.sql.{CometTestBase, SaveMode}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.serde.{CometTruncDate, CometTruncTimestamp}
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
+
+class CometTemporalExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
+
+  test("trunc (TruncDate)") {
+    val supportedFormats = CometTruncDate.supportedFormats
+    val unsupportedFormats = Seq("invalid")
+
+    val r = new Random(42)
+    val schema = StructType(
+      Seq(
+        StructField("c0", DataTypes.DateType, true),
+        StructField("c1", DataTypes.StringType, true)))
+    val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, 
DataGenOptions())
+
+    df.createOrReplaceTempView("tbl")
+
+    for (format <- supportedFormats) {
+      checkSparkAnswerAndOperator(s"SELECT c0, trunc(c0, '$format') from tbl 
order by c0, c1")
+    }
+    for (format <- unsupportedFormats) {
+      // Comet should fall back to Spark for unsupported or invalid formats
+      checkSparkAnswerAndFallbackReason(
+        s"SELECT c0, trunc(c0, '$format') from tbl order by c0, c1",
+        s"Format $format is not supported")
+    }
+
+    // Comet should fall back to Spark if format is not a literal
+    checkSparkAnswerAndFallbackReason(
+      "SELECT c0, trunc(c0, c1) from tbl order by c0, c1",
+      "Invalid format strings will throw an exception instead of returning 
NULL")
+  }
+
+  test("date_trunc (TruncTimestamp) - reading from DataFrame") {
+    val supportedFormats = CometTruncTimestamp.supportedFormats
+    val unsupportedFormats = Seq("invalid")
+
+    createTimestampTestData.createOrReplaceTempView("tbl")
+
+    // TODO test fails with non-UTC timezone
+    // https://github.com/apache/datafusion-comet/issues/2649
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+      for (format <- supportedFormats) {
+        checkSparkAnswerAndOperator(s"SELECT c0, date_trunc('$format', c0) 
from tbl order by c0")
+      }
+      for (format <- unsupportedFormats) {
+        // Comet should fall back to Spark for unsupported or invalid formats
+        checkSparkAnswerAndFallbackReason(
+          s"SELECT c0, date_trunc('$format', c0) from tbl order by c0",
+          s"Format $format is not supported")
+      }
+      // Comet should fall back to Spark if format is not a literal
+      checkSparkAnswerAndFallbackReason(
+        "SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt",
+        "Invalid format strings will throw an exception instead of returning 
NULL")
+    }
+  }
+
+  test("date_trunc (TruncTimestamp) - reading from Parquet") {
+    val supportedFormats = CometTruncTimestamp.supportedFormats
+    val unsupportedFormats = Seq("invalid")
+
+    withTempDir { path =>
+      
createTimestampTestData.write.mode(SaveMode.Overwrite).parquet(path.toString)
+      spark.read.parquet(path.toString).createOrReplaceTempView("tbl")
+
+      // TODO test fails with non-UTC timezone
+      // https://github.com/apache/datafusion-comet/issues/2649
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+        for (format <- supportedFormats) {
+          checkSparkAnswerAndOperator(
+            s"SELECT c0, date_trunc('$format', c0) from tbl order by c0")
+        }
+        for (format <- unsupportedFormats) {
+          // Comet should fall back to Spark for unsupported or invalid formats
+          checkSparkAnswerAndFallbackReason(
+            s"SELECT c0, date_trunc('$format', c0) from tbl order by c0",
+            s"Format $format is not supported")
+        }
+        // Comet should fall back to Spark if format is not a literal
+        checkSparkAnswerAndFallbackReason(
+          "SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt",
+          "Invalid format strings will throw an exception instead of returning 
NULL")
+      }
+    }
+  }
+
+  private def createTimestampTestData = {
+    val r = new Random(42)
+    val schema = StructType(
+      Seq(
+        StructField("c0", DataTypes.TimestampType, true),
+        StructField("fmt", DataTypes.StringType, true)))
+    FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000, 
DataGenOptions())
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to