This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5825db81e00 [SPARK-42052][SQL] Codegen Support for HiveSimpleUDF
5825db81e00 is described below

commit 5825db81e0059a4895b4f59d57dec67b0bc618b4
Author: panbingkun <pbk1...@gmail.com>
AuthorDate: Wed Mar 22 12:14:58 2023 +0800

    [SPARK-42052][SQL] Codegen Support for HiveSimpleUDF
    
    ### What changes were proposed in this pull request?
    - As a subtask of 
[SPARK-42050](https://issues.apache.org/jira/browse/SPARK-42050), this PR adds 
Codegen Support for HiveSimpleUDF
    - Extract a`HiveUDFEvaluatorBase` class for the common behaviors of 
HiveSimpleUDFEvaluator & HiveGenericUDFEvaluator.
    
    ### Why are the changes needed?
    - Improve codegen coverage and performance.
    - Following https://github.com/apache/spark/pull/39949. Make the code more 
concise.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Add new UT.
    Pass GA.
    
    Closes #40397 from panbingkun/refactor_HiveSimpleUDF.
    
    Authored-by: panbingkun <pbk1...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/hive/hiveUDFEvaluators.scala  | 148 +++++++++++++++++++++
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 147 ++++++--------------
 .../spark/sql/hive/execution/HiveUDFSuite.scala    |  42 ++++++
 3 files changed, 232 insertions(+), 105 deletions(-)

diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
new file mode 100644
index 00000000000..094f8ba7a0f
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF}
+import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
+import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, 
ObjectInspectorFactory}
+import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
+import org.apache.spark.sql.types.DataType
+
+abstract class HiveUDFEvaluatorBase[UDFType <: AnyRef](
+    funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+  extends HiveInspectors with Serializable {
+
+  @transient
+  lazy val function = funcWrapper.createFunction[UDFType]()
+
+  @transient
+  lazy val isUDFDeterministic = {
+    val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
+    udfType != null && udfType.deterministic() && !udfType.stateful()
+  }
+
+  def returnType: DataType
+
+  def setArg(index: Int, arg: Any): Unit
+
+  def doEvaluate(): Any
+
+  final def evaluate(): Any = {
+    try {
+      doEvaluate()
+    } catch {
+      case e: Throwable =>
+        throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
+          s"${funcWrapper.functionClassName}",
+          s"${children.map(_.dataType.catalogString).mkString(", ")}",
+          s"${returnType.catalogString}",
+          e)
+    }
+  }
+}
+
+class HiveSimpleUDFEvaluator(
+    funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+  extends HiveUDFEvaluatorBase[UDF](funcWrapper, children) {
+
+  @transient
+  lazy val method = function.getResolver.
+    getEvalMethod(children.map(_.dataType.toTypeInfo).asJava)
+
+  @transient
+  private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), 
x.dataType)).toArray
+
+  @transient
+  private lazy val arguments = children.map(toInspector).toArray
+
+  // Create parameter converters
+  @transient
+  private lazy val conversionHelper = new ConversionHelper(method, arguments)
+
+  @transient
+  private lazy val inputs: Array[AnyRef] = new Array[AnyRef](children.length)
+
+  override def returnType: DataType = 
javaTypeToDataType(method.getGenericReturnType)
+
+  override def setArg(index: Int, arg: Any): Unit = {
+    inputs(index) = wrappers(index)(arg).asInstanceOf[AnyRef]
+  }
+
+  @transient
+  private lazy val unwrapper: Any => Any =
+    unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
+      method.getGenericReturnType, ObjectInspectorOptions.JAVA))
+
+  override def doEvaluate(): Any = {
+    val ret = FunctionRegistry.invoke(
+      method,
+      function,
+      conversionHelper.convertIfNecessary(inputs: _*): _*)
+    unwrapper(ret)
+  }
+}
+
+class HiveGenericUDFEvaluator(
+    funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+  extends HiveUDFEvaluatorBase[GenericUDF](funcWrapper, children) {
+
+  @transient
+  private lazy val argumentInspectors = children.map(toInspector)
+
+  @transient
+  lazy val returnInspector = {
+    function.initializeAndFoldConstants(argumentInspectors.toArray)
+  }
+
+  @transient
+  private lazy val deferredObjects: Array[DeferredObject] = 
argumentInspectors.zip(children).map {
+    case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
+  }.toArray[DeferredObject]
+
+  @transient
+  private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
+
+  override def returnType: DataType = inspectorToDataType(returnInspector)
+
+  def setArg(index: Int, arg: Any): Unit =
+    deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
+
+  override def doEvaluate(): Any = 
unwrapper(function.evaluate(deferredObjects))
+}
+
+// Adapter from Catalyst ExpressionResult to Hive DeferredObject
+private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: 
DataType)
+  extends DeferredObject with HiveInspectors {
+
+  private val wrapper = wrapperFor(oi, dataType)
+  private var func: Any = _
+  def set(func: Any): Unit = {
+    this.func = func
+  }
+  override def prepare(i: Int): Unit = {}
+  override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 98b2258ea13..b07a1b717e7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -23,15 +23,10 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.hadoop.hive.ql.exec._
-import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
 import 
org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
 import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, 
ObjectInspector, ObjectInspectorFactory}
-import 
org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions
 
-import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -49,56 +44,26 @@ private[hive] case class HiveSimpleUDF(
     name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
   extends Expression
   with HiveInspectors
-  with CodegenFallback
-  with Logging
   with UserDefinedExpression {
 
-  override lazy val deterministic: Boolean = isUDFDeterministic && 
children.forall(_.deterministic)
-
-  override def nullable: Boolean = true
-
-  @transient
-  lazy val function = funcWrapper.createFunction[UDF]()
-
-  @transient
-  private lazy val method =
-    
function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava)
-
-  @transient
-  private lazy val arguments = children.map(toInspector).toArray
-
   @transient
-  private lazy val isUDFDeterministic = {
-    val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic() && !udfType.stateful()
-  }
+  private lazy val evaluator = new HiveSimpleUDFEvaluator(funcWrapper, 
children)
 
-  override def foldable: Boolean = isUDFDeterministic && 
children.forall(_.foldable)
+  override lazy val deterministic: Boolean =
+    evaluator.isUDFDeterministic && children.forall(_.deterministic)
 
-  // Create parameter converters
-  @transient
-  private lazy val conversionHelper = new ConversionHelper(method, arguments)
+  override def nullable: Boolean = true
 
-  override lazy val dataType = javaTypeToDataType(method.getGenericReturnType)
+  override def foldable: Boolean = evaluator.isUDFDeterministic && 
children.forall(_.foldable)
 
-  @transient
-  private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), 
x.dataType)).toArray
-
-  @transient
-  lazy val unwrapper = 
unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
-    method.getGenericReturnType, ObjectInspectorOptions.JAVA))
-
-  @transient
-  private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
+  override lazy val dataType: DataType = 
javaTypeToDataType(evaluator.method.getGenericReturnType)
 
   // TODO: Finish input output types.
   override def eval(input: InternalRow): Any = {
-    val inputs = wrap(children.map(_.eval(input)), wrappers, cached)
-    val ret = FunctionRegistry.invoke(
-      method,
-      function,
-      conversionHelper.convertIfNecessary(inputs : _*): _*)
-    unwrapper(ret)
+    children.zipWithIndex.map {
+      case (child, idx) => evaluator.setArg(idx, child.eval(input))
+    }
+    evaluator.evaluate()
   }
 
   override def toString: String = {
@@ -111,19 +76,37 @@ private[hive] case class HiveSimpleUDF(
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): Expression =
     copy(children = newChildren)
-}
 
-// Adapter from Catalyst ExpressionResult to Hive DeferredObject
-private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: 
DataType)
-  extends DeferredObject with HiveInspectors {
+  protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
+    val evals = children.map(_.genCode(ctx))
+
+    val setValues = evals.zipWithIndex.map {
+      case (eval, i) =>
+        s"""
+           |if (${eval.isNull}) {
+           |  $refEvaluator.setArg($i, null);
+           |} else {
+           |  $refEvaluator.setArg($i, ${eval.value});
+           |}
+           |""".stripMargin
+    }
 
-  private val wrapper = wrapperFor(oi, dataType)
-  private var func: Any = _
-  def set(func: Any): Unit = {
-    this.func = func
+    val resultType = CodeGenerator.boxedType(dataType)
+    val resultTerm = ctx.freshName("result")
+    ev.copy(code =
+      code"""
+         |${evals.map(_.code).mkString("\n")}
+         |${setValues.mkString("\n")}
+         |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
+         |boolean ${ev.isNull} = $resultTerm == null;
+         |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
+         |if (!${ev.isNull}) {
+         |  ${ev.value} = $resultTerm;
+         |}
+         |""".stripMargin
+    )
   }
-  override def prepare(i: Int): Unit = {}
-  override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
 }
 
 private[hive] case class HiveGenericUDF(
@@ -135,9 +118,9 @@ private[hive] case class HiveGenericUDF(
   override def nullable: Boolean = true
 
   override lazy val deterministic: Boolean =
-    isUDFDeterministic && children.forall(_.deterministic)
+    evaluator.isUDFDeterministic && children.forall(_.deterministic)
 
-  override def foldable: Boolean = isUDFDeterministic &&
+  override def foldable: Boolean = evaluator.isUDFDeterministic &&
     evaluator.returnInspector.isInstanceOf[ConstantObjectInspector]
 
   override lazy val dataType: DataType = 
inspectorToDataType(evaluator.returnInspector)
@@ -145,12 +128,6 @@ private[hive] case class HiveGenericUDF(
   @transient
   private lazy val evaluator = new HiveGenericUDFEvaluator(funcWrapper, 
children)
 
-  @transient
-  private val isUDFDeterministic = {
-    val udfType = 
evaluator.function.getClass.getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic() && !udfType.stateful()
-  }
-
   override def eval(input: InternalRow): Any = {
     children.zipWithIndex.map {
       case (child, idx) => evaluator.setArg(idx, child.eval(input))
@@ -188,18 +165,8 @@ private[hive] case class HiveGenericUDF(
       code"""
          |${evals.map(_.code).mkString("\n")}
          |${setValues.mkString("\n")}
-         |$resultType $resultTerm = null;
-         |boolean ${ev.isNull} = false;
-         |try {
-         |  $resultTerm = ($resultType) $refEvaluator.evaluate();
-         |  ${ev.isNull} = $resultTerm == null;
-         |} catch (Throwable e) {
-         |  throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
-         |    "${funcWrapper.functionClassName}",
-         |    "${children.map(_.dataType.catalogString).mkString(", ")}",
-         |    "${dataType.catalogString}",
-         |    e);
-         |}
+         |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
+         |boolean ${ev.isNull} = $resultTerm == null;
          |${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
          |if (!${ev.isNull}) {
          |  ${ev.value} = $resultTerm;
@@ -209,36 +176,6 @@ private[hive] case class HiveGenericUDF(
   }
 }
 
-class HiveGenericUDFEvaluator(
-    funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
-  extends HiveInspectors
-  with Serializable {
-
-  @transient
-  lazy val function = funcWrapper.createFunction[GenericUDF]()
-
-  @transient
-  private lazy val argumentInspectors = children.map(toInspector)
-
-  @transient
-  lazy val returnInspector = {
-    function.initializeAndFoldConstants(argumentInspectors.toArray)
-  }
-
-  @transient
-  private lazy val deferredObjects: Array[DeferredObject] = 
argumentInspectors.zip(children).map {
-    case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
-  }.toArray[DeferredObject]
-
-  @transient
-  private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)
-
-  def setArg(index: Int, arg: Any): Unit =
-    deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
-
-  def evaluate(): Any = unwrapper(function.evaluate(deferredObjects))
-}
-
 /**
  * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a
  * `Generator`. Note that the semantics of Generators do not allow
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index baa25843d48..8fb9209f9cb 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.metadata.HiveException
 import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
@@ -743,6 +744,38 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
       }
     }
   }
+
+  test("SPARK-42052: HiveSimpleUDF Codegen Support") {
+    withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) {
+      sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS 
'${classOf[UDFStringString].getName}'")
+      withTable("HiveSimpleUDFTable") {
+        sql(s"create table HiveSimpleUDFTable as select 'Spark SQL' as v")
+        val df = sql("SELECT CodeGenHiveSimpleUDF('Hello', v) from 
HiveSimpleUDFTable")
+        val plan = df.queryExecution.executedPlan
+        assert(plan.isInstanceOf[WholeStageCodegenExec])
+        checkAnswer(df, Seq(Row("Hello Spark SQL")))
+      }
+    }
+  }
+
+  test("SPARK-42052: HiveSimpleUDF Codegen Support w/ execution failure") {
+    withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) {
+      sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS 
'${classOf[SimpleUDFAssertTrue].getName}'")
+      withTable("HiveSimpleUDFTable") {
+        sql(s"create table HiveSimpleUDFTable as select false as v")
+        val df = sql("SELECT CodeGenHiveSimpleUDF(v) from HiveSimpleUDFTable")
+        checkError(
+          exception = 
intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException],
+          errorClass = "FAILED_EXECUTE_UDF",
+          parameters = Map(
+            "functionName" -> s"${classOf[SimpleUDFAssertTrue].getName}",
+            "signature" -> "boolean",
+            "result" -> "boolean"
+          )
+        )
+      }
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {
@@ -844,3 +877,12 @@ class ListFiles extends UDF {
     if (fileArray != null) Arrays.asList(fileArray: _*) else new 
ArrayList[String]()
   }
 }
+
+class SimpleUDFAssertTrue extends UDF {
+  def evaluate(condition: Boolean): Boolean = {
+    if (!condition) {
+      throw new HiveException("ASSERT_TRUE(): assertion failed.");
+    }
+    condition
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to