This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c3f8c973d44 [SPARK-41174][CORE][SQL] Propagate an error class to users for invalid `format` of `to_binary()` c3f8c973d44 is described below commit c3f8c973d448b4d9be7502985aededdd7b81d164 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Wed Nov 23 17:25:06 2022 +0300 [SPARK-41174][CORE][SQL] Propagate an error class to users for invalid `format` of `to_binary()` ### What changes were proposed in this pull request? This pr overrides the `checkInputDataTypes()` method of `ToBinary` function to propagate error class to users for invalid `format`. ### Why are the changes needed? Migration onto error classes unifies Spark SQL error messages. ### Does this PR introduce _any_ user-facing change? Yes. The PR changes user-facing error messages. ### How was this patch tested? Pass GitHub Actions Closes #38737 from LuciferYang/SPARK-41174. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- core/src/main/resources/error/error-classes.json | 5 ++ .../catalyst/expressions/stringExpressions.scala | 85 +++++++++++++++------- .../expressions/StringExpressionsSuite.scala | 15 ++++ .../sql-tests/inputs/string-functions.sql | 4 + .../results/ansi/string-functions.sql.out | 70 +++++++++++++++--- .../sql-tests/results/string-functions.sql.out | 70 +++++++++++++++--- 6 files changed, 204 insertions(+), 45 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index afe08f044c7..5bac5ae71f2 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -234,6 +234,11 @@ "Input to the function <functionName> cannot contain elements of the \"MAP\" type. In Spark, same maps may have different hashcode, thus hash expressions are prohibited on \"MAP\" elements. To restore previous behavior set \"spark.sql.legacy.allowHashOnMapType\" to \"true\"." ] }, + "INVALID_ARG_VALUE" : { + "message" : [ + "The <inputName> value must to be a <requireType> literal of <validValues>, but got <inputValue>." + ] + }, "INVALID_JSON_MAP_KEY_TYPE" : { "message" : [ "Input schema <schema> can only contain STRING as a key type for a MAP." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 60b56f4fef7..3a1db2ce1b8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2620,39 +2620,30 @@ case class ToBinary( nullOnInvalidFormat: Boolean = false) extends RuntimeReplaceable with ImplicitCastInputTypes { - override lazy val replacement: Expression = format.map { f => - assert(f.foldable && (f.dataType == StringType || f.dataType == NullType)) + @transient lazy val fmt: String = format.map { f => val value = f.eval() if (value == null) { - Literal(null, BinaryType) + null } else { - value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) match { - case "hex" => Unhex(expr, failOnError = true) - case "utf-8" | "utf8" => Encode(expr, Literal("UTF-8")) - case "base64" => UnBase64(expr, failOnError = true) - case _ if nullOnInvalidFormat => Literal(null, BinaryType) - case other => throw QueryCompilationErrors.invalidStringLiteralParameter( - "to_binary", - "format", - other, - Some( - "The value has to be a case-insensitive string literal of " + - "'hex', 'utf-8', 'utf8', or 'base64'.")) - } + value.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT) + } + }.getOrElse("hex") + + override lazy val replacement: Expression = if (fmt == null) { + Literal(null, BinaryType) + } else { + fmt match { + case "hex" => Unhex(expr, failOnError = true) + case "utf-8" | "utf8" => Encode(expr, Literal("UTF-8")) + case "base64" => UnBase64(expr, failOnError = true) + case _ => Literal(null, BinaryType) } - }.getOrElse(Unhex(expr, failOnError = true)) + } def this(expr: Expression) = this(expr, None, false) def this(expr: Expression, format: Expression) = - this(expr, Some({ - // We perform this check in the constructor to make it eager and not go through type coercion. - if (format.foldable && (format.dataType == StringType || format.dataType == NullType)) { - format - } else { - throw QueryCompilationErrors.requireLiteralParameter("to_binary", "format", "string") - } - }), false) + this(expr, Some(format), false) override def prettyName: String = "to_binary" @@ -2660,6 +2651,50 @@ case class ToBinary( override def inputTypes: Seq[AbstractDataType] = children.map(_ => StringType) + override def checkInputDataTypes(): TypeCheckResult = { + def isValidFormat: Boolean = { + fmt == null || Set("hex", "utf-8", "utf8", "base64").contains(fmt) + } + format match { + case Some(f) => + if (f.foldable && (f.dataType == StringType || f.dataType == NullType)) { + if (isValidFormat || nullOnInvalidFormat) { + super.checkInputDataTypes() + } else { + DataTypeMismatch( + errorSubClass = "INVALID_ARG_VALUE", + messageParameters = Map( + "inputName" -> "fmt", + "requireType" -> s"case-insensitive ${toSQLType(StringType)}", + "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", + "inputValue" -> toSQLValue(fmt, StringType) + ) + ) + } + } else if (!f.foldable) { + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "fmt", + "inputType" -> toSQLType(StringType), + "inputExpr" -> toSQLExpr(f) + ) + ) + } else { + DataTypeMismatch( + errorSubClass = "INVALID_ARG_VALUE", + messageParameters = Map( + "inputName" -> "fmt", + "requireType" -> s"case-insensitive ${toSQLType(StringType)}", + "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", + "inputValue" -> toSQLValue(f.eval(), f.dataType) + ) + ) + } + case _ => super.checkInputDataTypes() + } + } + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { if (format.isDefined) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 0585578571a..f0b320db3a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1256,6 +1256,21 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("ToBinary: fails analysis if fmt is not foldable") { + val wrongFmt = AttributeReference("invalidFormat", StringType)() + val toBinaryExpr = ToBinary(Literal("abc"), Some(wrongFmt)) + assert(toBinaryExpr.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "NON_FOLDABLE_INPUT", + messageParameters = Map( + "inputName" -> "fmt", + "inputType" -> toSQLType(wrongFmt.dataType), + "inputExpr" -> toSQLExpr(wrongFmt) + ) + ) + ) + } + test("ToNumber: negative tests (the input string does not match the format string)") { Seq( // The input contained more thousands separators than the format string. diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index cb18c547b61..39c57e6efa2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -225,3 +225,7 @@ select to_binary(null, cast(null as string)); -- invalid format select to_binary('abc', 1); select to_binary('abc', 'invalidFormat'); +CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat'); +SELECT to_binary('abc', fmtField) FROM fmtTable; +-- Clean up +DROP VIEW IF EXISTS fmtTable; diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 3ab49c14bef..5a0479996b9 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1610,11 +1610,13 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "funcName" : "to_binary", - "requiredType" : "string" + "inputName" : "fmt", + "inputValue" : "'1'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, 1)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" }, "queryContext" : [ { "objectType" : "", @@ -1633,11 +1635,59 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1101", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "endingMsg" : " The value has to be a case-insensitive string literal of 'hex', 'utf-8', 'utf8', or 'base64'.", - "funcName" : "to_binary", - "invalidValue" : "invalidformat" - } + "inputName" : "fmt", + "inputValue" : "'invalidformat'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, invalidFormat)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 40, + "fragment" : "to_binary('abc', 'invalidFormat')" + } ] } + + +-- !query +CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat') +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT to_binary('abc', fmtField) FROM fmtTable +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "messageParameters" : { + "inputExpr" : "\"fmtField\"", + "inputName" : "fmt", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"to_binary(abc, fmtField)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 33, + "fragment" : "to_binary('abc', fmtField)" + } ] +} + + +-- !query +DROP VIEW IF EXISTS fmtTable +-- !query schema +struct<> +-- !query output + diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 2ea5cefa38d..36814275cd7 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1542,11 +1542,13 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1100", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "funcName" : "to_binary", - "requiredType" : "string" + "inputName" : "fmt", + "inputValue" : "'1'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, 1)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" }, "queryContext" : [ { "objectType" : "", @@ -1565,11 +1567,59 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException { - "errorClass" : "_LEGACY_ERROR_TEMP_1101", + "errorClass" : "DATATYPE_MISMATCH.INVALID_ARG_VALUE", "messageParameters" : { - "argName" : "format", - "endingMsg" : " The value has to be a case-insensitive string literal of 'hex', 'utf-8', 'utf8', or 'base64'.", - "funcName" : "to_binary", - "invalidValue" : "invalidformat" - } + "inputName" : "fmt", + "inputValue" : "'invalidformat'", + "requireType" : "case-insensitive \"STRING\"", + "sqlExpr" : "\"to_binary(abc, invalidFormat)\"", + "validValues" : "'hex', 'utf-8', 'utf8', or 'base64'" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 40, + "fragment" : "to_binary('abc', 'invalidFormat')" + } ] } + + +-- !query +CREATE TEMPORARY VIEW fmtTable(fmtField) AS SELECT * FROM VALUES ('invalidFormat') +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT to_binary('abc', fmtField) FROM fmtTable +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + "messageParameters" : { + "inputExpr" : "\"fmtField\"", + "inputName" : "fmt", + "inputType" : "\"STRING\"", + "sqlExpr" : "\"to_binary(abc, fmtField)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 33, + "fragment" : "to_binary('abc', fmtField)" + } ] +} + + +-- !query +DROP VIEW IF EXISTS fmtTable +-- !query schema +struct<> +-- !query output + --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org