This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 34fb40892e3 [SPARK-42051][SQL] Codegen Support for HiveGenericUDF
34fb40892e3 is described below

commit 34fb40892e3b5680afbba59e5fd7b10e9a9a7d15
Author: Kent Yao <y...@apache.org>
AuthorDate: Wed Feb 1 09:56:36 2023 +0800

    [SPARK-42051][SQL] Codegen Support for HiveGenericUDF
    
    ### What changes were proposed in this pull request?
    
    As a subtask of SPARK-42050, this PR adds Codegen Support for 
`HiveGenericUDF`
    
    ### Why are the changes needed?
    
    improve codegen coverage and performance
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new UT added
    
    Closes #39555 from yaooqinn/SPARK-42051.
    
    Authored-by: Kent Yao <y...@apache.org>
    Signed-off-by: Kent Yao <y...@apache.org>
---
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 62 ++++++++++++++++++----
 .../spark/sql/hive/execution/HiveUDFSuite.scala    | 34 +++++++++++-
 2 files changed, 86 insertions(+), 10 deletions(-)

diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index a950c1a1783..32ade60e20d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -35,7 +35,8 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodeGenerator, CodegenFallback, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
 import org.apache.spark.sql.hive.HiveShim._
 import org.apache.spark.sql.types._
 
@@ -120,19 +121,18 @@ private[hive] class DeferredObjectAdapter(oi: 
ObjectInspector, dataType: DataTyp
   extends DeferredObject with HiveInspectors {
 
   private val wrapper = wrapperFor(oi, dataType)
-  private var func: () => Any = _
-  def set(func: () => Any): Unit = {
+  private var func: Any = _
+  def set(func: Any): Unit = {
     this.func = func
   }
   override def prepare(i: Int): Unit = {}
-  override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
+  override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
 }
 
 private[hive] case class HiveGenericUDF(
     name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
   extends Expression
   with HiveInspectors
-  with CodegenFallback
   with Logging
   with UserDefinedExpression {
 
@@ -154,8 +154,9 @@ private[hive] case class HiveGenericUDF(
     function.initializeAndFoldConstants(argumentInspectors.toArray)
   }
 
+  // Visible for codegen
   @transient
-  private lazy val unwrapper = unwrapperFor(returnInspector)
+  lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
 
   @transient
   private lazy val isUDFDeterministic = {
@@ -163,9 +164,10 @@ private[hive] case class HiveGenericUDF(
     udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
+  // Visible for codegen
   @transient
-  private lazy val deferredObjects = argumentInspectors.zip(children).map { 
case (inspect, child) =>
-    new DeferredObjectAdapter(inspect, child.dataType)
+  lazy val deferredObjects: Array[DeferredObject] = 
argumentInspectors.zip(children).map {
+    case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
   }.toArray[DeferredObject]
 
   override lazy val dataType: DataType = inspectorToDataType(returnInspector)
@@ -178,7 +180,7 @@ private[hive] case class HiveGenericUDF(
     while (i < length) {
       val idx = i
       deferredObjects(i).asInstanceOf[DeferredObjectAdapter]
-        .set(() => children(idx).eval(input))
+        .set(children(idx).eval(input))
       i += 1
     }
     unwrapper(function.evaluate(deferredObjects))
@@ -192,6 +194,48 @@ private[hive] case class HiveGenericUDF(
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
     copy(children = newChildren)
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val refTerm = ctx.addReferenceObj("this", this)
+    val childrenEvals = children.map(_.genCode(ctx))
+
+    val setDeferredObjects = childrenEvals.zipWithIndex.map {
+      case (eval, i) =>
+        val deferredObjectAdapterClz = 
classOf[DeferredObjectAdapter].getCanonicalName
+        s"""
+           |if (${eval.isNull}) {
+           |  (($deferredObjectAdapterClz) 
$refTerm.deferredObjects()[$i]).set(null);
+           |} else {
+           |  (($deferredObjectAdapterClz) 
$refTerm.deferredObjects()[$i]).set(${eval.value});
+           |}
+           |""".stripMargin
+    }
+
+    val resultType = CodeGenerator.boxedType(dataType)
+    val resultTerm = ctx.freshName("result")
+    ev.copy(code =
+      code"""
+         |${childrenEvals.map(_.code).mkString("\n")}
+         |${setDeferredObjects.mkString("\n")}
+         |$resultType $resultTerm = null;
+         |boolean ${ev.isNull} = false;
+         |try {
+         |  $resultTerm = ($resultType) $refTerm.unwrapper().apply(
+         |    $refTerm.function().evaluate($refTerm.deferredObjects()));
+         |  ${ev.isNull} = $resultTerm == null;
+         |} catch (Throwable e) {
+         |  throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
+         |    "${funcWrapper.functionClassName}",
+         |    "${children.map(_.dataType.catalogString).mkString(", ")}",
+         |    "${dataType.catalogString}",
+         |    e);
+         |}
+         |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+         |if (!${ev.isNull}) {
+         |  ${ev.value} = $resultTerm;
+         |}
+         |""".stripMargin
+    )
+  }
 }
 
 /**
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index f494232502f..baa25843d48 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -32,9 +32,10 @@ import 
org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
 import org.apache.hadoop.io.{LongWritable, Writable}
 
-import org.apache.spark.{SparkFiles, TestUtils}
+import org.apache.spark.{SparkException, SparkFiles, TestUtils}
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.functions.max
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.SQLConf
@@ -711,6 +712,37 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
       }
     }
   }
+
+  test("SPARK-42051: HiveGenericUDF Codegen Support") {
+    withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) {
+      sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS 
'${classOf[GenericUDFMaskHash].getName}'")
+      withTable("HiveGenericUDFTable") {
+        sql(s"create table HiveGenericUDFTable as select 'Spark SQL' as v")
+        val df = sql("SELECT CodeGenHiveGenericUDF(v) from 
HiveGenericUDFTable")
+        val plan = df.queryExecution.executedPlan
+        assert(plan.isInstanceOf[WholeStageCodegenExec])
+        checkAnswer(df, Seq(Row("14ab8df5135825bc9f5ff7c30609f02f")))
+      }
+    }
+  }
+
+  test("SPARK-42051: HiveGenericUDF Codegen Support w/ execution failure") {
+    withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) {
+      sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS 
'${classOf[GenericUDFAssertTrue].getName}'")
+      withTable("HiveGenericUDFTable") {
+        sql(s"create table HiveGenericUDFTable as select false as v")
+        val df = sql("SELECT CodeGenHiveGenericUDF(v) from 
HiveGenericUDFTable")
+        val e = 
intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException]
+        checkError(
+          e,
+          "FAILED_EXECUTE_UDF",
+          parameters = Map(
+            "functionName" -> s"${classOf[GenericUDFAssertTrue].getName}",
+            "signature" -> "boolean",
+            "result" -> "void"))
+      }
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to