This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f82a5e75cfb [SPARK-40749][SQL] Migrate type check failures of generators onto error classes f82a5e75cfb is described below commit f82a5e75cfbc0d5dea249029354737e811765e6a Author: panbingkun <pbk1...@gmail.com> AuthorDate: Fri Nov 4 12:21:35 2022 +0300 [SPARK-40749][SQL] Migrate type check failures of generators onto error classes ### What changes were proposed in this pull request? This pr aims to A.check error classes in GeneratorFunctionSuite by using checkError() B.replaces TypeCheckFailure by DataTypeMismatch in type checks in the generator expressions, includes: 1. Stack (3): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala#L163-L170 2. ExplodeBase (1): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala#L299 3. Inline (1): https://github.com/apache/spark/blob/1431975723d8df30a25b2333eddcfd0bb6c57677/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala#L441 ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? 1. Add new UT 2. Update existed UT 3. Pass GA Closes #38482 from panbingkun/SPARK-40749. Authored-by: panbingkun <pbk1...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 5 + .../sql/catalyst/expressions/generators.scala | 74 ++++++-- .../analysis/ExpressionTypeCheckingSuite.scala | 28 ++- .../apache/spark/sql/GeneratorFunctionSuite.scala | 198 +++++++++++++++------ 4 files changed, 236 insertions(+), 69 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 7fc806752be..f4b7874217a 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -335,6 +335,11 @@ "The lower bound of a window frame must be <comparison> to the upper bound." ] }, + "STACK_COLUMN_DIFF_TYPES" : { + "message" : [ + "The data type of the column (<columnIndex>) do not have the same type: <leftType> (<leftParamIndex>) <> <rightType> (<rightParamIndex>)." + ] + }, "UNEXPECTED_CLASS_TYPE" : { "message" : [ "class <className> not found" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d305b4d3700..1d60dd3795e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -22,6 +22,8 @@ import scala.collection.mutable import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreePattern.{GENERATOR, TreePattern} @@ -160,16 +162,54 @@ case class Stack(children: Seq[Expression]) extends Generator { override def checkInputDataTypes(): TypeCheckResult = { if (children.length <= 1) { - TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.") - } else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) { - TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.") + DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARGS", + messageParameters = Map( + "functionName" -> toSQLId(prettyName), + "expectedNum" -> "> 1", + "actualNum" -> children.length.toString) + ) + } else if (children.head.dataType != IntegerType) { + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(IntegerType), + "inputSql" -> toSQLExpr(children.head), + "inputType" -> toSQLType(children.head.dataType)) + ) + } else if (!children.head.foldable) { + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "n", + "inputType" -> toSQLType(IntegerType), + "inputExpr" -> toSQLExpr(children.head) + ) + ) + } else if (numRows < 1) { + DataTypeMismatch( + errorSubClass = "VALUE_OUT_OF_RANGE", + messageParameters = Map( + "exprName" -> toSQLId("n"), + "valueRange" -> s"(0, ${Int.MaxValue}]", + "currentValue" -> toSQLValue(numRows, children.head.dataType) + ) + ) } else { for (i <- 1 until children.length) { val j = (i - 1) % numFields if (children(i).dataType != elementSchema.fields(j).dataType) { - return TypeCheckResult.TypeCheckFailure( - s"Argument ${j + 1} (${elementSchema.fields(j).dataType.catalogString}) != " + - s"Argument $i (${children(i).dataType.catalogString})") + return DataTypeMismatch( + errorSubClass = "STACK_COLUMN_DIFF_TYPES", + messageParameters = Map( + "columnIndex" -> j.toString, + "leftParamIndex" -> (j + 1).toString, + "leftType" -> toSQLType(elementSchema.fields(j).dataType), + "rightParamIndex" -> i.toString, + "rightType" -> toSQLType(children(i).dataType) + ) + ) } } TypeCheckResult.TypeCheckSuccess @@ -296,9 +336,14 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with case _: ArrayType | _: MapType => TypeCheckResult.TypeCheckSuccess case _ => - TypeCheckResult.TypeCheckFailure( - "input to function explode should be array or map type, " + - s"not ${child.dataType.catalogString}") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType(TypeCollection(ArrayType, MapType)), + "inputSql" -> toSQLExpr(child), + "inputType" -> toSQLType(child.dataType)) + ) } // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) @@ -438,9 +483,14 @@ case class Inline(child: Expression) extends UnaryExpression with CollectionGene case ArrayType(st: StructType, _) => TypeCheckResult.TypeCheckSuccess case _ => - TypeCheckResult.TypeCheckFailure( - s"input to function $prettyName should be array of struct type, " + - s"not ${child.dataType.catalogString}") + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> "1", + "requiredType" -> toSQLType("ARRAY<STRUCT>"), + "inputSql" -> toSQLExpr(child), + "inputType" -> toSQLType(child.dataType)) + ) } override def elementSchema: StructType = child.dataType match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index eb2ebce3a5f..f656131c8e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -519,10 +519,30 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite with SQLHelper with Quer "expectedNum" -> "> 0", "actualNum" -> "0")) - assertError(Explode($"intField"), - "input to function explode should be array or map type") - assertError(PosExplode($"intField"), - "input to function explode should be array or map type") + checkError( + exception = intercept[AnalysisException] { + assertSuccess(Explode($"intField")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"explode(intField)\"", + "paramIndex" -> "1", + "inputSql" -> "\"intField\"", + "inputType" -> "\"INT\"", + "requiredType" -> "(\"ARRAY\" or \"MAP\")")) + + checkError( + exception = intercept[AnalysisException] { + assertSuccess(PosExplode($"intField")) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"posexplode(intField)\"", + "paramIndex" -> "1", + "inputSql" -> "\"intField\"", + "inputType" -> "\"INT\"", + "requiredType" -> "(\"ARRAY\" or \"MAP\")") + ) } test("check types for CreateNamedStruct") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 3fb66f08cea..abec582d43a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -55,36 +55,101 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil) // The first argument must be a positive constant integer. - val m = intercept[AnalysisException] { - df.selectExpr("stack(1.1, 1, 2, 3)") - }.getMessage - assert(m.contains("The number of rows must be a positive constant integer.")) - val m2 = intercept[AnalysisException] { - df.selectExpr("stack(-1, 1, 2, 3)") - }.getMessage - assert(m2.contains("The number of rows must be a positive constant integer.")) + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("stack(1.1, 1, 2, 3)") + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"stack(1.1, 1, 2, 3)\"", + "paramIndex" -> "1", + "inputSql" -> "\"1.1\"", + "inputType" -> "\"DECIMAL(2,1)\"", + "requiredType" -> "\"INT\""), + context = ExpectedContext( + fragment = "stack(1.1, 1, 2, 3)", + start = 0, + stop = 18 + ) + ) + + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("stack(-1, 1, 2, 3)") + }, + errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE", + parameters = Map( + "sqlExpr" -> "\"stack(-1, 1, 2, 3)\"", + "exprName" -> "`n`", + "valueRange" -> "(0, 2147483647]", + "currentValue" -> "-1"), + context = ExpectedContext( + fragment = "stack(-1, 1, 2, 3)", + start = 0, + stop = 17 + ) + ) // The data for the same column should have the same type. - val m3 = intercept[AnalysisException] { - df.selectExpr("stack(2, 1, '2.2')") - }.getMessage - assert(m3.contains("data type mismatch: Argument 1 (int) != Argument 2 (string)")) + checkError( + exception = intercept[AnalysisException] { + df.selectExpr("stack(2, 1, '2.2')") + }, + errorClass = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"stack(2, 1, 2.2)\"", + "columnIndex" -> "0", + "leftParamIndex" -> "1", + "leftType" -> "\"INT\"", + "rightParamIndex" -> "2", + "rightType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "stack(2, 1, '2.2')", + start = 0, + stop = 17 + ) + ) // stack on column data val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c") checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) :: Nil) - val m4 = intercept[AnalysisException] { - df2.selectExpr("stack(n, a, b, c)") - }.getMessage - assert(m4.contains("The number of rows must be a positive constant integer.")) + checkError( + exception = intercept[AnalysisException] { + df2.selectExpr("stack(n, a, b, c)") + }, + errorClass = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "sqlExpr" -> "\"stack(n, a, b, c)\"", + "inputName" -> "n", + "inputType" -> "\"INT\"", + "inputExpr" -> "\"n\""), + context = ExpectedContext( + fragment = "stack(n, a, b, c)", + start = 0, + stop = 16 + ) + ) val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b") - val m5 = intercept[AnalysisException] { - df3.selectExpr("stack(2, a, b)") - }.getMessage - assert(m5.contains("data type mismatch: Argument 1 (int) != Argument 2 (double)")) - + checkError( + exception = intercept[AnalysisException] { + df3.selectExpr("stack(2, a, b)") + }, + errorClass = "DATATYPE_MISMATCH.STACK_COLUMN_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"stack(2, a, b)\"", + "columnIndex" -> "0", + "leftParamIndex" -> "1", + "leftType" -> "\"INT\"", + "rightParamIndex" -> "2", + "rightType" -> "\"DOUBLE\""), + context = ExpectedContext( + fragment = "stack(2, a, b)", + start = 0, + stop = 13 + ) + ) } test("single explode") { @@ -218,10 +283,18 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { } test("inline raises exception on array of null type") { - val m = intercept[AnalysisException] { - spark.range(2).select(inline(array())) - }.getMessage - assert(m.contains("data type mismatch")) + checkError( + exception = intercept[AnalysisException] { + spark.range(2).select(inline(array())) + }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"inline(array())\"", + "paramIndex" -> "1", + "inputSql" -> "\"array()\"", + "inputType" -> "\"ARRAY<VOID>\"", + "requiredType" -> "\"ARRAY<STRUCT>\"") + ) } test("inline with empty table") { @@ -250,20 +323,30 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { Row(1, 2) :: Row(1, 2) :: Nil) // Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name difference. - val m = intercept[AnalysisException] { - df.select(inline(array(struct('a), struct('b)))) - }.getMessage - assert(m.contains("data type mismatch")) + checkError( + exception = intercept[AnalysisException] { + df.select(inline(array(struct('a), struct('b)))) + }, + errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array(struct(a), struct(b))\"", + "functionName" -> "`array`", + "dataType" -> "(\"STRUCT<a: INT>\" or \"STRUCT<b: INT>\")")) checkAnswer( df.select(inline(array(struct('a), struct('b.alias("a"))))), Row(1) :: Row(2) :: Nil) // Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to name difference. - val m2 = intercept[AnalysisException] { - df.select(inline(array(struct('a), struct(lit(2))))) - }.getMessage - assert(m2.contains("data type mismatch")) + checkError( + exception = intercept[AnalysisException] { + df.select(inline(array(struct('a), struct(lit(2))))) + }, + errorClass = "DATATYPE_MISMATCH.DATA_DIFF_TYPES", + parameters = Map( + "sqlExpr" -> "\"array(struct(a), struct(2))\"", + "functionName" -> "`array`", + "dataType" -> "(\"STRUCT<a: INT>\" or \"STRUCT<col1: INT>\")")) checkAnswer( df.select(inline(array(struct('a), struct(lit(2).alias("a"))))), @@ -330,30 +413,39 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { Row(1, 2) :: Row(1, 3) :: Nil ) - val msg1 = intercept[AnalysisException] { - sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1") - }.getMessage - assert(msg1.contains("The generator is not supported: nested in expressions")) - - val msg2 = intercept[AnalysisException] { - sql( - """select - | explode(array(min(c2), max(c2))), - | posexplode(array(min(c2), max(c2))) - |from t1 group by c1 - """.stripMargin) - }.getMessage - assert(msg2.contains("The generator is not supported: " + - "only one generator allowed per aggregate clause")) + checkError( + exception = intercept[AnalysisException] { + sql("select 1 + explode(array(min(c2), max(c2))) from t1 group by c1") + }, + errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + parameters = Map( + "expression" -> "\"(1 + explode(array(min(c2), max(c2))))\"")) + + + checkError( + exception = intercept[AnalysisException] { + sql( + """select + | explode(array(min(c2), max(c2))), + | posexplode(array(min(c2), max(c2))) + |from t1 group by c1""".stripMargin) + }, + errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", + parameters = Map( + "clause" -> "aggregate", + "num" -> "2", + "generators" -> ("\"explode(array(min(c2), max(c2)))\", " + + "\"posexplode(array(min(c2), max(c2)))\""))) } } test("SPARK-30998: Unsupported nested inner generators") { - val errMsg = intercept[AnalysisException] { - sql("SELECT array(array(1, 2), array(3)) v").select(explode(explode($"v"))).collect - }.getMessage - assert(errMsg.contains("The generator is not supported: " + - """nested in expressions "explode(explode(v))"""")) + checkError( + exception = intercept[AnalysisException] { + sql("SELECT array(array(1, 2), array(3)) v").select(explode(explode($"v"))).collect + }, + errorClass = "UNSUPPORTED_GENERATOR.NESTED_IN_EXPRESSIONS", + parameters = Map("expression" -> "\"explode(explode(v))\"")) } test("SPARK-30997: generators in aggregate expressions for dataframe") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org