This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f718b025d87 [SPARK-43802][SQL] Fix codegen for unhex and unbase64 with failOnError=true f718b025d87 is described below commit f718b025d87ae3726210c60ff71cb34917b32f51 Author: Adam Binford <adam...@gmail.com> AuthorDate: Fri May 26 20:37:14 2023 +0300 [SPARK-43802][SQL] Fix codegen for unhex and unbase64 with failOnError=true ### What changes were proposed in this pull request? Fixes an error with codegen for unhex and unbase64 expression when failOnError is enabled introduced in https://github.com/apache/spark/pull/37483. ### Why are the changes needed? Codegen fails and Spark falls back to interpreted evaluation: ``` Caused by: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 47, Column 1: failed to compile: org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 47, Column 1: Unknown variable or type "BASE64" ``` in the code block: ``` /* 107 */ if (!org.apache.spark.sql.catalyst.expressions.UnBase64.isValidBase64(project_value_1)) { /* 108 */ throw QueryExecutionErrors.invalidInputInConversionError( /* 109 */ ((org.apache.spark.sql.types.BinaryType$) references[1] /* to */), /* 110 */ project_value_1, /* 111 */ BASE64, /* 112 */ "try_to_binary"); /* 113 */ } ``` ### Does this PR introduce _any_ user-facing change? Bug fix. ### How was this patch tested? Added to the existing tests so evaluate an expression with failOnError enabled to test that path of the codegen. Closes #41317 from Kimahriman/bug-to-binary-codegen. Authored-by: Adam Binford <adam...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../sql/catalyst/expressions/mathExpressions.scala | 3 +- .../catalyst/expressions/stringExpressions.scala | 3 +- .../expressions/MathExpressionsSuite.scala | 3 ++ .../expressions/StringExpressionsSuite.scala | 4 +- .../sql/errors/QueryExecutionErrorsSuite.scala | 46 ++++++++++++++++------ 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index dcc821a24ea..add59a38b72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1172,14 +1172,13 @@ case class Unhex(child: Expression, failOnError: Boolean = false) nullSafeCodeGen(ctx, ev, c => { val hex = Hex.getClass.getName.stripSuffix("$") val maybeFailOnErrorCode = if (failOnError) { - val format = UTF8String.fromString("BASE64"); val binaryType = ctx.addReferenceObj("to", BinaryType, BinaryType.getClass.getName) s""" |if (${ev.value} == null) { | throw QueryExecutionErrors.invalidInputInConversionError( | $binaryType, | $c, - | $format, + | UTF8String.fromString("HEX"), | "try_to_binary"); |} |""".stripMargin diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 347dff0f4c4..03596ac40b1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2472,14 +2472,13 @@ case class UnBase64(child: Expression, failOnError: Boolean = false) nullSafeCodeGen(ctx, ev, child => { val maybeValidateInputCode = if (failOnError) { val unbase64 = UnBase64.getClass.getName.stripSuffix("$") - val format = UTF8String.fromString("BASE64"); val binaryType = ctx.addReferenceObj("to", BinaryType, BinaryType.getClass.getName) s""" |if (!$unbase64.isValidBase64($child)) { | throw QueryExecutionErrors.invalidInputInConversionError( | $binaryType, | $child, - | $format, + | UTF8String.fromString("BASE64"), | "try_to_binary"); |} """.stripMargin diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 437f7ddee01..823a6d2ce86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -615,6 +615,9 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Unhex(Literal("GG")), null) checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35)) checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69)) + + // failOnError + checkEvaluation(Unhex(Literal("12345"), true), Array[Byte](1, 35, 69)) // scalastyle:off // Turn off scala style for non-ascii chars checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), "δΈιη".getBytes(StandardCharsets.UTF_8)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index a27af7d2439..f320012d131 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -468,7 +468,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef")) checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef")) checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef")) - checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA==")) + + // failOnError + checkEvaluation(Base64(UnBase64(a, true)), "AQIDBA==", create_row("AQIDBA==")) checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes)) checkEvaluation(Base64(b), "", create_row(Array.empty[Byte])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 4bfab92ccb1..c37722133cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, R import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator} import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean import org.apache.spark.sql.catalyst.util.BadRecordException @@ -57,17 +58,40 @@ class QueryExecutionErrorsSuite import testImplicits._ - test("CONVERSION_INVALID_INPUT: to_binary conversion function") { - checkError( - exception = intercept[SparkIllegalArgumentException] { - sql("select to_binary('???', 'base64')").collect() - }, - errorClass = "CONVERSION_INVALID_INPUT", - parameters = Map( - "str" -> "'???'", - "fmt" -> "'BASE64'", - "targetType" -> "\"BINARY\"", - "suggestion" -> "`try_to_binary`")) + test("CONVERSION_INVALID_INPUT: to_binary conversion function base64") { + for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) { + val exception = intercept[SparkException] { + Seq(("???")).toDF("a").selectExpr("to_binary(a, 'base64')").collect() + }.getCause.asInstanceOf[SparkIllegalArgumentException] + checkError( + exception, + errorClass = "CONVERSION_INVALID_INPUT", + parameters = Map( + "str" -> "'???'", + "fmt" -> "'BASE64'", + "targetType" -> "\"BINARY\"", + "suggestion" -> "`try_to_binary`")) + } + } + } + + test("CONVERSION_INVALID_INPUT: to_binary conversion function hex") { + for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) { + val exception = intercept[SparkException] { + Seq(("???")).toDF("a").selectExpr("to_binary(a, 'hex')").collect() + }.getCause.asInstanceOf[SparkIllegalArgumentException] + checkError( + exception, + errorClass = "CONVERSION_INVALID_INPUT", + parameters = Map( + "str" -> "'???'", + "fmt" -> "'HEX'", + "targetType" -> "\"BINARY\"", + "suggestion" -> "`try_to_binary`")) + } + } } private def getAesInputs(): (DataFrame, DataFrame) = { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org