This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 34fb40892e3 [SPARK-42051][SQL] Codegen Support for HiveGenericUDF 34fb40892e3 is described below commit 34fb40892e3b5680afbba59e5fd7b10e9a9a7d15 Author: Kent Yao <y...@apache.org> AuthorDate: Wed Feb 1 09:56:36 2023 +0800 [SPARK-42051][SQL] Codegen Support for HiveGenericUDF ### What changes were proposed in this pull request? As a subtask of SPARK-42050, this PR adds Codegen Support for `HiveGenericUDF` ### Why are the changes needed? improve codegen coverage and performance ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new UT added Closes #39555 from yaooqinn/SPARK-42051. Authored-by: Kent Yao <y...@apache.org> Signed-off-by: Kent Yao <y...@apache.org> --- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 62 ++++++++++++++++++---- .../spark/sql/hive/execution/HiveUDFSuite.scala | 34 +++++++++++- 2 files changed, 86 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a950c1a1783..32ade60e20d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -35,7 +35,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types._ @@ -120,19 +121,18 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp extends DeferredObject with HiveInspectors { private val wrapper = wrapperFor(oi, dataType) - private var func: () => Any = _ - def set(func: () => Any): Unit = { + private var func: Any = _ + def set(func: Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef] + override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef] } private[hive] case class HiveGenericUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors - with CodegenFallback with Logging with UserDefinedExpression { @@ -154,8 +154,9 @@ private[hive] case class HiveGenericUDF( function.initializeAndFoldConstants(argumentInspectors.toArray) } + // Visible for codegen @transient - private lazy val unwrapper = unwrapperFor(returnInspector) + lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) @transient private lazy val isUDFDeterministic = { @@ -163,9 +164,10 @@ private[hive] case class HiveGenericUDF( udfType != null && udfType.deterministic() && !udfType.stateful() } + // Visible for codegen @transient - private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) => - new DeferredObjectAdapter(inspect, child.dataType) + lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { + case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] override lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -178,7 +180,7 @@ private[hive] case class HiveGenericUDF( while (i < length) { val idx = i deferredObjects(i).asInstanceOf[DeferredObjectAdapter] - .set(() => children(idx).eval(input)) + .set(children(idx).eval(input)) i += 1 } unwrapper(function.evaluate(deferredObjects)) @@ -192,6 +194,48 @@ private[hive] case class HiveGenericUDF( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refTerm = ctx.addReferenceObj("this", this) + val childrenEvals = children.map(_.genCode(ctx)) + + val setDeferredObjects = childrenEvals.zipWithIndex.map { + case (eval, i) => + val deferredObjectAdapterClz = classOf[DeferredObjectAdapter].getCanonicalName + s""" + |if (${eval.isNull}) { + | (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(null); + |} else { + | (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(${eval.value}); + |} + |""".stripMargin + } + + val resultType = CodeGenerator.boxedType(dataType) + val resultTerm = ctx.freshName("result") + ev.copy(code = + code""" + |${childrenEvals.map(_.code).mkString("\n")} + |${setDeferredObjects.mkString("\n")} + |$resultType $resultTerm = null; + |boolean ${ev.isNull} = false; + |try { + | $resultTerm = ($resultType) $refTerm.unwrapper().apply( + | $refTerm.function().evaluate($refTerm.deferredObjects())); + | ${ev.isNull} = $resultTerm == null; + |} catch (Throwable e) { + | throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( + | "${funcWrapper.functionClassName}", + | "${children.map(_.dataType.catalogString).mkString(", ")}", + | "${dataType.catalogString}", + | e); + |} + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + |""".stripMargin + ) + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index f494232502f..baa25843d48 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -32,9 +32,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.{LongWritable, Writable} -import org.apache.spark.{SparkFiles, TestUtils} +import org.apache.spark.{SparkException, SparkFiles, TestUtils} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -711,6 +712,37 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("SPARK-42051: HiveGenericUDF Codegen Support") { + withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) { + sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS '${classOf[GenericUDFMaskHash].getName}'") + withTable("HiveGenericUDFTable") { + sql(s"create table HiveGenericUDFTable as select 'Spark SQL' as v") + val df = sql("SELECT CodeGenHiveGenericUDF(v) from HiveGenericUDFTable") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec]) + checkAnswer(df, Seq(Row("14ab8df5135825bc9f5ff7c30609f02f"))) + } + } + } + + test("SPARK-42051: HiveGenericUDF Codegen Support w/ execution failure") { + withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) { + sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS '${classOf[GenericUDFAssertTrue].getName}'") + withTable("HiveGenericUDFTable") { + sql(s"create table HiveGenericUDFTable as select false as v") + val df = sql("SELECT CodeGenHiveGenericUDF(v) from HiveGenericUDFTable") + val e = intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException] + checkError( + e, + "FAILED_EXECUTE_UDF", + parameters = Map( + "functionName" -> s"${classOf[GenericUDFAssertTrue].getName}", + "signature" -> "boolean", + "result" -> "void")) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org