This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new e77998ad2 fix: checkSparkMaybeThrows should compare Spark and Comet 
results in success case (#2728)
e77998ad2 is described below

commit e77998ad2228b758d0a93588c57325e56bda83bc
Author: Andy Grove <[email protected]>
AuthorDate: Mon Nov 10 18:42:04 2025 -0700

    fix: checkSparkMaybeThrows should compare Spark and Comet results in 
success case (#2728)
---
 .../org/apache/comet/expressions/CometCast.scala   |  2 ++
 .../apache/comet/CometBitwiseExpressionSuite.scala |  2 +-
 .../scala/org/apache/comet/CometCastSuite.scala    | 14 ++++++-------
 .../org/apache/comet/CometExpressionSuite.scala    | 24 +++++++++++-----------
 .../apache/comet/CometMathExpressionSuite.scala    |  2 +-
 .../scala/org/apache/spark/sql/CometTestBase.scala |  6 +++---
 6 files changed, 26 insertions(+), 24 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 80166fe7b..ee7216686 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -116,6 +116,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
 
     (fromType, toType) match {
       case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => 
Compatible()
+      case (dt: ArrayType, DataTypes.StringType) if dt.elementType == 
DataTypes.BinaryType =>
+        Incompatible()
       case (dt: ArrayType, DataTypes.StringType) =>
         isSupported(dt.elementType, DataTypes.StringType, timeZoneId, evalMode)
       case (dt: ArrayType, dt1: ArrayType) =>
diff --git 
a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
index 02c003ede..99a57b157 100644
--- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
@@ -76,7 +76,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHe
 
   test("bitwise_get - throws exceptions") {
     def checkSparkAndCometEqualThrows(query: String): Unit = {
-      checkSparkMaybeThrows(sql(query)) match {
+      checkSparkAnswerMaybeThrows(sql(query)) match {
         case (Some(sparkExc), Some(cometExc)) =>
           assert(sparkExc.getMessage == cometExc.getMessage)
         case _ => fail("Exception should be thrown")
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index f047fc052..6e9b4ec70 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, 
StringType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, 
DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, 
StructField, StructType}
 
 import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
 import org.apache.comet.expressions.{CometCast, CometEvalMode}
@@ -1035,7 +1035,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast between decimals with negative precision") {
     // cast to negative scale
-    checkSparkMaybeThrows(
+    checkSparkAnswerMaybeThrows(
       spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a")) 
match {
       case (expected, actual) =>
         assert(expected.contains("PARSE_SYNTAX_ERROR") === 
actual.contains("PARSE_SYNTAX_ERROR"))
@@ -1062,11 +1062,11 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       IntegerType,
       LongType,
       ShortType,
-      //      FloatType,
-      //      DoubleType,
+      // FloatType,
+      // DoubleType,
+      // BinaryType
       DecimalType(10, 2),
-      DecimalType(38, 18),
-      BinaryType).foreach { dt =>
+      DecimalType(38, 18)).foreach { dt =>
       val input = generateArrays(100, dt)
       castTest(input, StringType, hasIncompatibleType = 
hasIncompatibleType(input.schema))
     }
@@ -1272,7 +1272,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
           // cast() should throw exception on invalid inputs when ansi mode is 
enabled
           val df = data.withColumn("converted", col("a").cast(toType))
-          checkSparkMaybeThrows(df) match {
+          checkSparkAnswerMaybeThrows(df) match {
             case (None, None) =>
             // neither system threw an exception
             case (None, Some(e)) =>
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index e4fb08101..fc45d2cb3 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -312,7 +312,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = 
dictionaryEnabled, 10000)
         withParquetTable(path.toString, "tbl") {
           val (sparkErr, cometErr) =
-            checkSparkMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} FROM 
tbl"))
+            checkSparkAnswerMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} 
FROM tbl"))
           if (isSpark40Plus) {
             
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
           } else {
@@ -359,7 +359,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = 
dictionaryEnabled, 10000)
         withParquetTable(path.toString, "tbl") {
           val (sparkErr, cometErr) =
-            checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM 
tbl"))
+            checkSparkAnswerMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} 
FROM tbl"))
           if (isSpark40Plus) {
             
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
             
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
@@ -2022,7 +2022,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       val expectedDivideByZeroError =
         "[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate 
divisor being 0 and return NULL instead."
 
-      checkSparkMaybeThrows(sql(query)) match {
+      checkSparkAnswerMaybeThrows(sql(query)) match {
         case (Some(sparkException), Some(cometException)) =>
           assert(sparkException.getMessage.contains(expectedDivideByZeroError))
           assert(cometException.getMessage.contains(expectedDivideByZeroError))
@@ -2174,7 +2174,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
 
     def checkOverflow(query: String, dtype: String): Unit = {
-      checkSparkMaybeThrows(sql(query)) match {
+      checkSparkAnswerMaybeThrows(sql(query)) match {
         case (Some(sparkException), Some(cometException)) =>
           assert(sparkException.getMessage.contains(dtype + " overflow"))
           assert(cometException.getMessage.contains(dtype + " overflow"))
@@ -2700,7 +2700,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("ListExtract") {
     def assertBothThrow(df: DataFrame): Unit = {
-      checkSparkMaybeThrows(df) match {
+      checkSparkAnswerMaybeThrows(df) match {
         case (Some(_), Some(_)) => ()
         case (spark, comet) =>
           fail(
@@ -2850,7 +2850,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                               |  from tbl
                               |  """.stripMargin)
 
-        checkSparkMaybeThrows(res) match {
+        checkSparkAnswerMaybeThrows(res) match {
           case (Some(sparkExc), Some(cometExc)) =>
             
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
             assert(sparkExc.getMessage.contains("overflow"))
@@ -2869,7 +2869,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                               |  _1 - _2
                               |  from tbl
                               |  """.stripMargin)
-        checkSparkMaybeThrows(res) match {
+        checkSparkAnswerMaybeThrows(res) match {
           case (Some(sparkExc), Some(cometExc)) =>
             
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
             assert(sparkExc.getMessage.contains("overflow"))
@@ -2889,7 +2889,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                               |  from tbl
                               |  """.stripMargin)
 
-        checkSparkMaybeThrows(res) match {
+        checkSparkAnswerMaybeThrows(res) match {
           case (Some(sparkExc), Some(cometExc)) =>
             
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
             assert(sparkExc.getMessage.contains("overflow"))
@@ -2909,7 +2909,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                               |  from tbl
                               |  """.stripMargin)
 
-        checkSparkMaybeThrows(res) match {
+        checkSparkAnswerMaybeThrows(res) match {
           case (Some(sparkExc), Some(cometExc)) =>
             assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
             assert(sparkExc.getMessage.contains("Division by zero"))
@@ -2929,7 +2929,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                               |  from tbl
                               |  """.stripMargin)
 
-        checkSparkMaybeThrows(res) match {
+        checkSparkAnswerMaybeThrows(res) match {
           case (Some(sparkExc), Some(cometExc)) =>
             assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
             assert(sparkExc.getMessage.contains("Division by zero"))
@@ -2950,7 +2950,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 |  from tbl
                 |  """.stripMargin)
 
-          checkSparkMaybeThrows(res) match {
+          checkSparkAnswerMaybeThrows(res) match {
             case (Some(sparkException), Some(cometException)) =>
               
assert(sparkException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
               
assert(cometException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
@@ -2985,7 +2985,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
           Seq(true, false).foreach { ansi =>
             withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
               val res = spark.sql(s"SELECT round(_1, $scale) from tbl")
-              checkSparkMaybeThrows(res) match {
+              checkSparkAnswerMaybeThrows(res) match {
                 case (Some(sparkException), Some(cometException)) =>
                   
assert(sparkException.getMessage.contains("ARITHMETIC_OVERFLOW"))
                   
assert(cometException.getMessage.contains("ARITHMETIC_OVERFLOW"))
diff --git 
a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
index c95047a0e..8ea4a9c88 100644
--- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
@@ -56,7 +56,7 @@ class CometMathExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelpe
     withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
       for (field <- df.schema.fields) {
         val col = field.name
-        checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY 
$col")) match {
+        checkSparkAnswerMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl 
ORDER BY $col")) match {
           case (Some(sparkExc), Some(cometExc)) =>
             val cometErrorPattern =
               """.+[ARITHMETIC_OVERFLOW].+overflow. If necessary set 
"spark.sql.ansi.enabled" to "false" to bypass this error.""".r
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 7c9ca6ea0..591f75acb 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -306,7 +306,7 @@ abstract class CometTestBase
    * This method does not check that Comet replaced any operators or that the 
results match in the
    * case where the query is successful against both Spark and Comet.
    */
-  protected def checkSparkMaybeThrows(
+  protected def checkSparkAnswerMaybeThrows(
       df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
     var expected: Try[Array[Row]] = null
     withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -316,8 +316,8 @@ abstract class CometTestBase
 
     (expected, actual) match {
       case (Success(_), Success(_)) =>
-        // TODO compare results and confirm that they match
-        // https://github.com/apache/datafusion-comet/issues/2657
+        // compare results and confirm that they match
+        checkSparkAnswer(df)
         (None, None)
       case _ =>
         (expected.failed.toOption, actual.failed.toOption)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to