This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 11cce7e7338 [SPARK-40768][SQL] Migrate type check failures of 
bloom_filter_agg() onto error classes
11cce7e7338 is described below

commit 11cce7e73380231ec7c94096655e3d98ce7e635d
Author: lvshaokang <lvshaoka...@gmail.com>
AuthorDate: Fri Oct 21 10:06:44 2022 +0500

    [SPARK-40768][SQL] Migrate type check failures of bloom_filter_agg() onto 
error classes
    
    ### What changes were proposed in this pull request?
    
    In the PR, I propose to use error classes in the case of type check failure 
in Bloom Filter Agg expressions.
    
    ### Why are the changes needed?
    
    Migration onto error classes unifies Spark SQL error messages.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. The PR changes user-facing error messages.
    
    ### How was this patch tested?
    
    ```
    build/sbt "sql/testOnly *SQLQueryTestSuite"
    build/sbt "test:testOnly org.apache.spark.SparkThrowableSuite"
    build/sbt "test:testOnly *BloomFilterAggregateQuerySuite"
    ```
    
    Closes #38315 from lvshaokang/SPARK-40768.
    
    Authored-by: lvshaokang <lvshaoka...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |   2 +-
 .../expressions/BloomFilterMightContain.scala      |   3 +-
 .../aggregate/BloomFilterAggregate.scala           |  61 +++++++--
 .../spark/sql/BloomFilterAggregateQuerySuite.scala | 144 ++++++++++++++++++---
 4 files changed, 179 insertions(+), 31 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 0cfb6861c77..1e9519dd89a 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -108,7 +108,7 @@
       },
       "BLOOM_FILTER_WRONG_TYPE" : {
         "message" : [
-          "Input to function <functionName> should have been <expectedLeft> 
followed by a value with <expectedRight>, but it's [<actualLeft>, 
<actualRight>]."
+          "Input to function <functionName> should have been <expectedLeft> 
followed by value with <expectedRight>, but it's [<actual>]."
         ]
       },
       "CANNOT_CONVERT_TO_JSON" : {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
index 5cb19d36b80..b2273b6a6d1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
@@ -76,8 +76,7 @@ case class BloomFilterMightContain(
             "functionName" -> toSQLId(prettyName),
             "expectedLeft" -> toSQLType(BinaryType),
             "expectedRight" -> toSQLType(LongType),
-            "actualLeft" -> toSQLType(left.dataType),
-            "actualRight" -> toSQLType(right.dataType)
+            "actual" -> Seq(left.dataType, 
right.dataType).map(toSQLType).mkString(", ")
           )
         )
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
index c734bca3ef8..5b78c5b5228 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, 
toSQLType, toSQLValue}
 import org.apache.spark.sql.catalyst.trees.TernaryLike
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -63,28 +64,66 @@ case class BloomFilterAggregate(
   override def checkInputDataTypes(): TypeCheckResult = {
     (first.dataType, second.dataType, third.dataType) match {
       case (_, NullType, _) | (_, _, NullType) =>
-        TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as 
size arguments")
+        DataTypeMismatch(
+          errorSubClass = "UNEXPECTED_NULL",
+          messageParameters = Map(
+            "exprName" -> "estimatedNumItems or numBits"
+          )
+        )
       case (LongType, LongType, LongType) =>
         if (!estimatedNumItemsExpression.foldable) {
-          TypeCheckFailure("The estimated number of items provided must be a 
constant literal")
+          DataTypeMismatch(
+            errorSubClass = "NON_FOLDABLE_INPUT",
+            messageParameters = Map(
+              "inputName" -> "estimatedNumItems",
+              "inputType" -> toSQLType(estimatedNumItemsExpression.dataType),
+              "inputExpr" -> toSQLExpr(estimatedNumItemsExpression)
+            )
+          )
         } else if (estimatedNumItems <= 0L) {
-          TypeCheckFailure("The estimated number of items must be a positive 
value " +
-            s" (current value = $estimatedNumItems)")
+          DataTypeMismatch(
+            errorSubClass = "VALUE_OUT_OF_RANGE",
+            messageParameters = Map(
+              "exprName" -> "estimatedNumItems",
+              "valueRange" -> s"[0, positive]",
+              "currentValue" -> toSQLValue(estimatedNumItems, LongType)
+            )
+          )
         } else if (!numBitsExpression.foldable) {
-          TypeCheckFailure("The number of bits provided must be a constant 
literal")
+          DataTypeMismatch(
+            errorSubClass = "NON_FOLDABLE_INPUT",
+            messageParameters = Map(
+              "inputName" -> "numBitsExpression",
+              "inputType" -> toSQLType(numBitsExpression.dataType),
+              "inputExpr" -> toSQLExpr(numBitsExpression)
+            )
+          )
         } else if (numBits <= 0L) {
-          TypeCheckFailure("The number of bits must be a positive value " +
-            s" (current value = $numBits)")
+          DataTypeMismatch(
+            errorSubClass = "VALUE_OUT_OF_RANGE",
+            messageParameters = Map(
+              "exprName" -> "numBits",
+              "valueRange" -> s"[0, positive]",
+              "currentValue" -> toSQLValue(numBits, LongType)
+            )
+          )
         } else {
           require(estimatedNumItems <=
             SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
           require(numBits <= 
SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
           TypeCheckSuccess
         }
-      case _ => TypeCheckResult.TypeCheckFailure(s"Input to function 
$prettyName should have " +
-        s"been a ${LongType.simpleString} value followed with two 
${LongType.simpleString} size " +
-        s"arguments, but it's [${first.dataType.catalogString}, " +
-        s"${second.dataType.catalogString}, ${third.dataType.catalogString}]")
+      case _ =>
+        DataTypeMismatch(
+          errorSubClass = "BLOOM_FILTER_WRONG_TYPE",
+          messageParameters = Map(
+            "functionName" -> toSQLId(prettyName),
+            "expectedLeft" -> toSQLType(BinaryType),
+            "expectedRight" -> toSQLType(LongType),
+            "actual" -> Seq(first.dataType, second.dataType, third.dataType)
+              .map(toSQLType).mkString(", ")
+          )
+        )
     }
   }
   override def nullable: Boolean = true
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
index 6a22414db00..cf5d4c8c1e9 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Cast.toSQLValue
 import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
 import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.LongType
 
 /**
  * Query tests for the Bloom filter aggregate and filter function.
@@ -62,8 +64,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with 
SharedSparkSession {
     val table = "bloom_filter_test"
     for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, 
Long.MaxValue,
       conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
-      for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, 
Long.MaxValue,
-        conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) {
+      for ((numBits, index) <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, 
Long.MaxValue,
+        conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)).zipWithIndex) 
{
         val sqlString = s"""
                            |SELECT every(might_contain(
                            |            (SELECT bloom_filter_agg(col,
@@ -87,13 +89,57 @@ class BloomFilterAggregateQuerySuite extends QueryTest with 
SharedSparkSession {
             val exception = intercept[AnalysisException] {
               spark.sql(sqlString)
             }
-            assert(exception.getMessage.contains(
-              "The estimated number of items must be a positive value"))
+            val stop = numEstimatedItems match {
+              case Long.MinValue => Seq(169, 152, 150, 153, 156, 168, 157)
+              case -10L => Seq(152, 135, 133, 136, 139, 151, 140)
+              case 0L => Seq(150, 133, 131, 134, 137, 149, 138)
+            }
+            checkError(
+              exception = exception,
+              errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+              parameters = Map(
+                "exprName" -> "estimatedNumItems",
+                "valueRange" -> "[0, positive]",
+                "currentValue" -> toSQLValue(numEstimatedItems, LongType),
+                "sqlExpr" -> (s""""bloom_filter_agg(col, 
CAST($numEstimatedItems AS BIGINT), """ +
+                  s"""CAST($numBits AS BIGINT))"""")
+              ),
+              context = ExpectedContext(
+                fragment = "bloom_filter_agg(col,\n" +
+                  s"              cast($numEstimatedItems as long),\n" +
+                  s"              cast($numBits as long))",
+                start = 49,
+                stop = stop(index)
+              )
+            )
           } else if (numBits <= 0) {
             val exception = intercept[AnalysisException] {
               spark.sql(sqlString)
             }
-            assert(exception.getMessage.contains("The number of bits must be a 
positive value"))
+            val stop = numEstimatedItems match {
+              case 4096L => Seq(153, 136, 134)
+              case 4194304L => Seq(156, 139, 137)
+              case Long.MaxValue => Seq(168, 151, 149)
+              case 4000000 => Seq(156, 139, 137)
+            }
+            checkError(
+              exception = exception,
+              errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
+              parameters = Map(
+                "exprName" -> "numBits",
+                "valueRange" -> "[0, positive]",
+                "currentValue" -> toSQLValue(numBits, LongType),
+                "sqlExpr" -> (s""""bloom_filter_agg(col, 
CAST($numEstimatedItems AS BIGINT), """ +
+                  s"""CAST($numBits AS BIGINT))"""")
+              ),
+              context = ExpectedContext(
+                fragment = "bloom_filter_agg(col,\n" +
+                  s"              cast($numEstimatedItems as long),\n" +
+                  s"              cast($numBits as long))",
+                start = 49,
+                stop = stop(index)
+              )
+            )
           } else {
             checkAnswer(spark.sql(sqlString), Row(true, false))
           }
@@ -109,8 +155,22 @@ class BloomFilterAggregateQuerySuite extends QueryTest 
with SharedSparkSession {
         |FROM values (1.2), (2.5) as t(a)"""
         .stripMargin)
     }
-    assert(exception1.getMessage.contains(
-      "Input to function bloom_filter_agg should have been a bigint value"))
+    checkError(
+      exception = exception1,
+      errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE",
+      parameters = Map(
+        "functionName" -> "`bloom_filter_agg`",
+        "sqlExpr" -> "\"bloom_filter_agg(a, 1000000, 8388608)\"",
+        "expectedLeft" -> "\"BINARY\"",
+        "expectedRight" -> "\"BIGINT\"",
+        "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\", \"BIGINT\""
+      ),
+      context = ExpectedContext(
+        fragment = "bloom_filter_agg(a)",
+        start = 8,
+        stop = 26
+      )
+    )
 
     val exception2 = intercept[AnalysisException] {
       spark.sql("""
@@ -118,8 +178,22 @@ class BloomFilterAggregateQuerySuite extends QueryTest 
with SharedSparkSession {
         |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
         .stripMargin)
     }
-    assert(exception2.getMessage.contains(
-      "function bloom_filter_agg should have been a bigint value followed with 
two bigint"))
+    checkError(
+      exception = exception2,
+      errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE",
+      parameters = Map(
+        "functionName" -> "`bloom_filter_agg`",
+        "sqlExpr" -> "\"bloom_filter_agg(a, 2, (2 * 8))\"",
+        "expectedLeft" -> "\"BINARY\"",
+        "expectedRight" -> "\"BIGINT\"",
+        "actual" -> "\"BIGINT\", \"INT\", \"BIGINT\""
+      ),
+      context = ExpectedContext(
+        fragment = "bloom_filter_agg(a, 2)",
+        start = 8,
+        stop = 29
+      )
+    )
 
     val exception3 = intercept[AnalysisException] {
       spark.sql("""
@@ -127,8 +201,22 @@ class BloomFilterAggregateQuerySuite extends QueryTest 
with SharedSparkSession {
         |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
         .stripMargin)
     }
-    assert(exception3.getMessage.contains(
-      "function bloom_filter_agg should have been a bigint value followed with 
two bigint"))
+    checkError(
+      exception = exception3,
+      errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE",
+      parameters = Map(
+        "functionName" -> "`bloom_filter_agg`",
+        "sqlExpr" -> "\"bloom_filter_agg(a, CAST(2 AS BIGINT), 5)\"",
+        "expectedLeft" -> "\"BINARY\"",
+        "expectedRight" -> "\"BIGINT\"",
+        "actual" -> "\"BIGINT\", \"BIGINT\", \"INT\""
+      ),
+      context = ExpectedContext(
+        fragment = "bloom_filter_agg(a, cast(2 as long), 5)",
+        start = 8,
+        stop = 46
+      )
+    )
 
     val exception4 = intercept[AnalysisException] {
       spark.sql("""
@@ -136,7 +224,19 @@ class BloomFilterAggregateQuerySuite extends QueryTest 
with SharedSparkSession {
         |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
         .stripMargin)
     }
-    assert(exception4.getMessage.contains("Null typed values cannot be used as 
size arguments"))
+    checkError(
+      exception = exception4,
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+      parameters = Map(
+        "exprName" -> "estimatedNumItems or numBits",
+        "sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\""
+      ),
+      context = ExpectedContext(
+        fragment = "bloom_filter_agg(a, null, 5)",
+        start = 8,
+        stop = 35
+      )
+    )
 
     val exception5 = intercept[AnalysisException] {
       spark.sql("""
@@ -144,7 +244,19 @@ class BloomFilterAggregateQuerySuite extends QueryTest 
with SharedSparkSession {
         |FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
         .stripMargin)
     }
-    assert(exception5.getMessage.contains("Null typed values cannot be used as 
size arguments"))
+    checkError(
+      exception = exception5,
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+      parameters = Map(
+        "exprName" -> "estimatedNumItems or numBits",
+        "sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\""
+      ),
+      context = ExpectedContext(
+        fragment = "bloom_filter_agg(a, 5, null)",
+        start = 8,
+        stop = 35
+      )
+    )
   }
 
   test("Test that might_contain errors out disallowed input value types") {
@@ -160,8 +272,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with 
SharedSparkSession {
         "functionName" -> "`might_contain`",
         "expectedLeft" -> "\"BINARY\"",
         "expectedRight" -> "\"BIGINT\"",
-        "actualLeft" -> "\"DECIMAL(2,1)\"",
-        "actualRight" -> "\"BIGINT\""
+        "actual" -> "\"DECIMAL(2,1)\", \"BIGINT\""
       ),
       context = ExpectedContext(
         fragment = "might_contain(1.0, 1L)",
@@ -182,8 +293,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with 
SharedSparkSession {
         "functionName" -> "`might_contain`",
         "expectedLeft" -> "\"BINARY\"",
         "expectedRight" -> "\"BIGINT\"",
-        "actualLeft" -> "\"VOID\"",
-        "actualRight" -> "\"DECIMAL(1,1)\""
+        "actual" -> "\"VOID\", \"DECIMAL(1,1)\""
       ),
       context = ExpectedContext(
         fragment = "might_contain(NULL, 0.1)",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to