This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new e77998ad2 fix: checkSparkMaybeThrows should compare Spark and Comet
results in success case (#2728)
e77998ad2 is described below
commit e77998ad2228b758d0a93588c57325e56bda83bc
Author: Andy Grove <[email protected]>
AuthorDate: Mon Nov 10 18:42:04 2025 -0700
fix: checkSparkMaybeThrows should compare Spark and Comet results in
success case (#2728)
---
.../org/apache/comet/expressions/CometCast.scala | 2 ++
.../apache/comet/CometBitwiseExpressionSuite.scala | 2 +-
.../scala/org/apache/comet/CometCastSuite.scala | 14 ++++++-------
.../org/apache/comet/CometExpressionSuite.scala | 24 +++++++++++-----------
.../apache/comet/CometMathExpressionSuite.scala | 2 +-
.../scala/org/apache/spark/sql/CometTestBase.scala | 6 +++---
6 files changed, 26 insertions(+), 24 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 80166fe7b..ee7216686 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -116,6 +116,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
(fromType, toType) match {
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType =>
Compatible()
+ case (dt: ArrayType, DataTypes.StringType) if dt.elementType ==
DataTypes.BinaryType =>
+ Incompatible()
case (dt: ArrayType, DataTypes.StringType) =>
isSupported(dt.elementType, DataTypes.StringType, timeZoneId, evalMode)
case (dt: ArrayType, dt1: ArrayType) =>
diff --git
a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
index 02c003ede..99a57b157 100644
--- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala
@@ -76,7 +76,7 @@ class CometBitwiseExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHe
test("bitwise_get - throws exceptions") {
def checkSparkAndCometEqualThrows(query: String): Unit = {
- checkSparkMaybeThrows(sql(query)) match {
+ checkSparkAnswerMaybeThrows(sql(query)) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(sparkExc.getMessage == cometExc.getMessage)
case _ => fail("Exception should be thrown")
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index f047fc052..6e9b4ec70 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType,
ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType,
StringType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType,
DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType,
StructField, StructType}
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.expressions.{CometCast, CometEvalMode}
@@ -1035,7 +1035,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast between decimals with negative precision") {
// cast to negative scale
- checkSparkMaybeThrows(
+ checkSparkAnswerMaybeThrows(
spark.sql("select a, cast(a as DECIMAL(10,-4)) from t order by a"))
match {
case (expected, actual) =>
assert(expected.contains("PARSE_SYNTAX_ERROR") ===
actual.contains("PARSE_SYNTAX_ERROR"))
@@ -1062,11 +1062,11 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
IntegerType,
LongType,
ShortType,
- // FloatType,
- // DoubleType,
+ // FloatType,
+ // DoubleType,
+ // BinaryType
DecimalType(10, 2),
- DecimalType(38, 18),
- BinaryType).foreach { dt =>
+ DecimalType(38, 18)).foreach { dt =>
val input = generateArrays(100, dt)
castTest(input, StringType, hasIncompatibleType =
hasIncompatibleType(input.schema))
}
@@ -1272,7 +1272,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// cast() should throw exception on invalid inputs when ansi mode is
enabled
val df = data.withColumn("converted", col("a").cast(toType))
- checkSparkMaybeThrows(df) match {
+ checkSparkAnswerMaybeThrows(df) match {
case (None, None) =>
// neither system threw an exception
case (None, Some(e)) =>
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index e4fb08101..fc45d2cb3 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -312,7 +312,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
val (sparkErr, cometErr) =
- checkSparkMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue} FROM
tbl"))
+ checkSparkAnswerMaybeThrows(sql(s"SELECT _20 + ${Int.MaxValue}
FROM tbl"))
if (isSpark40Plus) {
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
} else {
@@ -359,7 +359,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled =
dictionaryEnabled, 10000)
withParquetTable(path.toString, "tbl") {
val (sparkErr, cometErr) =
- checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM
tbl"))
+ checkSparkAnswerMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue}
FROM tbl"))
if (isSpark40Plus) {
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
@@ -2022,7 +2022,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
val expectedDivideByZeroError =
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate
divisor being 0 and return NULL instead."
- checkSparkMaybeThrows(sql(query)) match {
+ checkSparkAnswerMaybeThrows(sql(query)) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(expectedDivideByZeroError))
assert(cometException.getMessage.contains(expectedDivideByZeroError))
@@ -2174,7 +2174,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
def checkOverflow(query: String, dtype: String): Unit = {
- checkSparkMaybeThrows(sql(query)) match {
+ checkSparkAnswerMaybeThrows(sql(query)) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(dtype + " overflow"))
assert(cometException.getMessage.contains(dtype + " overflow"))
@@ -2700,7 +2700,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("ListExtract") {
def assertBothThrow(df: DataFrame): Unit = {
- checkSparkMaybeThrows(df) match {
+ checkSparkAnswerMaybeThrows(df) match {
case (Some(_), Some(_)) => ()
case (spark, comet) =>
fail(
@@ -2850,7 +2850,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("overflow"))
@@ -2869,7 +2869,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
| _1 - _2
| from tbl
| """.stripMargin)
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("overflow"))
@@ -2889,7 +2889,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(ARITHMETIC_OVERFLOW_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("overflow"))
@@ -2909,7 +2909,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("Division by zero"))
@@ -2929,7 +2929,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkExc), Some(cometExc)) =>
assert(cometExc.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(sparkExc.getMessage.contains("Division by zero"))
@@ -2950,7 +2950,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
| from tbl
| """.stripMargin)
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
assert(cometException.getMessage.contains(DIVIDE_BY_ZERO_EXCEPTION_MSG))
@@ -2985,7 +2985,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
Seq(true, false).foreach { ansi =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansi.toString) {
val res = spark.sql(s"SELECT round(_1, $scale) from tbl")
- checkSparkMaybeThrows(res) match {
+ checkSparkAnswerMaybeThrows(res) match {
case (Some(sparkException), Some(cometException)) =>
assert(sparkException.getMessage.contains("ARITHMETIC_OVERFLOW"))
assert(cometException.getMessage.contains("ARITHMETIC_OVERFLOW"))
diff --git
a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
index c95047a0e..8ea4a9c88 100644
--- a/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometMathExpressionSuite.scala
@@ -56,7 +56,7 @@ class CometMathExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelpe
withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
for (field <- df.schema.fields) {
val col = field.name
- checkSparkMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl ORDER BY
$col")) match {
+ checkSparkAnswerMaybeThrows(sql(s"SELECT $col, abs($col) FROM tbl
ORDER BY $col")) match {
case (Some(sparkExc), Some(cometExc)) =>
val cometErrorPattern =
""".+[ARITHMETIC_OVERFLOW].+overflow. If necessary set
"spark.sql.ansi.enabled" to "false" to bypass this error.""".r
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 7c9ca6ea0..591f75acb 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -306,7 +306,7 @@ abstract class CometTestBase
* This method does not check that Comet replaced any operators or that the
results match in the
* case where the query is successful against both Spark and Comet.
*/
- protected def checkSparkMaybeThrows(
+ protected def checkSparkAnswerMaybeThrows(
df: => DataFrame): (Option[Throwable], Option[Throwable]) = {
var expected: Try[Array[Row]] = null
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -316,8 +316,8 @@ abstract class CometTestBase
(expected, actual) match {
case (Success(_), Success(_)) =>
- // TODO compare results and confirm that they match
- // https://github.com/apache/datafusion-comet/issues/2657
+ // compare results and confirm that they match
+ checkSparkAnswer(df)
(None, None)
case _ =>
(expected.failed.toOption, actual.failed.toOption)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]