This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 977a189b1 fix: Fall back to Spark for `trunc` / `date_trunc` functions
when format string is unsupported, or is not a literal value (#2634)
977a189b1 is described below
commit 977a189b1dda912b4f2fb3893f8340ceb7c12aa0
Author: Andy Grove <[email protected]>
AuthorDate: Thu Oct 30 14:58:22 2025 -0600
fix: Fall back to Spark for `trunc` / `date_trunc` functions when format
string is unsupported, or is not a literal value (#2634)
---
.github/workflows/pr_build_linux.yml | 1 +
.github/workflows/pr_build_macos.yml | 1 +
.../main/scala/org/apache/comet/CometConf.scala | 4 +
.../scala/org/apache/comet/serde/datetime.scala | 54 +++++++++
.../org/apache/comet/CometExpressionSuite.scala | 30 ++---
.../comet/CometTemporalExpressionSuite.scala | 125 +++++++++++++++++++++
6 files changed, 202 insertions(+), 13 deletions(-)
diff --git a/.github/workflows/pr_build_linux.yml
b/.github/workflows/pr_build_linux.yml
index 118b756ee..2867f61da 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -141,6 +141,7 @@ jobs:
org.apache.spark.CometPluginsDefaultSuite
org.apache.spark.CometPluginsNonOverrideSuite
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+ org.apache.comet.CometTemporalExpressionSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
org.apache.spark.sql.CometTPCHQuerySuite
diff --git a/.github/workflows/pr_build_macos.yml
b/.github/workflows/pr_build_macos.yml
index 465533834..0fd1cb606 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -106,6 +106,7 @@ jobs:
org.apache.spark.CometPluginsDefaultSuite
org.apache.spark.CometPluginsNonOverrideSuite
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+ org.apache.comet.CometTemporalExpressionSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
org.apache.spark.sql.CometTPCHQuerySuite
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index e3ff30eb4..d48d14972 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -740,6 +740,10 @@ object CometConf extends ShimCometConf {
s"${CometConf.COMET_EXPR_CONFIG_PREFIX}.$name.allowIncompatible"
}
+ def getExprAllowIncompatConfigKey(exprClass: Class[_]): String = {
+
s"${CometConf.COMET_EXPR_CONFIG_PREFIX}.${exprClass.getSimpleName}.allowIncompatible"
+ }
+
def getBooleanConf(name: String, defaultValue: Boolean, conf: SQLConf):
Boolean = {
conf.getConfString(name, defaultValue.toString).toLowerCase(Locale.ROOT)
== "true"
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala
b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
index 267b2b441..ef2b0f793 100644
--- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
@@ -19,8 +19,11 @@
package org.apache.comet.serde
+import java.util.Locale
+
import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateSub,
DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, Literal, Minute, Month,
Quarter, Second, TruncDate, TruncTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.types.{DateType, IntegerType}
+import org.apache.spark.unsafe.types.UTF8String
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.CometGetDateField.CometGetDateField
@@ -256,6 +259,24 @@ object CometDateAdd extends
CometScalarFunction[DateAdd]("date_add")
object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
object CometTruncDate extends CometExpressionSerde[TruncDate] {
+
+ val supportedFormats: Seq[String] =
+ Seq("year", "yyyy", "yy", "quarter", "mon", "month", "mm", "week")
+
+ override def getSupportLevel(expr: TruncDate): SupportLevel = {
+ expr.format match {
+ case Literal(fmt: UTF8String, _) =>
+ if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
+ Compatible()
+ } else {
+ Unsupported(Some(s"Format $fmt is not supported"))
+ }
+ case _ =>
+ Incompatible(
+ Some("Invalid format strings will throw an exception instead of
returning NULL"))
+ }
+ }
+
override def convert(
expr: TruncDate,
inputs: Seq[Attribute],
@@ -274,6 +295,39 @@ object CometTruncDate extends
CometExpressionSerde[TruncDate] {
}
object CometTruncTimestamp extends CometExpressionSerde[TruncTimestamp] {
+
+ val supportedFormats: Seq[String] =
+ Seq(
+ "year",
+ "yyyy",
+ "yy",
+ "quarter",
+ "mon",
+ "month",
+ "mm",
+ "week",
+ "day",
+ "dd",
+ "hour",
+ "minute",
+ "second",
+ "millisecond",
+ "microsecond")
+
+ override def getSupportLevel(expr: TruncTimestamp): SupportLevel = {
+ expr.format match {
+ case Literal(fmt: UTF8String, _) =>
+ if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
+ Compatible()
+ } else {
+ Unsupported(Some(s"Format $fmt is not supported"))
+ }
+ case _ =>
+ Incompatible(
+ Some("Invalid format strings will throw an exception instead of
returning NULL"))
+ }
+ }
+
override def convert(
expr: TruncTimestamp,
inputs: Seq[Attribute],
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 54d336221..7b6ed1945 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -30,7 +30,7 @@ import org.scalatest.Tag
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, Literal,
TruncDate, TruncTimestamp}
import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec,
CometWindowExec}
import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, SparkPlan,
WholeStageCodegenExec}
@@ -706,11 +706,13 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
val path = new Path(dir.toURI.toString,
"date_trunc_with_format.parquet")
makeDateTimeWithFormatTable(path, dictionaryEnabled =
dictionaryEnabled, numRows)
withParquetTable(path.toString, "dateformattbl") {
- checkSparkAnswerAndOperator(
- "SELECT " +
- "dateformat, _7, " +
- "trunc(_7, dateformat) " +
- " from dateformattbl ")
+
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[TruncDate]) ->
"true") {
+ checkSparkAnswerAndOperator(
+ "SELECT " +
+ "dateformat, _7, " +
+ "trunc(_7, dateformat) " +
+ " from dateformattbl ")
+ }
}
}
}
@@ -787,13 +789,15 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
test("date_trunc with format array") {
- withSQLConf(CometConf.COMET_EXPR_ALLOW_INCOMPATIBLE.key -> "true") {
- val numRows = 1000
- Seq(true, false).foreach { dictionaryEnabled =>
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString,
"timestamp_trunc_with_format.parquet")
- makeDateTimeWithFormatTable(path, dictionaryEnabled =
dictionaryEnabled, numRows)
- withParquetTable(path.toString, "timeformattbl") {
+ val numRows = 1000
+ Seq(true, false).foreach { dictionaryEnabled =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString,
"timestamp_trunc_with_format.parquet")
+ makeDateTimeWithFormatTable(path, dictionaryEnabled =
dictionaryEnabled, numRows)
+ withParquetTable(path.toString, "timeformattbl") {
+ withSQLConf(
+ CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
+ CometConf.getExprAllowIncompatConfigKey(classOf[TruncTimestamp])
-> "true") {
checkSparkAnswerAndOperator(
"SELECT " +
"format, _0, _1, _2, _3, _4, _5, " +
diff --git
a/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala
new file mode 100644
index 000000000..9a23c76d8
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/CometTemporalExpressionSuite.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import scala.util.Random
+
+import org.apache.spark.sql.{CometTestBase, SaveMode}
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.serde.{CometTruncDate, CometTruncTimestamp}
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
+
+class CometTemporalExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
+
+ test("trunc (TruncDate)") {
+ val supportedFormats = CometTruncDate.supportedFormats
+ val unsupportedFormats = Seq("invalid")
+
+ val r = new Random(42)
+ val schema = StructType(
+ Seq(
+ StructField("c0", DataTypes.DateType, true),
+ StructField("c1", DataTypes.StringType, true)))
+ val df = FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000,
DataGenOptions())
+
+ df.createOrReplaceTempView("tbl")
+
+ for (format <- supportedFormats) {
+ checkSparkAnswerAndOperator(s"SELECT c0, trunc(c0, '$format') from tbl
order by c0, c1")
+ }
+ for (format <- unsupportedFormats) {
+ // Comet should fall back to Spark for unsupported or invalid formats
+ checkSparkAnswerAndFallbackReason(
+ s"SELECT c0, trunc(c0, '$format') from tbl order by c0, c1",
+ s"Format $format is not supported")
+ }
+
+ // Comet should fall back to Spark if format is not a literal
+ checkSparkAnswerAndFallbackReason(
+ "SELECT c0, trunc(c0, c1) from tbl order by c0, c1",
+ "Invalid format strings will throw an exception instead of returning
NULL")
+ }
+
+ test("date_trunc (TruncTimestamp) - reading from DataFrame") {
+ val supportedFormats = CometTruncTimestamp.supportedFormats
+ val unsupportedFormats = Seq("invalid")
+
+ createTimestampTestData.createOrReplaceTempView("tbl")
+
+ // TODO test fails with non-UTC timezone
+ // https://github.com/apache/datafusion-comet/issues/2649
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+ for (format <- supportedFormats) {
+ checkSparkAnswerAndOperator(s"SELECT c0, date_trunc('$format', c0)
from tbl order by c0")
+ }
+ for (format <- unsupportedFormats) {
+ // Comet should fall back to Spark for unsupported or invalid formats
+ checkSparkAnswerAndFallbackReason(
+ s"SELECT c0, date_trunc('$format', c0) from tbl order by c0",
+ s"Format $format is not supported")
+ }
+ // Comet should fall back to Spark if format is not a literal
+ checkSparkAnswerAndFallbackReason(
+ "SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt",
+ "Invalid format strings will throw an exception instead of returning
NULL")
+ }
+ }
+
+ test("date_trunc (TruncTimestamp) - reading from Parquet") {
+ val supportedFormats = CometTruncTimestamp.supportedFormats
+ val unsupportedFormats = Seq("invalid")
+
+ withTempDir { path =>
+
createTimestampTestData.write.mode(SaveMode.Overwrite).parquet(path.toString)
+ spark.read.parquet(path.toString).createOrReplaceTempView("tbl")
+
+ // TODO test fails with non-UTC timezone
+ // https://github.com/apache/datafusion-comet/issues/2649
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+ for (format <- supportedFormats) {
+ checkSparkAnswerAndOperator(
+ s"SELECT c0, date_trunc('$format', c0) from tbl order by c0")
+ }
+ for (format <- unsupportedFormats) {
+ // Comet should fall back to Spark for unsupported or invalid formats
+ checkSparkAnswerAndFallbackReason(
+ s"SELECT c0, date_trunc('$format', c0) from tbl order by c0",
+ s"Format $format is not supported")
+ }
+ // Comet should fall back to Spark if format is not a literal
+ checkSparkAnswerAndFallbackReason(
+ "SELECT c0, date_trunc(fmt, c0) from tbl order by c0, fmt",
+ "Invalid format strings will throw an exception instead of returning
NULL")
+ }
+ }
+ }
+
+ private def createTimestampTestData = {
+ val r = new Random(42)
+ val schema = StructType(
+ Seq(
+ StructField("c0", DataTypes.TimestampType, true),
+ StructField("fmt", DataTypes.StringType, true)))
+ FuzzDataGenerator.generateDataFrame(r, spark, schema, 1000,
DataGenOptions())
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]