This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new abcb4992c chore: Add unit tests for `CometScanRule` (#2867)
abcb4992c is described below
commit abcb4992c715bd6a414e012929ffaaa5f60ba8d6
Author: Andy Grove <[email protected]>
AuthorDate: Wed Dec 10 09:56:00 2025 -0700
chore: Add unit tests for `CometScanRule` (#2867)
---
.github/workflows/pr_build_linux.yml | 1 +
.github/workflows/pr_build_macos.yml | 1 +
.../apache/comet/rules/CometScanRuleSuite.scala | 181 +++++++++++++++++++++
3 files changed, 183 insertions(+)
diff --git a/.github/workflows/pr_build_linux.yml
b/.github/workflows/pr_build_linux.yml
index 59be24370..e7651ec5f 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -141,6 +141,7 @@ jobs:
org.apache.spark.CometPluginsDefaultSuite
org.apache.spark.CometPluginsNonOverrideSuite
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+ org.apache.comet.rules.CometScanRuleSuite
org.apache.comet.rules.CometExecRuleSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
diff --git a/.github/workflows/pr_build_macos.yml
b/.github/workflows/pr_build_macos.yml
index 4cee7395b..8fd0aab78 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -106,6 +106,7 @@ jobs:
org.apache.spark.CometPluginsDefaultSuite
org.apache.spark.CometPluginsNonOverrideSuite
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
+ org.apache.comet.rules.CometScanRuleSuite
org.apache.comet.rules.CometExecRuleSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
diff --git
a/spark/src/test/scala/org/apache/comet/rules/CometScanRuleSuite.scala
b/spark/src/test/scala/org/apache/comet/rules/CometScanRuleSuite.scala
new file mode 100644
index 000000000..7f54d0b7c
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/rules/CometScanRuleSuite.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.rules
+
+import scala.util.Random
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.comet._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.adaptive.QueryStageExec
+import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+
+import org.apache.comet.CometConf
+import org.apache.comet.parquet.CometParquetScan
+import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator}
+
+/**
+ * Test suite specifically for CometScanRule transformation logic.
+ */
+class CometScanRuleSuite extends CometTestBase {
+
+ /** Helper method to apply CometExecRule and return the transformed plan */
+ private def applyCometScanRule(plan: SparkPlan): SparkPlan = {
+ CometScanRule(spark).apply(stripAQEPlan(plan))
+ }
+
+ /** Create a test data frame that is used in all tests */
+ private def createTestDataFrame = {
+ val testSchema = new StructType(
+ Array(
+ StructField("id", DataTypes.IntegerType, nullable = true),
+ StructField("name", DataTypes.StringType, nullable = true)))
+ FuzzDataGenerator.generateDataFrame(new Random(42), spark, testSchema,
100, DataGenOptions())
+ }
+
+ /** Create a SparkPlan from the specified SQL with Comet disabled */
+ private def createSparkPlan(spark: SparkSession, sql: String): SparkPlan = {
+ var sparkPlan: SparkPlan = null
+ withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+ val df = spark.sql(sql)
+ sparkPlan = df.queryExecution.executedPlan
+ }
+ sparkPlan
+ }
+
+ /** Count the number of the specified operator in the plan */
+ private def countOperators(plan: SparkPlan, opClass: Class[_]): Int = {
+ stripAQEPlan(plan).collect {
+ case stage: QueryStageExec =>
+ countOperators(stage.plan, opClass)
+ case op if op.getClass.isAssignableFrom(opClass) => 1
+ }.sum
+ }
+
+ test("CometExecRule should replace FileSourceScanExec, but only when Comet
is enabled") {
+ withTempPath { path =>
+ createTestDataFrame.write.parquet(path.toString)
+ withTempView("test_data") {
+ spark.read.parquet(path.toString).createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT id, id * 2 as doubled FROM test_data
WHERE id % 2 == 0")
+
+ // Count original Spark operators
+ assert(countOperators(sparkPlan, classOf[FileSourceScanExec]) == 1)
+
+ for (cometEnabled <- Seq(true, false)) {
+ withSQLConf(CometConf.COMET_ENABLED.key -> cometEnabled.toString) {
+
+ val transformedPlan = applyCometScanRule(sparkPlan)
+
+ if (cometEnabled) {
+ assert(countOperators(transformedPlan,
classOf[FileSourceScanExec]) == 0)
+ assert(countOperators(transformedPlan, classOf[CometScanExec])
== 1)
+ } else {
+ assert(countOperators(transformedPlan,
classOf[FileSourceScanExec]) == 1)
+ assert(countOperators(transformedPlan, classOf[CometScanExec])
== 0)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test("CometExecRule should replace BatchScanExec, but only when Comet is
enabled") {
+ withTempPath { path =>
+ createTestDataFrame.write.parquet(path.toString)
+ withTempView("test_data") {
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+
spark.read.parquet(path.toString).createOrReplaceTempView("test_data")
+
+ val sparkPlan =
+ createSparkPlan(
+ spark,
+ "SELECT id, id * 2 as doubled FROM test_data WHERE id % 2 == 0")
+
+ // Count original Spark operators
+ assert(countOperators(sparkPlan, classOf[BatchScanExec]) == 1)
+
+ for (cometEnabled <- Seq(true, false)) {
+ withSQLConf(CometConf.COMET_ENABLED.key -> cometEnabled.toString) {
+
+ val transformedPlan = applyCometScanRule(sparkPlan)
+
+ if (cometEnabled) {
+ assert(countOperators(transformedPlan, classOf[BatchScanExec])
== 0)
+ assert(countOperators(transformedPlan,
classOf[CometBatchScanExec]) == 1)
+
+ // CometScanRule should have replaced the underlying scan
+ val scan = transformedPlan.collect { case scan:
CometBatchScanExec => scan }.head
+ assert(scan.wrapped.scan.isInstanceOf[CometParquetScan])
+
+ } else {
+ assert(countOperators(transformedPlan, classOf[BatchScanExec])
== 1)
+ assert(countOperators(transformedPlan,
classOf[CometBatchScanExec]) == 0)
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ test("CometScanRule should fallback to Spark for unsupported data types in
v1 scan") {
+ withTempPath { path =>
+ // Create test data with unsupported types (e.g., BinaryType,
CalendarIntervalType)
+ import org.apache.spark.sql.types._
+ val unsupportedSchema = new StructType(
+ Array(
+ StructField("id", DataTypes.IntegerType, nullable = true),
+ StructField(
+ "value",
+ DataTypes.ByteType,
+ nullable = true
+ ), // Unsupported in some scan modes
+ StructField("name", DataTypes.StringType, nullable = true)))
+
+ val testData = Seq(Row(1, 1.toByte, "test1"), Row(2, -1.toByte, "test2"))
+
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(testData),
unsupportedSchema)
+ df.write.parquet(path.toString)
+
+ withTempView("unsupported_data") {
+
spark.read.parquet(path.toString).createOrReplaceTempView("unsupported_data")
+
+ val sparkPlan =
+ createSparkPlan(spark, "SELECT id, value FROM unsupported_data WHERE
id = 1")
+
+ withSQLConf(
+ CometConf.COMET_NATIVE_SCAN_IMPL.key ->
CometConf.SCAN_NATIVE_ICEBERG_COMPAT,
+ CometConf.COMET_SCAN_ALLOW_INCOMPATIBLE.key -> "false") {
+ val transformedPlan = applyCometScanRule(sparkPlan)
+
+ // Should fallback to Spark due to unsupported ByteType in schema
+ assert(countOperators(transformedPlan, classOf[FileSourceScanExec])
== 1)
+ assert(countOperators(transformedPlan, classOf[CometScanExec]) == 0)
+ }
+ }
+ }
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]