Repository: spark Updated Branches: refs/heads/master e8167768c -> f38594fc5
[SPARK-25768][SQL] fix constant argument expecting UDAFs ## What changes were proposed in this pull request? Without this PR some UDAFs like `GenericUDAFPercentileApprox` can throw an exception because expecting a constant parameter (object inspector) as a particular argument. The exception is thrown because `toPrettySQL` call in `ResolveAliases` analyzer rule transforms a `Literal` parameter to a `PrettyAttribute` which is then transformed to an `ObjectInspector` instead of a `ConstantObjectInspector`. The exception comes from `getEvaluator` method of `GenericUDAFPercentileApprox` that actually shouldn't be called during `toPrettySQL` transformation. The reason why it is called are the non lazy fields in `HiveUDAFFunction`. This PR makes all fields of `HiveUDAFFunction` lazy. ## How was this patch tested? added new UT Closes #22766 from peter-toth/SPARK-25768. Authored-by: Peter Toth <peter.t...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f38594fc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f38594fc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f38594fc Branch: refs/heads/master Commit: f38594fc561208e17af80d17acf8da362b91fca4 Parents: e816776 Author: Peter Toth <peter.t...@gmail.com> Authored: Fri Oct 19 21:17:14 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Oct 19 21:17:14 2018 +0800 ---------------------------------------------------------------------- .../org/apache/spark/sql/hive/hiveUDFs.scala | 53 +++++++++++--------- .../spark/sql/hive/execution/HiveUDFSuite.scala | 14 ++++++ 2 files changed, 42 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f38594fc/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 68af99e..4a84509 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -340,39 +340,40 @@ private[hive] case class HiveUDAFFunction( resolver.getEvaluator(parameterInfo) } - // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. - @transient - private lazy val partial1ModeEvaluator = newEvaluator() + private case class HiveEvaluator( + evaluator: GenericUDAFEvaluator, + objectInspector: ObjectInspector) + // The UDAF evaluator used to consume raw input rows and produce partial aggregation results. // Hive `ObjectInspector` used to inspect partial aggregation results. @transient - private val partialResultInspector = partial1ModeEvaluator.init( - GenericUDAFEvaluator.Mode.PARTIAL1, - inputInspectors - ) + private lazy val partial1HiveEvaluator = { + val evaluator = newEvaluator() + HiveEvaluator(evaluator, evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors)) + } // The UDAF evaluator used to merge partial aggregation results. @transient private lazy val partial2ModeEvaluator = { val evaluator = newEvaluator() - evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partialResultInspector)) + evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, Array(partial1HiveEvaluator.objectInspector)) evaluator } // Spark SQL data type of partial aggregation results @transient - private lazy val partialResultDataType = inspectorToDataType(partialResultInspector) + private lazy val partialResultDataType = + inspectorToDataType(partial1HiveEvaluator.objectInspector) // The UDAF evaluator used to compute the final result from a partial aggregation result objects. - @transient - private lazy val finalModeEvaluator = newEvaluator() - // Hive `ObjectInspector` used to inspect the final aggregation result object. @transient - private val returnInspector = finalModeEvaluator.init( - GenericUDAFEvaluator.Mode.FINAL, - Array(partialResultInspector) - ) + private lazy val finalHiveEvaluator = { + val evaluator = newEvaluator() + HiveEvaluator( + evaluator, + evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(partial1HiveEvaluator.objectInspector))) + } // Wrapper functions used to wrap Spark SQL input arguments into Hive specific format. @transient @@ -381,7 +382,7 @@ private[hive] case class HiveUDAFFunction( // Unwrapper function used to unwrap final aggregation result objects returned by Hive UDAFs into // Spark SQL specific format. @transient - private lazy val resultUnwrapper = unwrapperFor(returnInspector) + private lazy val resultUnwrapper = unwrapperFor(finalHiveEvaluator.objectInspector) @transient private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) @@ -391,7 +392,7 @@ private[hive] case class HiveUDAFFunction( override def nullable: Boolean = true - override lazy val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(finalHiveEvaluator.objectInspector) override def prettyName: String = name @@ -401,13 +402,13 @@ private[hive] case class HiveUDAFFunction( } override def createAggregationBuffer(): AggregationBuffer = - partial1ModeEvaluator.getNewAggregationBuffer + partial1HiveEvaluator.evaluator.getNewAggregationBuffer @transient private lazy val inputProjection = UnsafeProjection.create(children) override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = { - partial1ModeEvaluator.iterate( + partial1HiveEvaluator.evaluator.iterate( buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) buffer } @@ -417,12 +418,12 @@ private[hive] case class HiveUDAFFunction( // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. - partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) + partial2ModeEvaluator.merge(buffer, partial1HiveEvaluator.evaluator.terminatePartial(input)) buffer } override def eval(buffer: AggregationBuffer): Any = { - resultUnwrapper(finalModeEvaluator.terminate(buffer)) + resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer)) } override def serialize(buffer: AggregationBuffer): Array[Byte] = { @@ -439,9 +440,10 @@ private[hive] case class HiveUDAFFunction( // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects private class AggregationBufferSerDe { - private val partialResultUnwrapper = unwrapperFor(partialResultInspector) + private val partialResultUnwrapper = unwrapperFor(partial1HiveEvaluator.objectInspector) - private val partialResultWrapper = wrapperFor(partialResultInspector, partialResultDataType) + private val partialResultWrapper = + wrapperFor(partial1HiveEvaluator.objectInspector, partialResultDataType) private val projection = UnsafeProjection.create(Array(partialResultDataType)) @@ -451,7 +453,8 @@ private[hive] case class HiveUDAFFunction( // `GenericUDAFEvaluator.terminatePartial()` converts an `AggregationBuffer` into an object // that can be inspected by the `ObjectInspector` returned by `GenericUDAFEvaluator.init()`. // Then we can unwrap it to a Spark SQL value. - mutableRow.update(0, partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer))) + mutableRow.update(0, partialResultUnwrapper( + partial1HiveEvaluator.evaluator.terminatePartial(buffer))) val unsafeRow = projection(mutableRow) val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes) unsafeRow.writeTo(bytes) http://git-wip-us.apache.org/repos/asf/spark/blob/f38594fc/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 6198d49..a6fc744 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -638,6 +638,20 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { Row(3) :: Row(3) :: Nil) } } + + test("SPARK-25768 constant argument expecting Hive UDF") { + withTempView("inputTable") { + spark.range(10).createOrReplaceTempView("inputTable") + withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) { + val numFunc = spark.catalog.listFunctions().count() + sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" + + s"${classOf[GenericUDAFPercentileApprox].getName}'") + checkAnswer( + sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM inputTable"), + Seq(Row(4.0))) + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org