This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c299d572 chore: Various improvements to `checkSparkAnswer*` methods 
in `CometTestBase` (#2656)
2c299d572 is described below

commit 2c299d5728ac925a801617198878da397247d946
Author: Andy Grove <[email protected]>
AuthorDate: Wed Oct 29 10:04:03 2025 -0600

    chore: Various improvements to `checkSparkAnswer*` methods in 
`CometTestBase` (#2656)
---
 .../main/scala/org/apache/comet/CometConf.scala    |   7 +
 docs/source/user-guide/latest/compatibility.md     |   2 +-
 docs/source/user-guide/latest/configs.md           |   2 +
 .../org/apache/comet/ExtendedExplainInfo.scala     |   4 +-
 .../org/apache/comet/rules/CometExecRule.scala     |   2 +-
 .../apache/comet/CometArrayExpressionSuite.scala   |   7 +-
 .../org/apache/comet/CometExpressionSuite.scala    |  48 ++-
 .../apache/comet/CometStringExpressionSuite.scala  |   4 +-
 .../apache/comet/exec/CometAggregateSuite.scala    |   6 +-
 .../org/apache/comet/exec/CometExecSuite.scala     |  16 +-
 .../scala/org/apache/spark/sql/CometTestBase.scala | 382 +++++++++++++--------
 11 files changed, 301 insertions(+), 179 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 7b55932c7..e3ff30eb4 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -702,6 +702,13 @@ object CometConf extends ShimCometConf {
       .bytesConf(ByteUnit.BYTE)
       .createWithDefault(100L * 1024 * 1024 * 1024) // 100 GB
 
+  val COMET_STRICT_TESTING: ConfigEntry[Boolean] = 
conf(s"$COMET_PREFIX.testing.strict")
+    .category(CATEGORY_TESTING)
+    .doc("Experimental option to enable strict testing, which will fail tests 
that could be " +
+      "more comprehensive, such as checking for a specific fallback reason")
+    .booleanConf
+    .createWithDefault(sys.env.getOrElse("ENABLE_COMET_STRICT_TESTING", 
"false").toBoolean)
+
   /** Create a config to enable a specific operator */
   private def createExecEnabledConfig(
       exec: String,
diff --git a/docs/source/user-guide/latest/compatibility.md 
b/docs/source/user-guide/latest/compatibility.md
index 6c3bab59d..ac2be802d 100644
--- a/docs/source/user-guide/latest/compatibility.md
+++ b/docs/source/user-guide/latest/compatibility.md
@@ -97,7 +97,7 @@ because they are handled well in Spark (e.g., 
`SQLOrderingUtil.compareFloats`).
 functions of arrow-rs used by DataFusion do not normalize NaN and zero (e.g., 
[arrow::compute::kernels::cmp::eq](https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.eq.html#)).
 So Comet will add additional normalization expression of NaN and zero for 
comparison.
 
-Sorting on floating-point data types (or complex types containing 
floating-point values) is not compatible with 
+Sorting on floating-point data types (or complex types containing 
floating-point values) is not compatible with
 Spark if the data contains both zero and negative zero. This is likely an edge 
case that is not of concern for many users
 and sorting on floating-point data can be enabled by setting 
`spark.comet.expression.SortOrder.allowIncompatible=true`.
 
diff --git a/docs/source/user-guide/latest/configs.md 
b/docs/source/user-guide/latest/configs.md
index bc8fdcb98..77bdaf1cc 100644
--- a/docs/source/user-guide/latest/configs.md
+++ b/docs/source/user-guide/latest/configs.md
@@ -130,6 +130,7 @@ These settings can be used to determine which parts of the 
plan are accelerated
 | `spark.comet.exec.onHeap.enabled` | Whether to allow Comet to run in on-heap 
mode. Required for running Spark SQL tests. | false |
 | `spark.comet.exec.onHeap.memoryPool` | The type of memory pool to be used 
for Comet native execution when running Spark in on-heap mode. Available pool 
types are `greedy`, `fair_spill`, `greedy_task_shared`, 
`fair_spill_task_shared`, `greedy_global`, `fair_spill_global`, and 
`unbounded`. | greedy_task_shared |
 | `spark.comet.memoryOverhead` | The amount of additional memory to be 
allocated per executor process for Comet, in MiB, when running Spark in on-heap 
mode. | 1024 MiB |
+| `spark.comet.testing.strict` | Experimental option to enable strict testing, 
which will fail tests that could be more comprehensive, such as checking for a 
specific fallback reason | false |
 <!--END:CONFIG_TABLE-->
 
 ## Enabling or Disabling Individual Operators
@@ -274,6 +275,7 @@ These settings can be used to determine which parts of the 
plan are accelerated
 | `spark.comet.expression.ShiftRight.enabled` | Enable Comet acceleration for 
`ShiftRight` | true |
 | `spark.comet.expression.Signum.enabled` | Enable Comet acceleration for 
`Signum` | true |
 | `spark.comet.expression.Sin.enabled` | Enable Comet acceleration for `Sin` | 
true |
+| `spark.comet.expression.SortOrder.enabled` | Enable Comet acceleration for 
`SortOrder` | true |
 | `spark.comet.expression.SparkPartitionID.enabled` | Enable Comet 
acceleration for `SparkPartitionID` | true |
 | `spark.comet.expression.Sqrt.enabled` | Enable Comet acceleration for `Sqrt` 
| true |
 | `spark.comet.expression.StartsWith.enabled` | Enable Comet acceleration for 
`StartsWith` | true |
diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala 
b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala
index 1e514956d..b3f6579bd 100644
--- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala
+++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala
@@ -38,12 +38,12 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator {
     if (CometConf.COMET_EXPLAIN_VERBOSE_ENABLED.get()) {
       generateVerboseExtendedInfo(plan)
     } else {
-      val info = extensionInfo(plan)
+      val info = getFallbackReasons(plan)
       info.toSeq.sorted.mkString("\n").trim
     }
   }
 
-  private[comet] def extensionInfo(node: TreeNode[_]): Set[String] = {
+  def getFallbackReasons(node: TreeNode[_]): Set[String] = {
     var info = mutable.Seq[String]()
     val sorted = sortup(node)
     sorted.foreach { p =>
diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala 
b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
index 7a2f8a04f..47e21b24c 100644
--- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
+++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
@@ -649,7 +649,7 @@ case class CometExecRule(session: SparkSession) extends 
Rule[SparkPlan] {
       // config is enabled)
       if (CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.get()) {
         val info = new ExtendedExplainInfo()
-        if (info.extensionInfo(newPlan).nonEmpty) {
+        if (info.getFallbackReasons(newPlan).nonEmpty) {
           logWarning(
             "Comet cannot execute some parts of this plan natively " +
               s"(set ${CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key}=false " +
diff --git 
a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index c346dc2e9..3239bc020 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -128,12 +128,9 @@ class CometArrayExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelp
           .createOrReplaceTempView("t2")
         val expectedFallbackReasons = HashSet(
           "data type not supported: 
ArrayType(StructType(StructField(_1,BooleanType,true),StructField(_2,ByteType,true)),false)")
-        // note that checkExtended is disabled here due to an unrelated issue
-        // https://github.com/apache/datafusion-comet/issues/1313
-        checkSparkAnswerAndCompareExplainPlan(
+        checkSparkAnswerAndFallbackReasons(
           sql("SELECT array_remove(a, b) FROM t2"),
-          expectedFallbackReasons,
-          checkExplainString = false)
+          expectedFallbackReasons)
       }
     }
   }
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 9aecf55b3..c36d4979f 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -846,10 +846,16 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 "SECOND",
                 "MILLISECOND",
                 "MICROSECOND").foreach { format =>
-                checkSparkAnswer(
-                  "SELECT " +
-                    s"date_trunc('$format', ts )" +
-                    " from int96timetbl")
+                val sql = "SELECT " +
+                  s"date_trunc('$format', ts )" +
+                  " from int96timetbl"
+
+                if (conversionEnabled) {
+                  // plugin is disabled if PARQUET_INT96_TIMESTAMP_CONVERSION 
is true
+                  checkSparkAnswer(sql)
+                } else {
+                  checkSparkAnswerAndOperator(sql)
+                }
               }
             }
           }
@@ -978,7 +984,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       sql(s"create table $table(id int, name varchar(20)) using parquet")
       sql(s"insert into $table values(1,'James Smith')")
       val query = sql(s"select cast(id as string) from $table")
-      val (_, cometPlan) = checkSparkAnswer(query)
+      val (_, cometPlan) = checkSparkAnswerAndOperator(query)
       val project = cometPlan
         .asInstanceOf[WholeStageCodegenExec]
         .child
@@ -1343,17 +1349,19 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
           "sin",
           "sqrt",
           "tan")) {
-        val df = checkSparkAnswerWithTol(s"SELECT $expr(_1), $expr(_2) FROM 
tbl")
-        val cometProjectExecs = collect(df.queryExecution.executedPlan) {
-          case op: CometProjectExec => op
+        val (_, cometPlan) =
+          checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1), $expr(_2) 
FROM tbl"))
+        val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec 
=>
+          op
         }
         assert(cometProjectExecs.length == 1, expr)
       }
       // expressions with two args
       for (expr <- Seq("atan2", "pow")) {
-        val df = checkSparkAnswerWithTol(s"SELECT $expr(_1, _2) FROM tbl")
-        val cometProjectExecs = collect(df.queryExecution.executedPlan) {
-          case op: CometProjectExec => op
+        val (_, cometPlan) =
+          checkSparkAnswerAndOperatorWithTol(sql(s"SELECT $expr(_1, _2) FROM 
tbl"))
+        val cometProjectExecs = collect(cometPlan) { case op: CometProjectExec 
=>
+          op
         }
         assert(cometProjectExecs.length == 1, expr)
       }
@@ -1364,8 +1372,8 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     val testValuesRepeated = doubleValues.flatMap(v => Seq.fill(1000)(v))
     for (withDictionary <- Seq(true, false)) {
       withParquetTable(testValuesRepeated.map(n => (n, n)), "tbl", 
withDictionary) {
-        val df = checkSparkAnswerWithTol(s"SELECT $expr(_1) FROM tbl")
-        val projections = collect(df.queryExecution.executedPlan) { case p: 
CometProjectExec =>
+        val (_, cometPlan) = checkSparkAnswerAndOperatorWithTol(sql(s"SELECT 
$expr(_1) FROM tbl"))
+        val projections = collect(cometPlan) { case p: CometProjectExec =>
           p
         }
         assert(projections.length == 1)
@@ -1381,10 +1389,12 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
     withParquetTable(Seq(0, 1, 2).map(n => (n, n)), "tbl") {
       val sql = "select _1+_2 from tbl"
-      val (_, cometPlan) = checkSparkAnswer(sql)
+      val (_, cometPlan) = checkSparkAnswerAndOperator(sql)
       assert(0 == countSparkProjectExec(cometPlan))
       withSQLConf(CometConf.getExprEnabledConfigKey("Add") -> "false") {
-        val (_, cometPlan) = checkSparkAnswer(sql)
+        val (_, cometPlan) = checkSparkAnswerAndFallbackReason(
+          sql,
+          "Expression support is disabled. Set 
spark.comet.expression.Add.enabled=true to enable it.")
         assert(1 == countSparkProjectExec(cometPlan))
       }
     }
@@ -1401,7 +1411,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       val (_, cometPlan) = checkSparkAnswer(sql)
       assert(1 == countSparkProjectExec(cometPlan))
       withSQLConf(CometConf.getExprAllowIncompatConfigKey("InitCap") -> 
"true") {
-        val (_, cometPlan) = checkSparkAnswer(sql)
+        val (_, cometPlan) = checkSparkAnswerAndOperator(sql)
         assert(0 == countSparkProjectExec(cometPlan))
       }
     }
@@ -1677,7 +1687,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
               s"SELECT * FROM $table WHERE name in ('Smith', 'Brown', NULL)")
 
             // TODO: why with not in, the plan is only `LocalTableScan`?
-            checkSparkAnswer(s"SELECT * FROM $table WHERE id not in (1)")
+            checkSparkAnswerAndOperator(s"SELECT * FROM $table WHERE id not in 
(1)")
             checkSparkAnswer(s"SELECT * FROM $table WHERE name not in 
('Smith', 'Brown', NULL)")
           }
         }
@@ -2005,7 +2015,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
             val expected = test._2
             val df = sql(qry)
             df.collect() // force an execution
-            checkSparkAnswerAndCompareExplainPlan(df, expected)
+            checkSparkAnswerAndFallbackReasons(df, expected)
           })
       }
     }
@@ -2030,7 +2040,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
             val expected = test._2
             val df = sql(qry)
             df.collect() // force an execution
-            checkSparkAnswerAndCompareExplainPlan(df, expected)
+            checkSparkAnswerAndFallbackReasons(df, expected)
           })
       }
     }
diff --git 
a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
index a63aba8da..3d08c01a7 100644
--- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala
@@ -85,7 +85,7 @@ class CometStringExpressionSuite extends CometTestBase {
           if (isLiteralStr && isLiteralLen && isLiteralPad) {
             // all arguments are literal, so Spark constant folding will kick 
in
             // and pad function will not be evaluated by Comet
-            checkSparkAnswer(sql)
+            checkSparkAnswerAndOperator(sql)
           } else if (isLiteralStr) {
             checkSparkAnswerAndFallbackReason(
               sql,
@@ -135,7 +135,7 @@ class CometStringExpressionSuite extends CometTestBase {
           if (isLiteralStr && isLiteralLen && isLiteralPad) {
             // all arguments are literal, so Spark constant folding will kick 
in
             // and pad function will not be evaluated by Comet
-            checkSparkAnswer(sql)
+            checkSparkAnswerAndOperator(sql)
           } else {
             // Comet will fall back to Spark because the plan contains a 
staticinvoke instruction
             // which is not supported
diff --git 
a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
index 211cc16d0..f38dfb8ea 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala
@@ -1105,7 +1105,7 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 "SELECT _g2, AVG(_7) FROM tbl GROUP BY _g2",
                 expectedNumOfCometAggregates)
 
-              checkSparkAnswerWithTol("SELECT _g3, AVG(_8) FROM tbl GROUP BY 
_g3")
+              checkSparkAnswerWithTolerance("SELECT _g3, AVG(_8) FROM tbl 
GROUP BY _g3")
               assert(getNumCometHashAggregate(
                 sql("SELECT _g3, AVG(_8) FROM tbl GROUP BY _g3")) == 
expectedNumOfCometAggregates)
 
@@ -1117,7 +1117,7 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 "SELECT AVG(_7) FROM tbl",
                 expectedNumOfCometAggregates)
 
-              checkSparkAnswerWithTol("SELECT AVG(_8) FROM tbl")
+              checkSparkAnswerWithTolerance("SELECT AVG(_8) FROM tbl")
               assert(getNumCometHashAggregate(
                 sql("SELECT AVG(_8) FROM tbl")) == 
expectedNumOfCometAggregates)
 
@@ -1505,7 +1505,7 @@ class CometAggregateSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       numAggregates: Int,
       absTol: Double = 1e-6): Unit = {
     val df = sql(query)
-    checkSparkAnswerWithTol(df, absTol)
+    checkSparkAnswerWithTolerance(df, absTol)
     val actualNumAggregates = getNumCometHashAggregate(df)
     assert(
       actualNumAggregates == numAggregates,
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index af2730b8f..56174c7fc 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -2076,16 +2076,16 @@ class CometExecSuite extends CometTestBase {
                     List(s"COUNT(_$col)", s"MAX(_$col)", s"MIN(_$col)", 
s"SUM(_$col)")
                   aggregateFunctions.foreach { function =>
                     val df1 = sql(s"SELECT $function OVER() FROM tbl")
-                    checkSparkAnswerWithTol(df1, 1e-6)
+                    checkSparkAnswerWithTolerance(df1, 1e-6)
 
                     val df2 = sql(s"SELECT $function OVER(order by _2) FROM 
tbl")
-                    checkSparkAnswerWithTol(df2, 1e-6)
+                    checkSparkAnswerWithTolerance(df2, 1e-6)
 
                     val df3 = sql(s"SELECT $function OVER(order by _2 desc) 
FROM tbl")
-                    checkSparkAnswerWithTol(df3, 1e-6)
+                    checkSparkAnswerWithTolerance(df3, 1e-6)
 
                     val df4 = sql(s"SELECT $function OVER(partition by _2 
order by _2) FROM tbl")
-                    checkSparkAnswerWithTol(df4, 1e-6)
+                    checkSparkAnswerWithTolerance(df4, 1e-6)
                   }
                 }
 
@@ -2093,16 +2093,16 @@ class CometExecSuite extends CometTestBase {
                 val aggregateFunctionsWithoutSum = List("COUNT(_12)", 
"MAX(_12)", "MIN(_12)")
                 aggregateFunctionsWithoutSum.foreach { function =>
                   val df1 = sql(s"SELECT $function OVER() FROM tbl")
-                  checkSparkAnswerWithTol(df1, 1e-6)
+                  checkSparkAnswerWithTolerance(df1, 1e-6)
 
                   val df2 = sql(s"SELECT $function OVER(order by _2) FROM tbl")
-                  checkSparkAnswerWithTol(df2, 1e-6)
+                  checkSparkAnswerWithTolerance(df2, 1e-6)
 
                   val df3 = sql(s"SELECT $function OVER(order by _2 desc) FROM 
tbl")
-                  checkSparkAnswerWithTol(df3, 1e-6)
+                  checkSparkAnswerWithTolerance(df3, 1e-6)
 
                   val df4 = sql(s"SELECT $function OVER(partition by _2 order 
by _2) FROM tbl")
-                  checkSparkAnswerWithTol(df4, 1e-6)
+                  checkSparkAnswerWithTolerance(df4, 1e-6)
                 }
               }
             }
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index fe2edc705..1854edf59 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -24,11 +24,10 @@ import java.util.concurrent.atomic.AtomicInteger
 import scala.concurrent.duration._
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.TypeTag
-import scala.util.Try
+import scala.util.{Success, Try}
 
 import org.scalatest.BeforeAndAfterEach
 
-import org.apache.commons.lang3.StringUtils
 import org.apache.hadoop.fs.Path
 import org.apache.parquet.column.ParquetProperties
 import org.apache.parquet.example.data.Group
@@ -102,10 +101,226 @@ abstract class CometTestBase
     }
   }
 
+  protected def internalCheckSparkAnswer(
+      df: => DataFrame,
+      assertCometNative: Boolean,
+      includeClasses: Seq[Class[_]] = Seq.empty,
+      excludedClasses: Seq[Class[_]] = Seq.empty,
+      withTol: Option[Double] = None): (SparkPlan, SparkPlan) = {
+
+    var expected: Array[Row] = Array.empty
+    var sparkPlan = null.asInstanceOf[SparkPlan]
+    withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+      val dfSpark = datasetOfRows(spark, df.logicalPlan)
+      expected = dfSpark.collect()
+      sparkPlan = dfSpark.queryExecution.executedPlan
+    }
+    val dfComet = datasetOfRows(spark, df.logicalPlan)
+
+    if (withTol.isDefined) {
+      checkAnswerWithTolerance(dfComet, expected, withTol.get)
+    } else {
+      checkAnswer(dfComet, expected)
+    }
+
+    if (assertCometNative) {
+      checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), 
excludedClasses: _*)
+    } else {
+      if (CometConf.COMET_STRICT_TESTING.get()) {
+        if (findFirstNonCometOperator(
+            stripAQEPlan(df.queryExecution.executedPlan),
+            excludedClasses: _*).isEmpty) {
+          fail("Plan was fully native in Comet. Call 
checkSparkAnswerAndOperator instead.")
+        }
+      }
+    }
+
+    if (includeClasses.nonEmpty) {
+      checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), 
includeClasses: _*)
+    }
+
+    (sparkPlan, dfComet.queryExecution.executedPlan)
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled, 
but do not check if
+   * Comet accelerated any operators
+   */
+  protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = {
+    internalCheckSparkAnswer(sql(query), assertCometNative = false)
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled, 
but do not check if
+   * Comet accelerated any operators
+   */
+  protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = {
+    internalCheckSparkAnswer(df, assertCometNative = false)
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled, 
but do not check if
+   * Comet accelerated any operators
+   *
+   * Use the provided `tol` when comparing floating-point results.
+   */
+  protected def checkSparkAnswerWithTolerance(
+      query: String,
+      absTol: Double = 1e-6): (SparkPlan, SparkPlan) = {
+    checkSparkAnswerWithTolerance(sql(query), absTol)
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled, 
but do not check if
+   * Comet accelerated any operators
+   *
+   * Use the provided `tol` when comparing floating-point results.
+   */
+  protected def checkSparkAnswerWithTolerance(
+      df: => DataFrame,
+      absTol: Double): (SparkPlan, SparkPlan) = {
+    internalCheckSparkAnswer(df, assertCometNative = false, withTol = 
Some(absTol))
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled 
and that Comet
+   * replaced all possible operators except for those specified in the 
excluded list.
+   */
+  protected def checkSparkAnswerAndOperator(
+      query: String,
+      excludedClasses: Class[_]*): (SparkPlan, SparkPlan) = {
+    checkSparkAnswerAndOperator(sql(query), excludedClasses: _*)
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled 
and that Comet
+   * replaced all possible operators except for those specified in the 
excluded list.
+   */
+  protected def checkSparkAnswerAndOperator(
+      df: => DataFrame,
+      excludedClasses: Class[_]*): (SparkPlan, SparkPlan) = {
+    internalCheckSparkAnswer(
+      df,
+      assertCometNative = true,
+      excludedClasses = Seq(excludedClasses: _*))
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled 
and that Comet
+   * replaced all possible operators except for those specified in the 
excluded list.
+   *
+   * Also check that the plan included all operators specified in 
`includeClasses`.
+   */
+  protected def checkSparkAnswerAndOperator(
+      df: => DataFrame,
+      includeClasses: Seq[Class[_]],
+      excludedClasses: Class[_]*): (SparkPlan, SparkPlan) = {
+    internalCheckSparkAnswer(
+      df,
+      assertCometNative = true,
+      includeClasses,
+      excludedClasses = Seq(excludedClasses: _*))
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled 
and that Comet
+   * replaced all possible operators except for those specified in the 
excluded list.
+   *
+   * Also check that the plan included all operators specified in 
`includeClasses`.
+   *
+   * Use the provided `tol` when comparing floating-point results.
+   */
+  protected def checkSparkAnswerAndOperatorWithTol(
+      df: => DataFrame,
+      tol: Double = 1e-6): (SparkPlan, SparkPlan) = {
+    checkSparkAnswerAndOperatorWithTol(df, tol, Seq.empty)
+  }
+
+  /**
+   * Check that the query returns the correct results when Comet is enabled 
and that Comet
+   * replaced all possible operators except for those specified in the 
excluded list.
+   *
+   * Also check that the plan included all operators specified in 
`includeClasses`.
+   *
+   * Use the provided `tol` when comparing floating-point results.
+   */
+  protected def checkSparkAnswerAndOperatorWithTol(
+      df: => DataFrame,
+      tol: Double,
+      includeClasses: Seq[Class[_]],
+      excludedClasses: Class[_]*): (SparkPlan, SparkPlan) = {
+    internalCheckSparkAnswer(
+      df,
+      assertCometNative = true,
+      includeClasses = Seq(includeClasses: _*),
+      excludedClasses = Seq(excludedClasses: _*),
+      withTol = Some(tol))
+  }
+
+  /** Check for the correct results as well as the expected fallback reason */
+  protected def checkSparkAnswerAndFallbackReason(
+      query: String,
+      fallbackReason: String): (SparkPlan, SparkPlan) = {
+    checkSparkAnswerAndFallbackReasons(sql(query), Set(fallbackReason))
+  }
+
+  /** Check for the correct results as well as the expected fallback reason */
+  protected def checkSparkAnswerAndFallbackReason(
+      df: => DataFrame,
+      fallbackReason: String): (SparkPlan, SparkPlan) = {
+    checkSparkAnswerAndFallbackReasons(df, Set(fallbackReason))
+  }
+
+  /** Check for the correct results as well as the expected fallback reasons */
+  protected def checkSparkAnswerAndFallbackReasons(
+      df: => DataFrame,
+      fallbackReasons: Set[String]): (SparkPlan, SparkPlan) = {
+    val (sparkPlan, cometPlan) = internalCheckSparkAnswer(df, 
assertCometNative = false)
+    val explainInfo = new ExtendedExplainInfo()
+    val actualFallbacks = explainInfo.getFallbackReasons(cometPlan)
+    for (reason <- fallbackReasons) {
+      if (!actualFallbacks.exists(_.contains(reason))) {
+        if (actualFallbacks.isEmpty) {
+          fail(
+            s"Expected fallback reason '$reason' but no fallback reasons were 
found. Explain: ${explainInfo
+                .generateVerboseExtendedInfo(cometPlan)}")
+        } else {
+          fail(
+            s"Expected fallback reason '$reason' not found in 
[${actualFallbacks.mkString(", ")}]")
+        }
+      }
+    }
+    (sparkPlan, cometPlan)
+  }
+
+  /**
+   * Try executing the query against Spark and Comet and return the results or 
the exception.
+   *
+   * This method does not check that Comet replaced any operators or that the 
results match in the
+   * case where the query is successful against both Spark and Comet.
+   */
+  protected def checkSparkMaybeThrows(
+      df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
+    var expected: Try[Array[Row]] = null
+    withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+      expected = Try(datasetOfRows(spark, df.logicalPlan).collect())
+    }
+    val actual = Try(datasetOfRows(spark, df.logicalPlan).collect())
+
+    (expected, actual) match {
+      case (Success(_), Success(_)) =>
+        // TODO compare results and confirm that they match
+        // https://github.com/apache/datafusion-comet/issues/2657
+        (None, None)
+      case _ =>
+        (expected.failed.toOption, actual.failed.toOption)
+    }
+  }
+
   /**
    * A helper function for comparing Comet DataFrame with Spark result using 
absolute tolerance.
    */
-  protected def checkAnswerWithTol(
+  private def checkAnswerWithTolerance(
       dataFrame: DataFrame,
       expectedAnswer: Seq[Row],
       absTol: Double): Unit = {
@@ -115,14 +330,14 @@ abstract class CometTestBase
       s"actual num rows ${actualAnswer.length} != expected num of rows 
${expectedAnswer.length}")
 
     actualAnswer.zip(expectedAnswer).foreach { case (actualRow, expectedRow) =>
-      checkAnswerWithTol(actualRow, expectedRow, absTol)
+      checkAnswerWithTolerance(actualRow, expectedRow, absTol)
     }
   }
 
   /**
    * Compares two answers and makes sure the answer is within absTol of the 
expected result.
    */
-  protected def checkAnswerWithTol(
+  private def checkAnswerWithTolerance(
       actualAnswer: Row,
       expectedAnswer: Row,
       absTol: Double): Unit = {
@@ -133,6 +348,14 @@ abstract class CometTestBase
     require(absTol > 0 && absTol <= 1e-6, s"absTol $absTol is out of range (0, 
1e-6]")
 
     actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach {
+      case (actual: Float, expected: Float) =>
+        if (actual.isInfinity || expected.isInfinity) {
+          assert(actual.isInfinity == expected.isInfinity, s"actual answer 
$actual != $expected")
+        } else if (!actual.isNaN && !expected.isNaN) {
+          assert(
+            math.abs(actual - expected) < absTol,
+            s"actual answer $actual not within $absTol of correct answer 
$expected")
+        }
       case (actual: Double, expected: Double) =>
         if (actual.isInfinity || expected.isInfinity) {
           assert(actual.isInfinity == expected.isInfinity, s"actual answer 
$actual != $expected")
@@ -146,71 +369,20 @@ abstract class CometTestBase
     }
   }
 
-  protected def checkSparkAnswer(query: String): (SparkPlan, SparkPlan) = {
-    checkSparkAnswer(sql(query))
-  }
-
-  /**
-   * Check the answer of a Comet SQL query with Spark result.
-   * @param df
-   *   The DataFrame of the query.
-   * @return
-   *   A tuple of the SparkPlan of the query and the SparkPlan of the Comet 
query.
-   */
-  protected def checkSparkAnswer(df: => DataFrame): (SparkPlan, SparkPlan) = {
-    var expected: Array[Row] = Array.empty
-    var sparkPlan = null.asInstanceOf[SparkPlan]
-    withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
-      val dfSpark = datasetOfRows(spark, df.logicalPlan)
-      expected = dfSpark.collect()
-      sparkPlan = dfSpark.queryExecution.executedPlan
+  protected def checkCometOperators(plan: SparkPlan, excludedClasses: 
Class[_]*): Unit = {
+    findFirstNonCometOperator(plan, excludedClasses: _*) match {
+      case Some(op) =>
+        assert(
+          false,
+          s"Expected only Comet native operators, but found ${op.nodeName}.\n" 
+
+            s"plan: ${new 
ExtendedExplainInfo().generateVerboseExtendedInfo(plan)}")
+      case _ =>
     }
-    val dfComet = datasetOfRows(spark, df.logicalPlan)
-    checkAnswer(dfComet, expected)
-    (sparkPlan, dfComet.queryExecution.executedPlan)
   }
 
-  /** Check for the correct results as well as the expected fallback reason */
-  def checkSparkAnswerAndFallbackReason(sql: String, fallbackReason: String): 
Unit = {
-    val (_, cometPlan) = checkSparkAnswer(sql)
-    val explain = new 
ExtendedExplainInfo().generateVerboseExtendedInfo(cometPlan)
-    assert(explain.contains(fallbackReason))
-  }
-
-  protected def checkSparkAnswerAndOperator(query: String, excludedClasses: 
Class[_]*): Unit = {
-    checkSparkAnswerAndOperator(sql(query), excludedClasses: _*)
-  }
-
-  protected def checkSparkAnswerAndOperator(
-      df: => DataFrame,
-      excludedClasses: Class[_]*): Unit = {
-    checkSparkAnswerAndOperator(df, Seq.empty, excludedClasses: _*)
-  }
-
-  protected def checkSparkAnswerAndOperator(
-      df: => DataFrame,
-      includeClasses: Seq[Class[_]],
-      excludedClasses: Class[_]*): Unit = {
-    checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), 
excludedClasses: _*)
-    checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), 
includeClasses: _*)
-    checkSparkAnswer(df)
-  }
-
-  protected def checkSparkAnswerAndOperatorWithTol(df: => DataFrame, tol: 
Double = 1e-6): Unit = {
-    checkSparkAnswerAndOperatorWithTol(df, tol, Seq.empty)
-  }
-
-  protected def checkSparkAnswerAndOperatorWithTol(
-      df: => DataFrame,
-      tol: Double,
-      includeClasses: Seq[Class[_]],
-      excludedClasses: Class[_]*): Unit = {
-    checkCometOperators(stripAQEPlan(df.queryExecution.executedPlan), 
excludedClasses: _*)
-    checkPlanContains(stripAQEPlan(df.queryExecution.executedPlan), 
includeClasses: _*)
-    checkSparkAnswerWithTol(df, tol)
-  }
-
-  protected def checkCometOperators(plan: SparkPlan, excludedClasses: 
Class[_]*): Unit = {
+  protected def findFirstNonCometOperator(
+      plan: SparkPlan,
+      excludedClasses: Class[_]*): Option[SparkPlan] = {
     val wrapped = wrapCometSparkToColumnar(plan)
     wrapped.foreach {
       case _: CometNativeScanExec | _: CometScanExec | _: CometBatchScanExec =>
@@ -220,17 +392,14 @@ abstract class CometTestBase
       case _: CometExec | _: CometShuffleExchangeExec =>
       case _: CometBroadcastExchangeExec =>
       case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter =>
-      case op =>
-        if (!excludedClasses.exists(c => c.isAssignableFrom(op.getClass))) {
-          assert(
-            false,
-            s"Expected only Comet native operators, but found 
${op.nodeName}.\n" +
-              s"plan: ${new 
ExtendedExplainInfo().generateVerboseExtendedInfo(plan)}")
-        }
+      case op if !excludedClasses.exists(c => c.isAssignableFrom(op.getClass)) 
=>
+        return Some(op)
+      case _ =>
     }
+    None
   }
 
-  protected def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): 
Unit = {
+  private def checkPlanContains(plan: SparkPlan, includePlans: Class[_]*): 
Unit = {
     includePlans.foreach { case planClass =>
       if (plan.find(op => planClass.isAssignableFrom(op.getClass)).isEmpty) {
         assert(
@@ -249,69 +418,6 @@ abstract class CometTestBase
     }
   }
 
-  /**
-   * Check the answer of a Comet SQL query with Spark result using absolute 
tolerance.
-   */
-  protected def checkSparkAnswerWithTol(query: String, absTol: Double = 1e-6): 
DataFrame = {
-    checkSparkAnswerWithTol(sql(query), absTol)
-  }
-
-  /**
-   * Check the answer of a Comet DataFrame with Spark result using absolute 
tolerance.
-   */
-  protected def checkSparkAnswerWithTol(df: => DataFrame, absTol: Double): 
DataFrame = {
-    var expected: Array[Row] = Array.empty
-    withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
-      val dfSpark = datasetOfRows(spark, df.logicalPlan)
-      expected = dfSpark.collect()
-    }
-    val dfComet = datasetOfRows(spark, df.logicalPlan)
-    checkAnswerWithTol(dfComet, expected, absTol: Double)
-    dfComet
-  }
-
-  protected def checkSparkMaybeThrows(
-      df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
-    var expected: Option[Throwable] = None
-    withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
-      expected = Try(datasetOfRows(spark, 
df.logicalPlan).collect()).failed.toOption
-    }
-    val actual = Try(datasetOfRows(spark, 
df.logicalPlan).collect()).failed.toOption
-    (expected, actual)
-  }
-
-  protected def checkSparkAnswerAndCompareExplainPlan(
-      df: DataFrame,
-      expectedInfo: Set[String],
-      checkExplainString: Boolean = true): Unit = {
-    var expected: Array[Row] = Array.empty
-    var dfSpark: Dataset[Row] = null
-    withSQLConf(CometConf.COMET_ENABLED.key -> "false", 
EXTENDED_EXPLAIN_PROVIDERS_KEY -> "") {
-      dfSpark = datasetOfRows(spark, df.logicalPlan)
-      expected = dfSpark.collect()
-    }
-    val dfComet = datasetOfRows(spark, df.logicalPlan)
-    checkAnswer(dfComet, expected)
-    if (checkExplainString) {
-      val diff = StringUtils.difference(
-        dfSpark.queryExecution.explainString(ExtendedMode),
-        dfComet.queryExecution.explainString(ExtendedMode))
-      if (supportsExtendedExplainInfo(dfSpark.queryExecution)) {
-        for (info <- expectedInfo) {
-          if (!diff.contains(info)) {
-            fail(s"Extended explain diff did not contain [$info]. Diff: 
$diff.")
-          }
-        }
-      }
-    }
-    val extendedInfo =
-      new 
ExtendedExplainInfo().generateExtendedInfo(dfComet.queryExecution.executedPlan)
-    val expectedStr = expectedInfo.toSeq.sorted.mkString("\n")
-    if (!extendedInfo.equalsIgnoreCase(expectedStr)) {
-      fail(s"$extendedInfo != $expectedStr (case-insensitive comparison)")
-    }
-  }
-
   private var _spark: SparkSessionType = _
   override protected implicit def spark: SparkSessionType = _spark
   protected implicit def sqlContext: SQLContext = _spark.sqlContext


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to