This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 81e19be4b chore: Add unit tests for `CometExecRule` (#2863)
81e19be4b is described below
commit 81e19be4bfdf058f04f8168e084a002838f9d419
Author: Andy Grove <[email protected]>
AuthorDate: Wed Dec 10 08:01:23 2025 -0700
chore: Add unit tests for `CometExecRule` (#2863)
---
.github/workflows/pr_build_linux.yml | 1 +
.github/workflows/pr_build_macos.yml | 1 +
.../main/scala/org/apache/comet/CometConf.scala | 16 ++
.../org/apache/spark/sql/comet/operators.scala | 16 +-
.../apache/comet/rules/CometExecRuleSuite.scala | 232 +++++++++++++++++++++
5 files changed, 265 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/pr_build_linux.yml
b/.github/workflows/pr_build_linux.yml
index 0fd28cb58..59be24370 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -141,6 +141,7 @@ jobs:
org.apache.spark.CometPluginsDefaultSuite
org.apache.spark.CometPluginsNonOverrideSuite
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+ org.apache.comet.rules.CometExecRuleSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
org.apache.spark.sql.CometTPCHQuerySuite
diff --git a/.github/workflows/pr_build_macos.yml
b/.github/workflows/pr_build_macos.yml
index e915fa74a..4cee7395b 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -106,6 +106,7 @@ jobs:
org.apache.spark.CometPluginsDefaultSuite
org.apache.spark.CometPluginsNonOverrideSuite
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+ org.apache.comet.rules.CometExecRuleSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
org.apache.spark.sql.CometTPCHQuerySuite
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 647572fc0..7eaae6055 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -654,6 +654,22 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(COMET_SCHEMA_EVOLUTION_ENABLED_DEFAULT)
+ val COMET_ENABLE_PARTIAL_HASH_AGGREGATE: ConfigEntry[Boolean] =
+ conf("spark.comet.testing.aggregate.partialMode.enabled")
+ .internal()
+ .category(CATEGORY_TESTING)
+ .doc("This setting is used in unit tests")
+ .booleanConf
+ .createWithDefault(true)
+
+ val COMET_ENABLE_FINAL_HASH_AGGREGATE: ConfigEntry[Boolean] =
+ conf("spark.comet.testing.aggregate.finalMode.enabled")
+ .internal()
+ .category(CATEGORY_TESTING)
+ .doc("This setting is used in unit tests")
+ .booleanConf
+ .createWithDefault(true)
+
val COMET_SPARK_TO_ARROW_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.sparkToColumnar.enabled")
.category(CATEGORY_TESTING)
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index b64fb14ac..0a435e5b7 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateMode, Final, Partial}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateMode, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
@@ -1234,6 +1234,20 @@ object CometHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)
+ override def getSupportLevel(op: HashAggregateExec): SupportLevel = {
+ // some unit tests need to disable partial or final hash aggregate support
to test that
+ // CometExecRule does not allow mixed Spark/Comet aggregates
+ if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
+ op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode
== PartialMerge)) {
+ return Unsupported(Some("Partial aggregates disabled via test config"))
+ }
+ if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
+ op.aggregateExpressions.exists(_.mode == Final)) {
+ return Unsupported(Some("Final aggregates disabled via test config"))
+ }
+ Compatible()
+ }
+
override def convert(
aggregate: HashAggregateExec,
builder: Operator.Builder,
diff --git
a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala
b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala
new file mode 100644
index 000000000..cf6f8918f
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.rules
+
+import scala.util.Random
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.CometConf
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
+
+/**
+ * Test suite specifically for CometExecRule transformation logic. Tests the
rule's ability to
+ * transform Spark operators to Comet operators, fallback mechanisms,
configuration handling, and
+ * edge cases.
+ */
+class CometExecRuleSuite extends CometTestBase {
+
+ /** Helper method to apply CometExecRule and return the transformed plan */
+ private def applyCometExecRule(plan: SparkPlan): SparkPlan = {
+ CometExecRule(spark).apply(stripAQEPlan(plan))
+ }
+
+ /** Create a test data frame that is used in all tests */
+ private def createTestDataFrame = {
+ val testSchema = new StructType(
+ Array(
+ StructField("id", DataTypes.IntegerType, nullable = true),
+ StructField("name", DataTypes.StringType, nullable = true)))
+ FuzzDataGenerator.generateDataFrame(new Random(42), spark, testSchema,
100, DataGenOptions())
+ }
+
+ /** Create a SparkPlan from the specified SQL with Comet disabled */
+ private def createSparkPlan(spark: SparkSession, sql: String): SparkPlan = {
+ var sparkPlan: SparkPlan = null
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark.sql(sql)
+ sparkPlan = df.queryExecution.executedPlan
+ }
+ sparkPlan
+ }
+
+ /** Count the number of the specified operator in the plan */
+ private def countOperators(plan: SparkPlan, opClass: Class[_]): Int = {
+ stripAQEPlan(plan).collect {
+ case stage: QueryStageExec =>
+ countOperators(stage.plan, opClass)
+ case op if op.getClass.isAssignableFrom(opClass) => 1
+ }.sum
+ }
+
+ test(
+ "CometExecRule should apply basic operator transformations, but only when
Comet is enabled") {
+ withTempView("test_data") {
+ createTestDataFrame.createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT id, id * 2 as doubled FROM test_data
WHERE id % 2 == 0")
+
+ // Count original Spark operators
+ assert(countOperators(sparkPlan, classOf[ProjectExec]) == 1)
+ assert(countOperators(sparkPlan, classOf[FilterExec]) == 1)
+
+ for (cometEnabled <- Seq(true, false)) {
+ withSQLConf(
+ CometConf.COMET_ENABLED.key -> cometEnabled.toString,
+ CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
+
+ val transformedPlan = applyCometExecRule(sparkPlan)
+
+ if (cometEnabled) {
+ assert(countOperators(transformedPlan, classOf[ProjectExec]) == 0)
+ assert(countOperators(transformedPlan, classOf[FilterExec]) == 0)
+ assert(countOperators(transformedPlan, classOf[CometProjectExec])
== 1)
+ assert(countOperators(transformedPlan, classOf[CometFilterExec])
== 1)
+ } else {
+ assert(countOperators(transformedPlan, classOf[ProjectExec]) == 1)
+ assert(countOperators(transformedPlan, classOf[FilterExec]) == 1)
+ assert(countOperators(transformedPlan, classOf[CometProjectExec])
== 0)
+ assert(countOperators(transformedPlan, classOf[CometFilterExec])
== 0)
+ }
+ }
+ }
+ }
+ }
+
+ test("CometExecRule should apply hash aggregate transformations") {
+ withTempView("test_data") {
+ createTestDataFrame.createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT COUNT(*), SUM(id) FROM test_data GROUP
BY (id % 3)")
+
+ // Count original Spark operators
+ val originalHashAggCount = countOperators(sparkPlan,
classOf[HashAggregateExec])
+ assert(originalHashAggCount == 2)
+
+ withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true")
{
+ val transformedPlan = applyCometExecRule(sparkPlan)
+
+ assert(countOperators(transformedPlan, classOf[HashAggregateExec]) ==
0)
+ assert(
+ countOperators(
+ transformedPlan,
+ classOf[CometHashAggregateExec]) == originalHashAggCount)
+ }
+ }
+ }
+
+ // TODO this test exposes the bug described in
+ // https://github.com/apache/datafusion-comet/issues/1389
+ ignore("CometExecRule should not allow Comet partial and Spark final hash
aggregate") {
+ withTempView("test_data") {
+ createTestDataFrame.createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT COUNT(*), SUM(id) FROM test_data GROUP
BY (id % 3)")
+
+ // Count original Spark operators
+ val originalHashAggCount = countOperators(sparkPlan,
classOf[HashAggregateExec])
+ assert(originalHashAggCount == 2)
+
+ withSQLConf(
+ CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false",
+ CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
+ val transformedPlan = applyCometExecRule(sparkPlan)
+
+ // if the final aggregate cannot be converted to Comet, then neither
should be
+ assert(
+ countOperators(transformedPlan, classOf[HashAggregateExec]) ==
originalHashAggCount)
+ assert(countOperators(transformedPlan,
classOf[CometHashAggregateExec]) == 0)
+ }
+ }
+ }
+
+ test("CometExecRule should not allow Spark partial and Comet final hash
aggregate") {
+ withTempView("test_data") {
+ createTestDataFrame.createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT COUNT(*), SUM(id) FROM test_data GROUP
BY (id % 3)")
+
+ // Count original Spark operators
+ val originalHashAggCount = countOperators(sparkPlan,
classOf[HashAggregateExec])
+ assert(originalHashAggCount == 2)
+
+ withSQLConf(
+ CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
+ CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
+ val transformedPlan = applyCometExecRule(sparkPlan)
+
+ // if the partial aggregate cannot be converted to Comet, then neither
should be
+ assert(
+ countOperators(transformedPlan, classOf[HashAggregateExec]) ==
originalHashAggCount)
+ assert(countOperators(transformedPlan,
classOf[CometHashAggregateExec]) == 0)
+ }
+ }
+ }
+
+ test("CometExecRule should apply broadcast exchange transformations") {
+ withTempView("test_data") {
+ createTestDataFrame.createOrReplaceTempView("test_data")
+
+ val sparkPlan = createSparkPlan(
+ spark,
+ "SELECT /*+ BROADCAST(b) */ a.id, b.name FROM test_data a JOIN
test_data b ON a.id = b.id")
+
+ // Count original Spark operators
+ val originalBroadcastExchangeCount =
+ countOperators(sparkPlan, classOf[BroadcastExchangeExec])
+ assert(originalBroadcastExchangeCount == 1)
+
+ withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true")
{
+ val transformedPlan = applyCometExecRule(sparkPlan)
+
+ assert(countOperators(transformedPlan, classOf[BroadcastExchangeExec])
== 0)
+ assert(
+ countOperators(
+ transformedPlan,
+ classOf[CometBroadcastExchangeExec]) ==
originalBroadcastExchangeCount)
+ }
+ }
+ }
+
+ test("CometExecRule should apply shuffle exchange transformations") {
+ withTempView("test_data") {
+ createTestDataFrame.createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT id, COUNT(*) FROM test_data GROUP BY id
ORDER BY id")
+
+ // Count original Spark operators
+ val originalShuffleExchangeCount = countOperators(sparkPlan,
classOf[ShuffleExchangeExec])
+ assert(originalShuffleExchangeCount == 2)
+
+ withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true")
{
+ val transformedPlan = applyCometExecRule(sparkPlan)
+
+ assert(countOperators(transformedPlan, classOf[ShuffleExchangeExec])
== 0)
+ assert(
+ countOperators(
+ transformedPlan,
+ classOf[CometShuffleExchangeExec]) == originalShuffleExchangeCount)
+ }
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]