This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 68531ada34d [SPARK-40374][SQL] Migrate type check failures of type creators onto error classes 68531ada34d is described below commit 68531ada34db72d352c39396f85458a8370af812 Author: panbingkun <pbk1...@gmail.com> AuthorDate: Wed Nov 2 14:51:36 2022 +0300 [SPARK-40374][SQL] Migrate type check failures of type creators onto error classes ### What changes were proposed in this pull request? This pr replaces TypeCheckFailure by DataTypeMismatch in type checks in the complex type creator expressions, includes: 1. CreateMap (3): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala#L205-L214 2. CreateNamedStruct (3): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala#L445-L457 3. UpdateFields (2): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala#L670-L673 ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? 1. Add new UT 2. Update existed UT 3. Pass GA Closes #38463 from panbingkun/SPARK-40374. Authored-by: panbingkun <pbk1...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 20 ++++++ .../catalyst/expressions/complexTypeCreator.scala | 72 ++++++++++++++----- .../analysis/ExpressionTypeCheckingSuite.scala | 83 ++++++++++++++++------ .../catalyst/expressions/ComplexTypeSuite.scala | 47 ++++++++++++ .../main/scala/org/apache/spark/sql/Column.scala | 2 +- .../apache/spark/sql/ColumnExpressionSuite.scala | 82 ++++++++++++++++----- 6 files changed, 250 insertions(+), 56 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index fe2cd3a44bb..7ec5e11a206 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -138,6 +138,11 @@ "Unable to convert column <name> of type <type> to JSON." ] }, + "CANNOT_DROP_ALL_FIELDS" : { + "message" : [ + "Cannot drop all fields in struct." + ] + }, "CAST_WITHOUT_SUGGESTION" : { "message" : [ "cannot cast <srcType> to <targetType>." @@ -155,6 +160,21 @@ "To convert values from <srcType> to <targetType>, you can use the functions <functionNames> instead." ] }, + "CREATE_MAP_KEY_DIFF_TYPES" : { + "message" : [ + "The given keys of function <functionName> should all be the same type, but they are <dataType>." + ] + }, + "CREATE_MAP_VALUE_DIFF_TYPES" : { + "message" : [ + "The given values of function <functionName> should all be the same type, but they are <dataType>." + ] + }, + "CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING" : { + "message" : [ + "Only foldable `STRING` expressions are allowed to appear at odd position, but they are <inputExprs>." + ] + }, "DATA_DIFF_TYPES" : { "message" : [ "Input to <functionName> should all be the same type, but it's <dataType>." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 27d4f506ac8..97c882fd176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -22,6 +22,8 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -202,16 +204,30 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure( - s"$prettyName expects a positive even number of arguments.") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "2n (n > 0)", + "actualNum" -> children.length.toString + ) + ) } else if (!TypeCoercion.haveSameType(keys.map(_.dataType))) { - TypeCheckResult.TypeCheckFailure( - "The given keys of function map should all be the same type, but they are " + - keys.map(_.dataType.catalogString).mkString("[", ", ", "]")) + DataTypeMismatch( + errorSubClass = "CREATE_MAP_KEY_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> keys.map(key => toSQLType(key.dataType)).mkString("[", ", ", "]") + ) + ) } else if (!TypeCoercion.haveSameType(values.map(_.dataType))) { - TypeCheckResult.TypeCheckFailure( - "The given values of function map should all be the same type, but they are " + - values.map(_.dataType.catalogString).mkString("[", ", ", "]")) + DataTypeMismatch( + errorSubClass = "CREATE_MAP_VALUE_DIFF_TYPES", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "dataType" -> values.map(value => toSQLType(value.dataType)).mkString("[", ", ", "]") + ) + ) } else { TypeUtils.checkForMapKeyType(dataType.keyType) } @@ -444,17 +460,32 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "2n (n > 0)", + "actualNum" -> children.length.toString + ) + ) } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { - TypeCheckResult.TypeCheckFailure( - s"Only foldable ${StringType.catalogString} expressions are allowed to appear at odd" + - s" position, got: ${invalidNames.mkString(",")}") + DataTypeMismatch( + errorSubClass = "CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + messageParameters = Map( + "inputExprs" -> invalidNames.map(toSQLExpr(_)).mkString("[", ", ", "]") + ) + ) } else if (!names.contains(null)) { TypeCheckResult.TypeCheckSuccess } else { - TypeCheckResult.TypeCheckFailure("Field name should not be null") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_NULL", + messageParameters = Map( + "exprName" -> nameExprs.map(toSQLExpr).mkString("[", ", ", "]") + ) + ) } } } @@ -668,10 +699,19 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat override def checkInputDataTypes(): TypeCheckResult = { val dataType = structExpr.dataType if (!dataType.isInstanceOf[StructType]) { - TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " + - dataType.catalogString) + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(StructType), + "inputSql" -> toSQLExpr(structExpr), + "inputType" -> toSQLType(structExpr.dataType)) + ) } else if (newExprs.isEmpty) { - TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct") + DataTypeMismatch( + errorSubClass = "CANNOT_DROP_ALL_FIELDS", + messageParameters = Map.empty + ) } else { TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 83139ab719f..eb2ebce3a5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -40,6 +40,10 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer $"arrayField".array(StringType), Symbol("mapField").map(StringType, LongType)) + private def analysisException(expr: Expression): AnalysisException = { + intercept[AnalysisException](assertSuccess(expr)) + } + def assertError(expr: Expression, errorMessage: String): Unit = { val e = intercept[AnalysisException] { assertSuccess(expr) @@ -522,29 +526,68 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer } test("check types for CreateNamedStruct") { - assertError( - CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") - assertError( - CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") - assertError( - CreateNamedStruct(Seq($"a".string.at(0), "a", "b", 2.0)), - "Only foldable string expressions are allowed to appear at odd position") - assertError( - CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), - "Field name should not be null") + checkError( + exception = analysisException(CreateNamedStruct(Seq("a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + parameters = Map( + "sqlExpr" -> "\"named_struct(a, b, 2.0)\"", + "functionName" -> "`named_struct`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + checkError( + exception = analysisException(CreateNamedStruct(Seq(1, "a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + parameters = Map( + "sqlExpr" -> "\"named_struct(1, a, b, 2.0)\"", + "inputExprs" -> "[\"1\"]") + ) + checkError( + exception = analysisException(CreateNamedStruct(Seq($"a".string.at(0), "a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.CREATE_NAMED_STRUCT_WITHOUT_FOLDABLE_STRING", + parameters = Map( + "sqlExpr" -> "\"named_struct(boundreference(), a, b, 2.0)\"", + "inputExprs" -> "[\"boundreference()\"]") + ) + checkError( + exception = analysisException(CreateNamedStruct(Seq(Literal.create(null, StringType), "a"))), + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + parameters = Map( + "sqlExpr" -> "\"named_struct(NULL, a)\"", + "exprName" -> "[\"NULL\"]") + ) } test("check types for CreateMap") { - assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments") - assertError( - CreateMap(Seq($"intField", $"stringField", - $"booleanField", $"stringField")), - "keys of function map should all be the same type") - assertError( - CreateMap(Seq($"stringField", $"intField", - $"stringField", $"booleanField")), - "values of function map should all be the same type") + checkError( + exception = analysisException(CreateMap(Seq("a", "b", 2.0))), + errorClass = "DATATYPE_MISMATCH.WRONG_NUM_ARGS", + parameters = Map( + "sqlExpr" -> "\"map(a, b, 2.0)\"", + "functionName" -> "`map`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + checkError( + exception = analysisException(CreateMap(Seq(Literal(1), + Literal("a"), Literal(true), Literal("b")))), + errorClass = "DATATYPE_MISMATCH.CREATE_MAP_KEY_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"map(1, a, true, b)\"", + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"BOOLEAN\"]" + ) + ) + checkError( + exception = analysisException(CreateMap(Seq(Literal("a"), + Literal(1), Literal("b"), Literal(true)))), + errorClass = "DATATYPE_MISMATCH.CREATE_MAP_VALUE_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"map(a, 1, b, true)\"", + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"BOOLEAN\"]" + ) + ) } test("check types for ROUND/BROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index fb6a23e3d77..f1f781b7137 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util._ @@ -314,6 +315,40 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { assert(errorSubClass == "INVALID_MAP_KEY_TYPE") assert(messageParameters === Map("keyType" -> "\"MAP<INT, INT>\"")) } + + // expects a positive even number of arguments + val map3 = CreateMap(Seq(Literal(1), Literal(2), Literal(3))) + assert(map3.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> "`map`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + ) + + // The given keys of function map should all be the same type + val map4 = CreateMap(Seq(Literal(1), Literal(2), Literal('a'), Literal(3))) + assert(map4.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "CREATE_MAP_KEY_DIFF_TYPES", + messageParameters = Map( + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"STRING\"]") + ) + ) + + // The given values of function map should all be the same type + val map5 = CreateMap(Seq(Literal(1), Literal(2), Literal(3), Literal('a'))) + assert(map5.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "CREATE_MAP_VALUE_DIFF_TYPES", + messageParameters = Map( + "functionName" -> "`map`", + "dataType" -> "[\"INT\", \"STRING\"]") + ) + ) } test("MapFromArrays") { @@ -397,6 +432,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { create_row(UTF8String.fromString("x"), 2.0)) checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))), create_row(null)) + + // expects a positive even number of arguments + val namedStruct1 = CreateNamedStruct(Seq(Literal(1), Literal(2), Literal(3))) + assert(namedStruct1.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> "`named_struct`", + "expectedNum" -> "2n (n > 0)", + "actualNum" -> "3") + ) + ) } test("test dsl for complex type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 554f6a34b17..3c9f3e58cec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -940,7 +940,7 @@ class Column(val expr: Expression) extends Logging { * * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") * df.select($"struct_col".dropFields("a", "b")) - * // result: org.apache.spark.sql.AnalysisException: cannot resolve 'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot drop all fields in struct + * // result: org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS] Cannot resolve "update_fields(struct_col, dropfield(), dropfield())" due to data type mismatch: Cannot drop all fields in struct.; * * val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col") * df.select($"struct_col".dropFields("b")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d7ebb900388..718405bd8ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1027,9 +1027,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) test("withField should throw an exception if called on a non-StructType column") { - intercept[AnalysisException] { - testData.withColumn("key", $"key".withField("a", lit(2))) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + testData.withColumn("key", $"key".withField("a", lit(2))) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(key, WithField(2))\"", + "paramIndex" -> "1", + "inputSql" -> "\"key\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("withField should throw an exception if either fieldName or col argument are null") { @@ -1063,9 +1072,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should throw an exception if intermediate field is not a struct") { - intercept[AnalysisException] { - structLevel1.withColumn("a", $"a".withField("b.a", lit(2))) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + structLevel1.withColumn("a", $"a".withField("b.a", lit(2))) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(a.b, WithField(2))\"", + "paramIndex" -> "1", + "inputSql" -> "\"a.b\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("withField should throw an exception if intermediate field reference is ambiguous") { @@ -1792,9 +1810,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should throw an exception if called on a non-StructType column") { - intercept[AnalysisException] { - testData.withColumn("key", $"key".dropFields("a")) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + testData.withColumn("key", $"key".dropFields("a")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(key, dropfield())\"", + "paramIndex" -> "1", + "inputSql" -> "\"key\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("dropFields should throw an exception if fieldName argument is null") { @@ -1820,9 +1847,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should throw an exception if intermediate field is not a struct") { - intercept[AnalysisException] { - structLevel1.withColumn("a", $"a".dropFields("b.a")) - }.getMessage should include("struct argument should be struct type, got: int") + checkError( + exception = intercept[AnalysisException] { + structLevel1.withColumn("a", $"a".dropFields("b.a")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"update_fields(a.b, dropfield())\"", + "paramIndex" -> "1", + "inputSql" -> "\"a.b\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRUCT\"") + ) } test("dropFields should throw an exception if intermediate field reference is ambiguous") { @@ -1877,9 +1913,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should throw an exception if no fields will be left in struct") { - intercept[AnalysisException] { - structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) - }.getMessage should include("cannot drop all fields in struct") + checkError( + exception = intercept[AnalysisException] { + structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) + }, + errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", + parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\"") + ) } test("dropFields should drop field with no name in struct") { @@ -2144,10 +2184,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".dropFields("b", "c")), Row(Row(1))) - intercept[AnalysisException] { - sql("SELECT named_struct('a', 1, 'b', 2) struct_col") - .select($"struct_col".dropFields("a", "b")) - }.getMessage should include("cannot drop all fields in struct") + checkError( + exception = intercept[AnalysisException] { + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("a", "b")) + }, + errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", + parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\"") + ) checkAnswer( sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org