This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5825db81e00 [SPARK-42052][SQL] Codegen Support for HiveSimpleUDF 5825db81e00 is described below commit 5825db81e0059a4895b4f59d57dec67b0bc618b4 Author: panbingkun <pbk1...@gmail.com> AuthorDate: Wed Mar 22 12:14:58 2023 +0800 [SPARK-42052][SQL] Codegen Support for HiveSimpleUDF ### What changes were proposed in this pull request? - As a subtask of [SPARK-42050](https://issues.apache.org/jira/browse/SPARK-42050), this PR adds Codegen Support for HiveSimpleUDF - Extract a`HiveUDFEvaluatorBase` class for the common behaviors of HiveSimpleUDFEvaluator & HiveGenericUDFEvaluator. ### Why are the changes needed? - Improve codegen coverage and performance. - Following https://github.com/apache/spark/pull/39949. Make the code more concise. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Add new UT. Pass GA. Closes #40397 from panbingkun/refactor_HiveSimpleUDF. Authored-by: panbingkun <pbk1...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/hive/hiveUDFEvaluators.scala | 148 +++++++++++++++++++++ .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 147 ++++++-------------- .../spark/sql/hive/execution/HiveUDFSuite.scala | 42 ++++++ 3 files changed, 232 insertions(+), 105 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala new file mode 100644 index 00000000000..094f8ba7a0f --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, UDF} +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.types.DataType + +abstract class HiveUDFEvaluatorBase[UDFType <: AnyRef]( + funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends HiveInspectors with Serializable { + + @transient + lazy val function = funcWrapper.createFunction[UDFType]() + + @transient + lazy val isUDFDeterministic = { + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() && !udfType.stateful() + } + + def returnType: DataType + + def setArg(index: Int, arg: Any): Unit + + def doEvaluate(): Any + + final def evaluate(): Any = { + try { + doEvaluate() + } catch { + case e: Throwable => + throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( + s"${funcWrapper.functionClassName}", + s"${children.map(_.dataType.catalogString).mkString(", ")}", + s"${returnType.catalogString}", + e) + } + } +} + +class HiveSimpleUDFEvaluator( + funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends HiveUDFEvaluatorBase[UDF](funcWrapper, children) { + + @transient + lazy val method = function.getResolver. + getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) + + @transient + private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray + + @transient + private lazy val arguments = children.map(toInspector).toArray + + // Create parameter converters + @transient + private lazy val conversionHelper = new ConversionHelper(method, arguments) + + @transient + private lazy val inputs: Array[AnyRef] = new Array[AnyRef](children.length) + + override def returnType: DataType = javaTypeToDataType(method.getGenericReturnType) + + override def setArg(index: Int, arg: Any): Unit = { + inputs(index) = wrappers(index)(arg).asInstanceOf[AnyRef] + } + + @transient + private lazy val unwrapper: Any => Any = + unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( + method.getGenericReturnType, ObjectInspectorOptions.JAVA)) + + override def doEvaluate(): Any = { + val ret = FunctionRegistry.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs: _*): _*) + unwrapper(ret) + } +} + +class HiveGenericUDFEvaluator( + funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends HiveUDFEvaluatorBase[GenericUDF](funcWrapper, children) { + + @transient + private lazy val argumentInspectors = children.map(toInspector) + + @transient + lazy val returnInspector = { + function.initializeAndFoldConstants(argumentInspectors.toArray) + } + + @transient + private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { + case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] + + @transient + private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) + + override def returnType: DataType = inspectorToDataType(returnInspector) + + def setArg(index: Int, arg: Any): Unit = + deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg) + + override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects)) +} + +// Adapter from Catalyst ExpressionResult to Hive DeferredObject +private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) + extends DeferredObject with HiveInspectors { + + private val wrapper = wrapperFor(oi, dataType) + private var func: Any = _ + def set(func: Any): Unit = { + this.func = func + } + override def prepare(i: Int): Unit = {} + override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef] +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 98b2258ea13..b07a1b717e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -23,15 +23,10 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -49,56 +44,26 @@ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors - with CodegenFallback - with Logging with UserDefinedExpression { - override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) - - override def nullable: Boolean = true - - @transient - lazy val function = funcWrapper.createFunction[UDF]() - - @transient - private lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo).asJava) - - @transient - private lazy val arguments = children.map(toInspector).toArray - @transient - private lazy val isUDFDeterministic = { - val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() && !udfType.stateful() - } + private lazy val evaluator = new HiveSimpleUDFEvaluator(funcWrapper, children) - override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) + override lazy val deterministic: Boolean = + evaluator.isUDFDeterministic && children.forall(_.deterministic) - // Create parameter converters - @transient - private lazy val conversionHelper = new ConversionHelper(method, arguments) + override def nullable: Boolean = true - override lazy val dataType = javaTypeToDataType(method.getGenericReturnType) + override def foldable: Boolean = evaluator.isUDFDeterministic && children.forall(_.foldable) - @transient - private lazy val wrappers = children.map(x => wrapperFor(toInspector(x), x.dataType)).toArray - - @transient - lazy val unwrapper = unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType, ObjectInspectorOptions.JAVA)) - - @transient - private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + override lazy val dataType: DataType = javaTypeToDataType(evaluator.method.getGenericReturnType) // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - val inputs = wrap(children.map(_.eval(input)), wrappers, cached) - val ret = FunctionRegistry.invoke( - method, - function, - conversionHelper.convertIfNecessary(inputs : _*): _*) - unwrapper(ret) + children.zipWithIndex.map { + case (child, idx) => evaluator.setArg(idx, child.eval(input)) + } + evaluator.evaluate() } override def toString: String = { @@ -111,19 +76,37 @@ private[hive] case class HiveSimpleUDF( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) -} -// Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) - extends DeferredObject with HiveInspectors { + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) + val evals = children.map(_.genCode(ctx)) + + val setValues = evals.zipWithIndex.map { + case (eval, i) => + s""" + |if (${eval.isNull}) { + | $refEvaluator.setArg($i, null); + |} else { + | $refEvaluator.setArg($i, ${eval.value}); + |} + |""".stripMargin + } - private val wrapper = wrapperFor(oi, dataType) - private var func: Any = _ - def set(func: Any): Unit = { - this.func = func + val resultType = CodeGenerator.boxedType(dataType) + val resultTerm = ctx.freshName("result") + ev.copy(code = + code""" + |${evals.map(_.code).mkString("\n")} + |${setValues.mkString("\n")} + |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(); + |boolean ${ev.isNull} = $resultTerm == null; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + |""".stripMargin + ) } - override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef] } private[hive] case class HiveGenericUDF( @@ -135,9 +118,9 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true override lazy val deterministic: Boolean = - isUDFDeterministic && children.forall(_.deterministic) + evaluator.isUDFDeterministic && children.forall(_.deterministic) - override def foldable: Boolean = isUDFDeterministic && + override def foldable: Boolean = evaluator.isUDFDeterministic && evaluator.returnInspector.isInstanceOf[ConstantObjectInspector] override lazy val dataType: DataType = inspectorToDataType(evaluator.returnInspector) @@ -145,12 +128,6 @@ private[hive] case class HiveGenericUDF( @transient private lazy val evaluator = new HiveGenericUDFEvaluator(funcWrapper, children) - @transient - private val isUDFDeterministic = { - val udfType = evaluator.function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() && !udfType.stateful() - } - override def eval(input: InternalRow): Any = { children.zipWithIndex.map { case (child, idx) => evaluator.setArg(idx, child.eval(input)) @@ -188,18 +165,8 @@ private[hive] case class HiveGenericUDF( code""" |${evals.map(_.code).mkString("\n")} |${setValues.mkString("\n")} - |$resultType $resultTerm = null; - |boolean ${ev.isNull} = false; - |try { - | $resultTerm = ($resultType) $refEvaluator.evaluate(); - | ${ev.isNull} = $resultTerm == null; - |} catch (Throwable e) { - | throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError( - | "${funcWrapper.functionClassName}", - | "${children.map(_.dataType.catalogString).mkString(", ")}", - | "${dataType.catalogString}", - | e); - |} + |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(); + |boolean ${ev.isNull} = $resultTerm == null; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |if (!${ev.isNull}) { | ${ev.value} = $resultTerm; @@ -209,36 +176,6 @@ private[hive] case class HiveGenericUDF( } } -class HiveGenericUDFEvaluator( - funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends HiveInspectors - with Serializable { - - @transient - lazy val function = funcWrapper.createFunction[GenericUDF]() - - @transient - private lazy val argumentInspectors = children.map(toInspector) - - @transient - lazy val returnInspector = { - function.initializeAndFoldConstants(argumentInspectors.toArray) - } - - @transient - private lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map { - case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) - }.toArray[DeferredObject] - - @transient - private lazy val unwrapper: Any => Any = unwrapperFor(returnInspector) - - def setArg(index: Int, arg: Any): Unit = - deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg) - - def evaluate(): Any = unwrapper(function.evaluate(deferredObjects)) -} - /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * `Generator`. Note that the semantics of Generators do not allow diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index baa25843d48..8fb9209f9cb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject @@ -743,6 +744,38 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("SPARK-42052: HiveSimpleUDF Codegen Support") { + withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) { + sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS '${classOf[UDFStringString].getName}'") + withTable("HiveSimpleUDFTable") { + sql(s"create table HiveSimpleUDFTable as select 'Spark SQL' as v") + val df = sql("SELECT CodeGenHiveSimpleUDF('Hello', v) from HiveSimpleUDFTable") + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec]) + checkAnswer(df, Seq(Row("Hello Spark SQL"))) + } + } + } + + test("SPARK-42052: HiveSimpleUDF Codegen Support w/ execution failure") { + withUserDefinedFunction("CodeGenHiveSimpleUDF" -> false) { + sql(s"CREATE FUNCTION CodeGenHiveSimpleUDF AS '${classOf[SimpleUDFAssertTrue].getName}'") + withTable("HiveSimpleUDFTable") { + sql(s"create table HiveSimpleUDFTable as select false as v") + val df = sql("SELECT CodeGenHiveSimpleUDF(v) from HiveSimpleUDFTable") + checkError( + exception = intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException], + errorClass = "FAILED_EXECUTE_UDF", + parameters = Map( + "functionName" -> s"${classOf[SimpleUDFAssertTrue].getName}", + "signature" -> "boolean", + "result" -> "boolean" + ) + ) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { @@ -844,3 +877,12 @@ class ListFiles extends UDF { if (fileArray != null) Arrays.asList(fileArray: _*) else new ArrayList[String]() } } + +class SimpleUDFAssertTrue extends UDF { + def evaluate(condition: Boolean): Boolean = { + if (!condition) { + throw new HiveException("ASSERT_TRUE(): assertion failed."); + } + condition + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org