This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f718b025d87 [SPARK-43802][SQL] Fix codegen for unhex and unbase64 with 
failOnError=true
f718b025d87 is described below

commit f718b025d87ae3726210c60ff71cb34917b32f51
Author: Adam Binford <adam...@gmail.com>
AuthorDate: Fri May 26 20:37:14 2023 +0300

    [SPARK-43802][SQL] Fix codegen for unhex and unbase64 with failOnError=true
    
    ### What changes were proposed in this pull request?
    
    Fixes an error with codegen for unhex and unbase64 expression when 
failOnError is enabled introduced in https://github.com/apache/spark/pull/37483.
    
    ### Why are the changes needed?
    
    Codegen fails and Spark falls back to interpreted evaluation:
    ```
    Caused by: org.codehaus.commons.compiler.CompileException: File 
'generated.java', Line 47, Column 1: failed to compile: 
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 47, 
Column 1: Unknown variable or type "BASE64"
    ```
    in the code block:
    ```
    /* 107 */         if 
(!org.apache.spark.sql.catalyst.expressions.UnBase64.isValidBase64(project_value_1))
 {
    /* 108 */           throw 
QueryExecutionErrors.invalidInputInConversionError(
    /* 109 */             ((org.apache.spark.sql.types.BinaryType$) 
references[1] /* to */),
    /* 110 */             project_value_1,
    /* 111 */             BASE64,
    /* 112 */             "try_to_binary");
    /* 113 */         }
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Bug fix.
    
    ### How was this patch tested?
    
    Added to the existing tests so evaluate an expression with failOnError 
enabled to test that path of the codegen.
    
    Closes #41317 from Kimahriman/bug-to-binary-codegen.
    
    Authored-by: Adam Binford <adam...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../sql/catalyst/expressions/mathExpressions.scala |  3 +-
 .../catalyst/expressions/stringExpressions.scala   |  3 +-
 .../expressions/MathExpressionsSuite.scala         |  3 ++
 .../expressions/StringExpressionsSuite.scala       |  4 +-
 .../sql/errors/QueryExecutionErrorsSuite.scala     | 46 ++++++++++++++++------
 5 files changed, 43 insertions(+), 16 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index dcc821a24ea..add59a38b72 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1172,14 +1172,13 @@ case class Unhex(child: Expression, failOnError: 
Boolean = false)
     nullSafeCodeGen(ctx, ev, c => {
       val hex = Hex.getClass.getName.stripSuffix("$")
       val maybeFailOnErrorCode = if (failOnError) {
-        val format = UTF8String.fromString("BASE64");
         val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
         s"""
            |if (${ev.value} == null) {
            |  throw QueryExecutionErrors.invalidInputInConversionError(
            |    $binaryType,
            |    $c,
-           |    $format,
+           |    UTF8String.fromString("HEX"),
            |    "try_to_binary");
            |}
            |""".stripMargin
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 347dff0f4c4..03596ac40b1 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -2472,14 +2472,13 @@ case class UnBase64(child: Expression, failOnError: 
Boolean = false)
     nullSafeCodeGen(ctx, ev, child => {
       val maybeValidateInputCode = if (failOnError) {
         val unbase64 = UnBase64.getClass.getName.stripSuffix("$")
-        val format = UTF8String.fromString("BASE64");
         val binaryType = ctx.addReferenceObj("to", BinaryType, 
BinaryType.getClass.getName)
         s"""
            |if (!$unbase64.isValidBase64($child)) {
            |  throw QueryExecutionErrors.invalidInputInConversionError(
            |    $binaryType,
            |    $child,
-           |    $format,
+           |    UTF8String.fromString("BASE64"),
            |    "try_to_binary");
            |}
        """.stripMargin
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index 437f7ddee01..823a6d2ce86 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -615,6 +615,9 @@ class MathExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Unhex(Literal("GG")), null)
     checkEvaluation(Unhex(Literal("123")), Array[Byte](1, 35))
     checkEvaluation(Unhex(Literal("12345")), Array[Byte](1, 35, 69))
+
+    // failOnError
+    checkEvaluation(Unhex(Literal("12345"), true), Array[Byte](1, 35, 69))
     // scalastyle:off
     // Turn off scala style for non-ascii chars
     checkEvaluation(Unhex(Literal("E4B889E9878DE79A84")), 
"δΈ‰ι‡ηš„".getBytes(StandardCharsets.UTF_8))
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index a27af7d2439..f320012d131 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -468,7 +468,9 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", 
create_row("abdef"))
     checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
     checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, 
create_row("abdef"))
-    checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))
+
+    // failOnError
+    checkEvaluation(Base64(UnBase64(a, true)), "AQIDBA==", 
create_row("AQIDBA=="))
 
     checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
     checkEvaluation(Base64(b), "", create_row(Array.empty[Byte]))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
index 4bfab92ccb1..c37722133cb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, 
Dataset, QueryTest, R
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.{Parameter, UnresolvedGenerator}
 import org.apache.spark.sql.catalyst.expressions.{Grouping, Literal, RowNumber}
+import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean
 import org.apache.spark.sql.catalyst.util.BadRecordException
@@ -57,17 +58,40 @@ class QueryExecutionErrorsSuite
 
   import testImplicits._
 
-  test("CONVERSION_INVALID_INPUT: to_binary conversion function") {
-    checkError(
-      exception = intercept[SparkIllegalArgumentException] {
-        sql("select to_binary('???', 'base64')").collect()
-      },
-      errorClass = "CONVERSION_INVALID_INPUT",
-      parameters = Map(
-        "str" -> "'???'",
-        "fmt" -> "'BASE64'",
-        "targetType" -> "\"BINARY\"",
-        "suggestion" -> "`try_to_binary`"))
+  test("CONVERSION_INVALID_INPUT: to_binary conversion function base64") {
+    for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+        val exception = intercept[SparkException] {
+          Seq(("???")).toDF("a").selectExpr("to_binary(a, 'base64')").collect()
+        }.getCause.asInstanceOf[SparkIllegalArgumentException]
+        checkError(
+          exception,
+          errorClass = "CONVERSION_INVALID_INPUT",
+          parameters = Map(
+            "str" -> "'???'",
+            "fmt" -> "'BASE64'",
+            "targetType" -> "\"BINARY\"",
+            "suggestion" -> "`try_to_binary`"))
+      }
+    }
+  }
+
+  test("CONVERSION_INVALID_INPUT: to_binary conversion function hex") {
+    for (codegenMode <- Seq(CODEGEN_ONLY, NO_CODEGEN)) {
+      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode.toString) {
+        val exception = intercept[SparkException] {
+          Seq(("???")).toDF("a").selectExpr("to_binary(a, 'hex')").collect()
+        }.getCause.asInstanceOf[SparkIllegalArgumentException]
+        checkError(
+          exception,
+          errorClass = "CONVERSION_INVALID_INPUT",
+          parameters = Map(
+            "str" -> "'???'",
+            "fmt" -> "'HEX'",
+            "targetType" -> "\"BINARY\"",
+            "suggestion" -> "`try_to_binary`"))
+      }
+    }
   }
 
   private def getAesInputs(): (DataFrame, DataFrame) = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to