Repository: spark
Updated Branches:
  refs/heads/branch-2.3 61b301cc7 -> 353d32804


[SPARK-25768][SQL] fix constant argument expecting UDAFs

## What changes were proposed in this pull request?

Without this PR some UDAFs like `GenericUDAFPercentileApprox` can throw an 
exception because expecting a constant parameter (object inspector) as a 
particular argument.

The exception is thrown because `toPrettySQL` call in `ResolveAliases` analyzer 
rule transforms a `Literal` parameter to a `PrettyAttribute` which is then 
transformed to an `ObjectInspector` instead of a `ConstantObjectInspector`.
The exception comes from `getEvaluator` method of `GenericUDAFPercentileApprox` 
that actually shouldn't be called during `toPrettySQL` transformation. The 
reason why it is called are the non lazy fields in `HiveUDAFFunction`.

This PR makes all fields of `HiveUDAFFunction` lazy.

## How was this patch tested?

added new UT

Closes #22766 from peter-toth/SPARK-25768.

Authored-by: Peter Toth <peter.t...@gmail.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>
(cherry picked from commit f38594fc561208e17af80d17acf8da362b91fca4)
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/353d3280
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/353d3280
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/353d3280

Branch: refs/heads/branch-2.3
Commit: 353d328041397762e12acf915967cafab5dcdade
Parents: 61b301c
Author: Peter Toth <peter.t...@gmail.com>
Authored: Fri Oct 19 21:17:14 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Oct 19 21:18:36 2018 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/hiveUDFs.scala    | 53 +++++++++++---------
 .../spark/sql/hive/execution/HiveUDFSuite.scala | 14 ++++++
 2 files changed, 42 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/353d3280/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 68af99e..4a84509 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -340,39 +340,40 @@ private[hive] case class HiveUDAFFunction(
     resolver.getEvaluator(parameterInfo)
   }
 
-  // The UDAF evaluator used to consume raw input rows and produce partial 
aggregation results.
-  @transient
-  private lazy val partial1ModeEvaluator = newEvaluator()
+  private case class HiveEvaluator(
+      evaluator: GenericUDAFEvaluator,
+      objectInspector: ObjectInspector)
 
+  // The UDAF evaluator used to consume raw input rows and produce partial 
aggregation results.
   // Hive `ObjectInspector` used to inspect partial aggregation results.
   @transient
-  private val partialResultInspector = partial1ModeEvaluator.init(
-    GenericUDAFEvaluator.Mode.PARTIAL1,
-    inputInspectors
-  )
+  private lazy val partial1HiveEvaluator = {
+    val evaluator = newEvaluator()
+    HiveEvaluator(evaluator, 
evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inputInspectors))
+  }
 
   // The UDAF evaluator used to merge partial aggregation results.
   @transient
   private lazy val partial2ModeEvaluator = {
     val evaluator = newEvaluator()
-    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, 
Array(partialResultInspector))
+    evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL2, 
Array(partial1HiveEvaluator.objectInspector))
     evaluator
   }
 
   // Spark SQL data type of partial aggregation results
   @transient
-  private lazy val partialResultDataType = 
inspectorToDataType(partialResultInspector)
+  private lazy val partialResultDataType =
+    inspectorToDataType(partial1HiveEvaluator.objectInspector)
 
   // The UDAF evaluator used to compute the final result from a partial 
aggregation result objects.
-  @transient
-  private lazy val finalModeEvaluator = newEvaluator()
-
   // Hive `ObjectInspector` used to inspect the final aggregation result 
object.
   @transient
-  private val returnInspector = finalModeEvaluator.init(
-    GenericUDAFEvaluator.Mode.FINAL,
-    Array(partialResultInspector)
-  )
+  private lazy val finalHiveEvaluator = {
+    val evaluator = newEvaluator()
+    HiveEvaluator(
+      evaluator,
+      evaluator.init(GenericUDAFEvaluator.Mode.FINAL, 
Array(partial1HiveEvaluator.objectInspector)))
+  }
 
   // Wrapper functions used to wrap Spark SQL input arguments into Hive 
specific format.
   @transient
@@ -381,7 +382,7 @@ private[hive] case class HiveUDAFFunction(
   // Unwrapper function used to unwrap final aggregation result objects 
returned by Hive UDAFs into
   // Spark SQL specific format.
   @transient
-  private lazy val resultUnwrapper = unwrapperFor(returnInspector)
+  private lazy val resultUnwrapper = 
unwrapperFor(finalHiveEvaluator.objectInspector)
 
   @transient
   private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -391,7 +392,7 @@ private[hive] case class HiveUDAFFunction(
 
   override def nullable: Boolean = true
 
-  override lazy val dataType: DataType = inspectorToDataType(returnInspector)
+  override lazy val dataType: DataType = 
inspectorToDataType(finalHiveEvaluator.objectInspector)
 
   override def prettyName: String = name
 
@@ -401,13 +402,13 @@ private[hive] case class HiveUDAFFunction(
   }
 
   override def createAggregationBuffer(): AggregationBuffer =
-    partial1ModeEvaluator.getNewAggregationBuffer
+    partial1HiveEvaluator.evaluator.getNewAggregationBuffer
 
   @transient
   private lazy val inputProjection = UnsafeProjection.create(children)
 
   override def update(buffer: AggregationBuffer, input: InternalRow): 
AggregationBuffer = {
-    partial1ModeEvaluator.iterate(
+    partial1HiveEvaluator.evaluator.iterate(
       buffer, wrap(inputProjection(input), inputWrappers, cached, 
inputDataTypes))
     buffer
   }
@@ -417,12 +418,12 @@ private[hive] case class HiveUDAFFunction(
     // buffer in the 3rd format mentioned in the ScalaDoc of this class. 
Originally, Hive converts
     // this `AggregationBuffer`s into this format before shuffling partial 
aggregation results, and
     // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
-    partial2ModeEvaluator.merge(buffer, 
partial1ModeEvaluator.terminatePartial(input))
+    partial2ModeEvaluator.merge(buffer, 
partial1HiveEvaluator.evaluator.terminatePartial(input))
     buffer
   }
 
   override def eval(buffer: AggregationBuffer): Any = {
-    resultUnwrapper(finalModeEvaluator.terminate(buffer))
+    resultUnwrapper(finalHiveEvaluator.evaluator.terminate(buffer))
   }
 
   override def serialize(buffer: AggregationBuffer): Array[Byte] = {
@@ -439,9 +440,10 @@ private[hive] case class HiveUDAFFunction(
 
   // Helper class used to de/serialize Hive UDAF `AggregationBuffer` objects
   private class AggregationBufferSerDe {
-    private val partialResultUnwrapper = unwrapperFor(partialResultInspector)
+    private val partialResultUnwrapper = 
unwrapperFor(partial1HiveEvaluator.objectInspector)
 
-    private val partialResultWrapper = wrapperFor(partialResultInspector, 
partialResultDataType)
+    private val partialResultWrapper =
+      wrapperFor(partial1HiveEvaluator.objectInspector, partialResultDataType)
 
     private val projection = 
UnsafeProjection.create(Array(partialResultDataType))
 
@@ -451,7 +453,8 @@ private[hive] case class HiveUDAFFunction(
       // `GenericUDAFEvaluator.terminatePartial()` converts an 
`AggregationBuffer` into an object
       // that can be inspected by the `ObjectInspector` returned by 
`GenericUDAFEvaluator.init()`.
       // Then we can unwrap it to a Spark SQL value.
-      mutableRow.update(0, 
partialResultUnwrapper(partial1ModeEvaluator.terminatePartial(buffer)))
+      mutableRow.update(0, partialResultUnwrapper(
+        partial1HiveEvaluator.evaluator.terminatePartial(buffer)))
       val unsafeRow = projection(mutableRow)
       val bytes = ByteBuffer.allocate(unsafeRow.getSizeInBytes)
       unsafeRow.writeTo(bytes)

http://git-wip-us.apache.org/repos/asf/spark/blob/353d3280/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 6198d49..a6fc744 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -638,6 +638,20 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
         Row(3) :: Row(3) :: Nil)
     }
   }
+
+  test("SPARK-25768 constant argument expecting Hive UDF") {
+    withTempView("inputTable") {
+      spark.range(10).createOrReplaceTempView("inputTable")
+      withUserDefinedFunction("testGenericUDAFPercentileApprox" -> false) {
+        val numFunc = spark.catalog.listFunctions().count()
+        sql(s"CREATE FUNCTION testGenericUDAFPercentileApprox AS '" +
+          s"${classOf[GenericUDAFPercentileApprox].getName}'")
+        checkAnswer(
+          sql("SELECT testGenericUDAFPercentileApprox(id, 0.5) FROM 
inputTable"),
+          Seq(Row(4.0)))
+      }
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to