This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 74668e2bf14 [SPARK-40751][SQL] Migrate type check failures of high order functions onto error classes 74668e2bf14 is described below commit 74668e2bf14760dbc60509f7736f410c09084697 Author: panbingkun <pbk1...@gmail.com> AuthorDate: Thu Oct 27 13:47:54 2022 +0300 [SPARK-40751][SQL] Migrate type check failures of high order functions onto error classes ### What changes were proposed in this pull request? This pr aims to replace TypeCheckFailure by DataTypeMismatch in type checks in the high-order functions expressions, includes: - 1. ArraySort (2): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L403-L407 - 2. ArrayAggregate (1): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L807 - 3. MapZipWith (1): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala#L1028 ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? - Update existed UT - Pass GA. Closes #38359 from panbingkun/SPARK-40751. Authored-by: panbingkun <pbk1...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 10 +++ .../scala/org/apache/spark/SparkFunSuite.scala | 6 ++ .../expressions/higherOrderFunctions.scala | 43 ++++++++--- .../expressions/HigherOrderFunctionsSuite.scala | 18 +++++ .../results/typeCoercion/native/mapZipWith.sql.out | 35 ++++++++- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 89 +++++++++++++++++----- 6 files changed, 171 insertions(+), 30 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index d72eeece82e..015d86171d7 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -200,6 +200,11 @@ "The <functionName> accepts only arrays of pair structs, but <childExpr> is of <childType>." ] }, + "MAP_ZIP_WITH_DIFF_TYPES" : { + "message" : [ + "Input to the <functionName> should have been two maps with compatible key types, but it's [<leftType>, <rightType>]." + ] + }, "NON_FOLDABLE_INPUT" : { "message" : [ "the input <inputName> should be a foldable <inputType> expression; however, got <inputExpr>." @@ -275,6 +280,11 @@ "The <exprName> must not be null" ] }, + "UNEXPECTED_RETURN_TYPE" : { + "message" : [ + "The <functionName> requires return <expectedType> type, but the actual is <actualType> type." + ] + }, "UNEXPECTED_STATIC_METHOD" : { "message" : [ "cannot find a static method <methodName> that matches the argument types in <className>" diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 46b62d879cf..7a08de9c181 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -370,6 +370,12 @@ abstract class SparkFunSuite checkError(exception, errorClass, sqlState, parameters, false, Array(context)) + protected def checkErrorMatchPVals( + exception: SparkThrowable, + errorClass: String, + parameters: Map[String, String]): Unit = + checkError(exception, errorClass, None, parameters, matchPVals = true) + protected def checkErrorMatchPVals( exception: SparkThrowable, errorClass: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 98513fb5ddd..b59860ff181 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -24,6 +24,8 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -400,11 +402,25 @@ case class ArraySort( if (function.dataType == IntegerType) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure("Return type of the given function has to be " + - "IntegerType") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_RETURN_TYPE", + messageParameters = Map( + "functionName" -> toSQLId(function.prettyName), + "expectedType" -> toSQLType(IntegerType), + "actualType" -> toSQLType(function.dataType) + ) + ) } case _ => - TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(ArrayType), + "inputSql" -> toSQLExpr(argument), + "inputType" -> toSQLType(argument.dataType) + ) + ) } case failure => failure } @@ -804,9 +820,13 @@ case class ArrayAggregate( case TypeCheckResult.TypeCheckSuccess => if (!DataType.equalsStructurally( zero.dataType, merge.dataType, ignoreNullability = true)) { - TypeCheckResult.TypeCheckFailure( - s"argument 3 requires ${zero.dataType.simpleString} type, " + - s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "3", + "requiredType" -> toSQLType(zero.dataType), + "inputSql" -> toSQLExpr(merge), + "inputType" -> toSQLType(merge.dataType))) } else { TypeCheckResult.TypeCheckSuccess } @@ -1025,9 +1045,14 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) if (leftKeyType.sameType(rightKeyType)) { TypeUtils.checkForOrderingExpr(leftKeyType, prettyName) } else { - TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + - s"been two ${MapType.simpleString}s with compatible key types, but the key types are " + - s"[${leftKeyType.catalogString}, ${rightKeyType.catalogString}].") + DataTypeMismatch( + errorSubClass = "MAP_ZIP_WITH_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "leftType" -> toSQLType(leftKeyType), + "rightType" -> toSQLType(rightKeyType) + ) + ) } case failure => failure } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index a6546d8a5db..5f62dc97086 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -859,4 +861,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Seq(1, 1, 2, 3)) } } + + test("Return type of the given function has to be IntegerType") { + val comparator = { + val comp = ArraySort.comparator _ + (left: Expression, right: Expression) => Literal.create("hello", StringType) + } + + val result = arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator).checkInputDataTypes() + assert(result == DataTypeMismatch( + errorSubClass = "UNEXPECTED_RETURN_TYPE", + messageParameters = Map( + "functionName" -> toSQLId("lambdafunction"), + "expectedType" -> toSQLType(IntegerType), + "actualType" -> toSQLType(StringType) + ))) + } } diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 2f176951df8..09c6e10f762 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -82,8 +82,22 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.decimal_map1, various_maps.decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7 - +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + "messageParameters" : { + "functionName" : "`map_zip_with`", + "leftType" : "\"DECIMAL(36,0)\"", + "rightType" : "\"DECIMAL(36,35)\"", + "sqlExpr" : "\"map_zip_with(decimal_map1, decimal_map2, lambdafunction(struct(k, v1, v2), k, v1, v2))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 81, + "fragment" : "map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2))" + } ] +} -- !query SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m @@ -110,7 +124,22 @@ FROM various_maps struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'map_zip_with(various_maps.decimal_map2, various_maps.int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 +{ + "errorClass" : "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + "messageParameters" : { + "functionName" : "`map_zip_with`", + "leftType" : "\"DECIMAL(36,35)\"", + "rightType" : "\"INT\"", + "sqlExpr" : "\"map_zip_with(decimal_map2, int_map, lambdafunction(struct(k, v1, v2), k, v1, v2))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 76, + "fragment" : "map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2))" + } ] +} -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 85877c97ed5..3f02429fe62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -533,6 +533,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ) } + test("The given function only supports array input") { + val df = Seq(1, 2, 3).toDF("a") + checkErrorMatchPVals( + exception = intercept[AnalysisException] { + df.select(array_sort(col("a"), (x, y) => x - y)) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> """"array_sort\(a, lambdafunction\(\(x_\d+ - y_\d+\), x_\d+, y_\d+\)\)"""", + "paramIndex" -> "1", + "requiredType" -> "\"ARRAY\"", + "inputSql" -> "\"a\"", + "inputType" -> "\"INT\"" + )) + } + test("sort_array/array_sort functions") { val df = Seq( (Array[Int](2, 1, 3), Array("b", "c", "a")), @@ -3492,15 +3508,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"")) // scalastyle:on line.size.limit - val ex4 = intercept[AnalysisException] { - df.selectExpr("aggregate(s, 0, (acc, x) -> x)") - } - assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + // scalastyle:off line.size.limit + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("aggregate(s, 0, (acc, x) -> x)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""", + "paramIndex" -> "3", + "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"INT\"" + )) + // scalastyle:on line.size.limit - val ex4a = intercept[AnalysisException] { - df.select(aggregate(col("s"), lit(0), (acc, x) => x)) - } - assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) + // scalastyle:off line.size.limit + checkError( + exception = intercept[AnalysisException] { + df.select(aggregate(col("s"), lit(0), (acc, x) => x)) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> """"aggregate(s, 0, lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))"""", + "paramIndex" -> "3", + "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "\"INT\"" + )) + // scalastyle:on line.size.limit checkError( exception = @@ -3570,17 +3606,34 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match")) - val ex2 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") - } - assert(ex2.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with compatible key types")) + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))") + }, + errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"map_zip_with(mis, mmi, lambdafunction(concat(x, y, z), x, y, z))\"", + "functionName" -> "`map_zip_with`", + "leftType" -> "\"INT\"", + "rightType" -> "\"MAP<INT, INT>\""), + context = ExpectedContext( + fragment = "map_zip_with(mis, mmi, (x, y, z) -> concat(x, y, z))", + start = 0, + stop = 51)) - val ex2a = intercept[AnalysisException] { - df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) - } - assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + - "been two maps with compatible key types")) + // scalastyle:off line.size.limit + checkError( + exception = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) + }, + errorClass = "DATATYPE_MISMATCH.MAP_ZIP_WITH_DIFF_TYPES", + matchPVals = true, + parameters = Map( + "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", + "functionName" -> "`map_zip_with`", + "leftType" -> "\"INT\"", + "rightType" -> "\"MAP<INT, INT>\"")) + // scalastyle:on line.size.limit checkError( exception = intercept[AnalysisException] { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org