This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 625f76dae0d [SPARK-40760][SQL] Migrate type check failures of interval 
expressions onto error classes
625f76dae0d is described below

commit 625f76dae0d9581428d6c5c4b58bf2958957c8c8
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Sun Oct 23 13:32:34 2022 +0500

    [SPARK-40760][SQL] Migrate type check failures of interval expressions onto 
error classes
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to add new error sub-classes of the error class 
`DATATYPE_MISMATCH`, and use it in the case of type check failures of some 
interval expressions.
    
    ### Why are the changes needed?
    Migration onto error classes unifies Spark SQL error messages, and improves 
search-ability of errors.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. The PR changes user-facing error messages.
    
    ### How was this patch tested?
    By running the affected test suites:
    ```
    $ build/sbt "test:testOnly *AnalysisSuite"
    $ build/sbt "test:testOnly *ExpressionTypeCheckingSuite"
    $ build/sbt "test:testOnly *ApproxCountDistinctForIntervalsSuite"
    ```
    
    Closes #38237 from MaxGekk/type-check-fails-interval-exprs.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 core/src/main/resources/error/error-classes.json   |  5 +++
 .../ApproxCountDistinctForIntervals.scala          | 31 +++++++++++---
 .../catalyst/expressions/aggregate/Average.scala   |  2 +-
 .../sql/catalyst/expressions/aggregate/Sum.scala   |  2 +-
 .../apache/spark/sql/catalyst/util/TypeUtils.scala | 20 +++++----
 .../apache/spark/sql/types/AbstractDataType.scala  |  9 ++++
 .../sql/catalyst/analysis/AnalysisSuite.scala      | 50 ++++++++++++++--------
 .../analysis/ExpressionTypeCheckingSuite.scala     | 26 +++++++++--
 .../ApproxCountDistinctForIntervalsSuite.scala     | 21 ++++++---
 9 files changed, 123 insertions(+), 43 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json 
b/core/src/main/resources/error/error-classes.json
index 5f4db145479..0f9b665718c 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -263,6 +263,11 @@
           "The <exprName> must be between <valueRange> (current value = 
<currentValue>)"
         ]
       },
+      "WRONG_NUM_ENDPOINTS" : {
+        "message" : [
+          "The number of endpoints must be >= 2 to construct intervals but the 
actual number is <actualNumber>."
+        ]
+      },
       "WRONG_NUM_PARAMS" : {
         "message" : [
           "The <functionName> requires <expectedNum> parameters but the actual 
number is <actualNum>."
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
index f3bf251ba0b..0be4e4aa465 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala
@@ -21,10 +21,11 @@ import java.util
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
+import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, 
TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, GenericInternalRow}
 import org.apache.spark.sql.catalyst.trees.BinaryLike
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, 
HyperLogLogPlusPlusHelper}
+import org.apache.spark.sql.errors.QueryErrorsBase
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
 
@@ -49,7 +50,10 @@ case class ApproxCountDistinctForIntervals(
     relativeSD: Double = 0.05,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with 
BinaryLike[Expression] {
+  extends TypedImperativeAggregate[Array[Long]]
+  with ExpectsInputTypes
+  with BinaryLike[Expression]
+  with QueryErrorsBase {
 
   def this(child: Expression, endpointsExpression: Expression, relativeSD: 
Expression) = {
     this(
@@ -77,19 +81,32 @@ case class ApproxCountDistinctForIntervals(
     if (defaultCheck.isFailure) {
       defaultCheck
     } else if (!endpointsExpression.foldable) {
-      TypeCheckFailure("The endpoints provided must be constant literals")
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> "endpointsExpression",
+          "inputType" -> toSQLType(endpointsExpression.dataType)))
     } else {
       endpointsExpression.dataType match {
         case ArrayType(_: NumericType | DateType | TimestampType | 
TimestampNTZType |
            _: AnsiIntervalType, _) =>
           if (endpoints.length < 2) {
-            TypeCheckFailure("The number of endpoints must be >= 2 to 
construct intervals")
+            DataTypeMismatch(
+              errorSubClass = "WRONG_NUM_ENDPOINTS",
+              messageParameters = Map("actualNumber" -> 
endpoints.length.toString))
           } else {
             TypeCheckSuccess
           }
-        case _ =>
-          TypeCheckFailure("Endpoints require (numeric or timestamp or date or 
timestamp_ntz or " +
-            "interval year to month or interval day to second) type")
+        case inputType =>
+          val requiredElemTypes = toSQLType(TypeCollection(
+            NumericType, DateType, TimestampType, TimestampNTZType, 
AnsiIntervalType))
+          DataTypeMismatch(
+            errorSubClass = "UNEXPECTED_INPUT_TYPE",
+            messageParameters = Map(
+              "paramIndex" -> "2",
+              "requiredType" -> s"ARRAY OF $requiredElemTypes",
+              "inputSql" -> toSQLExpr(endpointsExpression),
+              "inputType" -> toSQLType(inputType)))
       }
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index ae644e9d663..ce9fa0575f2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -54,7 +54,7 @@ case class Average(
     Seq(TypeCollection(NumericType, YearMonthIntervalType, 
DayTimeIntervalType))
 
   override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average")
+    TypeUtils.checkForAnsiIntervalOrNumericType(child)
 
   override def nullable: Boolean = true
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 432d4b40b4a..2c892903437 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -67,7 +67,7 @@ case class Sum(
     Seq(TypeCollection(NumericType, YearMonthIntervalType, 
DayTimeIntervalType))
 
   override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName)
+    TypeUtils.checkForAnsiIntervalOrNumericType(child)
 
   final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 7cb471d14bd..0bb5d29c5c4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.util
 
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
-import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType
-import org.apache.spark.sql.catalyst.expressions.RowOrdering
-import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
 import org.apache.spark.sql.types._
 
 /**
  * Functions to help with checking for valid data types and value comparison 
of various types.
  */
-object TypeUtils {
+object TypeUtils extends QueryErrorsBase {
 
   def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = {
     if (RowOrdering.isOrderable(dt)) {
@@ -70,13 +69,18 @@ object TypeUtils {
     }
   }
 
-  def checkForAnsiIntervalOrNumericType(
-      dt: DataType, funcName: String): TypeCheckResult = dt match {
+  def checkForAnsiIntervalOrNumericType(input: Expression): TypeCheckResult = 
input.dataType match {
     case _: AnsiIntervalType | NullType =>
       TypeCheckResult.TypeCheckSuccess
     case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess
-    case other => TypeCheckResult.TypeCheckFailure(
-      s"function $funcName requires numeric or interval types, not 
${other.catalogString}")
+    case other =>
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> "1",
+          "requiredType" -> Seq(NumericType, 
AnsiIntervalType).map(toSQLType).mkString(" or "),
+          "inputSql" -> toSQLExpr(input),
+          "inputType" -> toSQLType(other)))
   }
 
   def getNumeric(t: DataType, exactNumericRequired: Boolean = false): 
Numeric[Any] = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index ebcf35a0674..294fb13e48c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -233,3 +233,12 @@ private[sql] abstract class DatetimeType extends AtomicType
  * The interval type which conforms to the ANSI SQL standard.
  */
 private[sql] abstract class AnsiIntervalType extends AtomicType
+
+private[spark] object AnsiIntervalType extends AbstractDataType {
+  override private[sql] def simpleString: String = "ANSI interval"
+
+  override private[sql] def acceptsType(other: DataType): Boolean =
+    other.isInstanceOf[AnsiIntervalType]
+
+  override private[sql] def defaultConcreteType: DataType = 
DayTimeIntervalType()
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 3036742c83f..6f0e6ef0c11 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -1163,25 +1163,39 @@ class AnalysisSuite extends AnalysisTest with Matchers {
   }
 
   test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data 
mismatch error") {
-    assertAnalysisError(parsePlan(
-      s"""
-         |WITH t as (SELECT true c)
-         |SELECT t.c
-         |FROM t
-         |GROUP BY t.c
-         |HAVING mean(t.c) > 0d""".stripMargin),
-      Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
-      false)
+    assertAnalysisErrorClass(
+      inputPlan = parsePlan(
+        s"""
+           |WITH t as (SELECT true c)
+           |SELECT t.c
+           |FROM t
+           |GROUP BY t.c
+           |HAVING mean(t.c) > 0d""".stripMargin),
+      expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      expectedMessageParameters = Map(
+        "sqlExpr" -> "\"mean(c)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"c\"",
+        "inputType" -> "\"BOOLEAN\"",
+        "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
+      caseSensitive = false)
 
-    assertAnalysisError(parsePlan(
-      s"""
-         |WITH t as (SELECT true c, false d)
-         |SELECT (t.c AND t.d) c
-         |FROM t
-         |GROUP BY t.c, t.d
-         |HAVING mean(c) > 0d""".stripMargin),
-      Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"),
-      false)
+    assertAnalysisErrorClass(
+      inputPlan = parsePlan(
+        s"""
+           |WITH t as (SELECT true c, false d)
+           |SELECT (t.c AND t.d) c
+           |FROM t
+           |GROUP BY t.c, t.d
+           |HAVING mean(c) > 0d""".stripMargin),
+      expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      expectedMessageParameters = Map(
+        "sqlExpr" -> "\"mean(c)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"c\"",
+        "inputType" -> "\"BOOLEAN\"",
+        "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""),
+      caseSensitive = false)
 
     assertAnalysisErrorClass(
       inputPlan = parsePlan(
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 991721a55ca..b41f627bac9 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -396,9 +396,29 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite 
with SQLHelper with Quer
         "dataType" -> "\"MAP<STRING, BIGINT>\""
       )
     )
-    assertError(Sum($"booleanField"), "function sum requires numeric or 
interval types")
-    assertError(Average($"booleanField"),
-      "function average requires numeric or interval types")
+
+    checkError(
+      exception = intercept[AnalysisException] {
+        assertSuccess(Sum($"booleanField"))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"sum(booleanField)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"booleanField\"",
+        "inputType" -> "\"BOOLEAN\"",
+        "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
+    checkError(
+      exception = intercept[AnalysisException] {
+        assertSuccess(Average($"booleanField"))
+      },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      parameters = Map(
+        "sqlExpr" -> "\"avg(booleanField)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"booleanField\"",
+        "inputType" -> "\"BOOLEAN\"",
+        "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""))
   }
 
   test("check types for others") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
index d00193c4f3b..bb99e1c1e8e 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala
@@ -22,7 +22,7 @@ import java.time.LocalDateTime
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
BoundReference, CreateArray, Literal, SpecificInternalRow}
 import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils}
 import org.apache.spark.sql.types._
@@ -48,20 +48,31 @@ class ApproxCountDistinctForIntervalsSuite extends 
SparkFunSuite {
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Seq(AttributeReference("b", 
DoubleType)())))
     assert(wrongEndpoints.checkInputDataTypes() ==
-      TypeCheckFailure("The endpoints provided must be constant literals"))
+      DataTypeMismatch(
+        errorSubClass = "NON_FOLDABLE_INPUT",
+        messageParameters = Map(
+          "inputName" -> "endpointsExpression",
+          "inputType" -> "\"ARRAY<DOUBLE>\"")))
 
     wrongEndpoints = ApproxCountDistinctForIntervals(
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Array(10L).map(Literal(_))))
     assert(wrongEndpoints.checkInputDataTypes() ==
-      TypeCheckFailure("The number of endpoints must be >= 2 to construct 
intervals"))
+      DataTypeMismatch("WRONG_NUM_ENDPOINTS", Map("actualNumber" -> "1")))
 
     wrongEndpoints = ApproxCountDistinctForIntervals(
       AttributeReference("a", DoubleType)(),
       endpointsExpression = CreateArray(Array("foobar").map(Literal(_))))
+    // scalastyle:off line.size.limit
     assert(wrongEndpoints.checkInputDataTypes() ==
-      TypeCheckFailure("Endpoints require (numeric or timestamp or date or 
timestamp_ntz or " +
-        "interval year to month or interval day to second) type"))
+      DataTypeMismatch(
+        errorSubClass = "UNEXPECTED_INPUT_TYPE",
+        messageParameters = Map(
+          "paramIndex" -> "2",
+          "requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or 
\"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")",
+          "inputSql" -> "\"array(foobar)\"",
+          "inputType" -> "\"ARRAY<STRING>\"")))
+    // scalastyle:on line.size.limit
   }
 
   /** Create an ApproxCountDistinctForIntervals instance and an input and 
output buffer. */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to