This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 625f76dae0d [SPARK-40760][SQL] Migrate type check failures of interval expressions onto error classes 625f76dae0d is described below commit 625f76dae0d9581428d6c5c4b58bf2958957c8c8 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Sun Oct 23 13:32:34 2022 +0500 [SPARK-40760][SQL] Migrate type check failures of interval expressions onto error classes ### What changes were proposed in this pull request? In the PR, I propose to add new error sub-classes of the error class `DATATYPE_MISMATCH`, and use it in the case of type check failures of some interval expressions. ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages, and improves search-ability of errors. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? By running the affected test suites: ``` $ build/sbt "test:testOnly *AnalysisSuite" $ build/sbt "test:testOnly *ExpressionTypeCheckingSuite" $ build/sbt "test:testOnly *ApproxCountDistinctForIntervalsSuite" ``` Closes #38237 from MaxGekk/type-check-fails-interval-exprs. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 5 +++ .../ApproxCountDistinctForIntervals.scala | 31 +++++++++++--- .../catalyst/expressions/aggregate/Average.scala | 2 +- .../sql/catalyst/expressions/aggregate/Sum.scala | 2 +- .../apache/spark/sql/catalyst/util/TypeUtils.scala | 20 +++++---- .../apache/spark/sql/types/AbstractDataType.scala | 9 ++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 50 ++++++++++++++-------- .../analysis/ExpressionTypeCheckingSuite.scala | 26 +++++++++-- .../ApproxCountDistinctForIntervalsSuite.scala | 21 ++++++--- 9 files changed, 123 insertions(+), 43 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 5f4db145479..0f9b665718c 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -263,6 +263,11 @@ "The <exprName> must be between <valueRange> (current value = <currentValue>)" ] }, + "WRONG_NUM_ENDPOINTS" : { + "message" : [ + "The number of endpoints must be >= 2 to construct intervals but the actual number is <actualNumber>." + ] + }, "WRONG_NUM_PARAMS" : { "message" : [ "The <functionName> requires <expectedNum> parameters but the actual number is <actualNum>." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index f3bf251ba0b..0be4e4aa465 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -21,10 +21,11 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} +import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -49,7 +50,10 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes with BinaryLike[Expression] { + extends TypedImperativeAggregate[Array[Long]] + with ExpectsInputTypes + with BinaryLike[Expression] + with QueryErrorsBase { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -77,19 +81,32 @@ case class ApproxCountDistinctForIntervals( if (defaultCheck.isFailure) { defaultCheck } else if (!endpointsExpression.foldable) { - TypeCheckFailure("The endpoints provided must be constant literals") + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "endpointsExpression", + "inputType" -> toSQLType(endpointsExpression.dataType))) } else { endpointsExpression.dataType match { case ArrayType(_: NumericType | DateType | TimestampType | TimestampNTZType | _: AnsiIntervalType, _) => if (endpoints.length < 2) { - TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ENDPOINTS", + messageParameters = Map("actualNumber" -> endpoints.length.toString)) } else { TypeCheckSuccess } - case _ => - TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " + - "interval year to month or interval day to second) type") + case inputType => + val requiredElemTypes = toSQLType(TypeCollection( + NumericType, DateType, TimestampType, TimestampNTZType, AnsiIntervalType)) + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2", + "requiredType" -> s"ARRAY OF $requiredElemTypes", + "inputSql" -> toSQLExpr(endpointsExpression), + "inputType" -> toSQLType(inputType))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index ae644e9d663..ce9fa0575f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -54,7 +54,7 @@ case class Average( Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "average") + TypeUtils.checkForAnsiIntervalOrNumericType(child) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 432d4b40b4a..2c892903437 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -67,7 +67,7 @@ case class Sum( Seq(TypeCollection(NumericType, YearMonthIntervalType, DayTimeIntervalType)) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, prettyName) + TypeUtils.checkForAnsiIntervalOrNumericType(child) final override val nodePatterns: Seq[TreePattern] = Seq(SUM) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 7cb471d14bd..0bb5d29c5c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -19,15 +19,14 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.Cast.toSQLType -import org.apache.spark.sql.catalyst.expressions.RowOrdering -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.types._ /** * Functions to help with checking for valid data types and value comparison of various types. */ -object TypeUtils { +object TypeUtils extends QueryErrorsBase { def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { if (RowOrdering.isOrderable(dt)) { @@ -70,13 +69,18 @@ object TypeUtils { } } - def checkForAnsiIntervalOrNumericType( - dt: DataType, funcName: String): TypeCheckResult = dt match { + def checkForAnsiIntervalOrNumericType(input: Expression): TypeCheckResult = input.dataType match { case _: AnsiIntervalType | NullType => TypeCheckResult.TypeCheckSuccess case dt if dt.isInstanceOf[NumericType] => TypeCheckResult.TypeCheckSuccess - case other => TypeCheckResult.TypeCheckFailure( - s"function $funcName requires numeric or interval types, not ${other.catalogString}") + case other => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> Seq(NumericType, AnsiIntervalType).map(toSQLType).mkString(" or "), + "inputSql" -> toSQLExpr(input), + "inputType" -> toSQLType(other))) } def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index ebcf35a0674..294fb13e48c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -233,3 +233,12 @@ private[sql] abstract class DatetimeType extends AtomicType * The interval type which conforms to the ANSI SQL standard. */ private[sql] abstract class AnsiIntervalType extends AtomicType + +private[spark] object AnsiIntervalType extends AbstractDataType { + override private[sql] def simpleString: String = "ANSI interval" + + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[AnsiIntervalType] + + override private[sql] def defaultConcreteType: DataType = DayTimeIntervalType() +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3036742c83f..6f0e6ef0c11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1163,25 +1163,39 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-38118: Func(wrong_type) in the HAVING clause should throw data mismatch error") { - assertAnalysisError(parsePlan( - s""" - |WITH t as (SELECT true c) - |SELECT t.c - |FROM t - |GROUP BY t.c - |HAVING mean(t.c) > 0d""".stripMargin), - Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"), - false) + assertAnalysisErrorClass( + inputPlan = parsePlan( + s""" + |WITH t as (SELECT true c) + |SELECT t.c + |FROM t + |GROUP BY t.c + |HAVING mean(t.c) > 0d""".stripMargin), + expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + expectedMessageParameters = Map( + "sqlExpr" -> "\"mean(c)\"", + "paramIndex" -> "1", + "inputSql" -> "\"c\"", + "inputType" -> "\"BOOLEAN\"", + "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""), + caseSensitive = false) - assertAnalysisError(parsePlan( - s""" - |WITH t as (SELECT true c, false d) - |SELECT (t.c AND t.d) c - |FROM t - |GROUP BY t.c, t.d - |HAVING mean(c) > 0d""".stripMargin), - Seq(s"cannot resolve 'mean(t.c)' due to data type mismatch"), - false) + assertAnalysisErrorClass( + inputPlan = parsePlan( + s""" + |WITH t as (SELECT true c, false d) + |SELECT (t.c AND t.d) c + |FROM t + |GROUP BY t.c, t.d + |HAVING mean(c) > 0d""".stripMargin), + expectedErrorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + expectedMessageParameters = Map( + "sqlExpr" -> "\"mean(c)\"", + "paramIndex" -> "1", + "inputSql" -> "\"c\"", + "inputType" -> "\"BOOLEAN\"", + "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\""), + caseSensitive = false) assertAnalysisErrorClass( inputPlan = parsePlan( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 991721a55ca..b41f627bac9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -396,9 +396,29 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer "dataType" -> "\"MAP<STRING, BIGINT>\"" ) ) - assertError(Sum($"booleanField"), "function sum requires numeric or interval types") - assertError(Average($"booleanField"), - "function average requires numeric or interval types") + + checkError( + exception = intercept[AnalysisException] { + assertSuccess(Sum($"booleanField")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"sum(booleanField)\"", + "paramIndex" -> "1", + "inputSql" -> "\"booleanField\"", + "inputType" -> "\"BOOLEAN\"", + "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\"")) + checkError( + exception = intercept[AnalysisException] { + assertSuccess(Average($"booleanField")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"avg(booleanField)\"", + "paramIndex" -> "1", + "inputSql" -> "\"booleanField\"", + "inputType" -> "\"BOOLEAN\"", + "requiredType" -> "\"NUMERIC\" or \"ANSI INTERVAL\"")) } test("check types for others") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index d00193c4f3b..bb99e1c1e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -22,7 +22,7 @@ import java.time.LocalDateTime import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, CreateArray, Literal, SpecificInternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils} import org.apache.spark.sql.types._ @@ -48,20 +48,31 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) assert(wrongEndpoints.checkInputDataTypes() == - TypeCheckFailure("The endpoints provided must be constant literals")) + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "endpointsExpression", + "inputType" -> "\"ARRAY<DOUBLE>\""))) wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == - TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) + DataTypeMismatch("WRONG_NUM_ENDPOINTS", Map("actualNumber" -> "1"))) wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) + // scalastyle:off line.size.limit assert(wrongEndpoints.checkInputDataTypes() == - TypeCheckFailure("Endpoints require (numeric or timestamp or date or timestamp_ntz or " + - "interval year to month or interval day to second) type")) + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "2", + "requiredType" -> "ARRAY OF (\"NUMERIC\" or \"DATE\" or \"TIMESTAMP\" or \"TIMESTAMP_NTZ\" or \"ANSI INTERVAL\")", + "inputSql" -> "\"array(foobar)\"", + "inputType" -> "\"ARRAY<STRING>\""))) + // scalastyle:on line.size.limit } /** Create an ApproxCountDistinctForIntervals instance and an input and output buffer. */ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org