Repository: spark
Updated Branches:
  refs/heads/master 7026ee23e -> 8a7db8a60


[SPARK-18980][SQL] implement Aggregator with TypedImperativeAggregate

## What changes were proposed in this pull request?

Currently we implement `Aggregator` with `DeclarativeAggregate`, which will 
serialize/deserialize the buffer object every time we process an input.

This PR implements `Aggregator` with `TypedImperativeAggregate` and avoids to 
serialize/deserialize buffer object many times. The benchmark shows we get 
about 2 times speed up.

For simple buffer object that doesn't need serialization, we still go with 
`DeclarativeAggregate`, to avoid performance regression.

## How was this patch tested?

N/A

Author: Wenchen Fan <[email protected]>

Closes #16383 from cloud-fan/aggregator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a7db8a6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a7db8a6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a7db8a6

Branch: refs/heads/master
Commit: 8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799
Parents: 7026ee2
Author: Wenchen Fan <[email protected]>
Authored: Mon Dec 26 22:10:20 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Mon Dec 26 22:10:20 2016 +0800

----------------------------------------------------------------------
 .../aggregate/ApproximatePercentile.scala       |   6 +-
 .../aggregate/CountMinSketchAgg.scala           |   6 +-
 .../expressions/aggregate/Percentile.scala      |  10 +-
 .../expressions/aggregate/interfaces.scala      |  23 ++-
 .../scala/org/apache/spark/sql/Column.scala     |   8 +-
 .../aggregate/TypedAggregateExpression.scala    | 185 ++++++++++++++++---
 .../org/apache/spark/sql/DatasetBenchmark.scala |  12 +-
 .../sql/TypedImperativeAggregateSuite.scala     |   6 +-
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |   6 +-
 .../sql/hive/execution/TestingTypedCount.scala  |   6 +-
 10 files changed, 212 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 0e71442..18b7f95 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -120,16 +120,18 @@ case class ApproximatePercentile(
     new PercentileDigest(relativeError)
   }
 
-  override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = 
{
+  override def update(buffer: PercentileDigest, inputRow: InternalRow): 
PercentileDigest = {
     val value = child.eval(inputRow)
     // Ignore empty rows, for example: percentile_approx(null)
     if (value != null) {
       buffer.add(value.asInstanceOf[Double])
     }
+    buffer
   }
 
-  override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit 
= {
+  override def merge(buffer: PercentileDigest, other: PercentileDigest): 
PercentileDigest = {
     buffer.merge(other)
+    buffer
   }
 
   override def eval(buffer: PercentileDigest): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index 612c198..dae88c7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -90,7 +90,7 @@ case class CountMinSketchAgg(
     CountMinSketch.create(eps, confidence, seed)
   }
 
-  override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
+  override def update(buffer: CountMinSketch, input: InternalRow): 
CountMinSketch = {
     val value = child.eval(input)
     // Ignore empty rows
     if (value != null) {
@@ -101,10 +101,12 @@ case class CountMinSketchAgg(
         case _ => buffer.add(value)
       }
     }
+    buffer
   }
 
-  override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = {
+  override def merge(buffer: CountMinSketch, input: CountMinSketch): 
CountMinSketch = {
     buffer.mergeInPlace(input)
+    buffer
   }
 
   override def eval(buffer: CountMinSketch): Any = serialize(buffer)

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 2f68195..2f4d68d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -123,19 +123,25 @@ case class Percentile(
     new OpenHashMap[Number, Long]()
   }
 
-  override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): 
Unit = {
+  override def update(
+      buffer: OpenHashMap[Number, Long],
+      input: InternalRow): OpenHashMap[Number, Long] = {
     val key = child.eval(input).asInstanceOf[Number]
 
     // Null values are ignored in counts map.
     if (key != null) {
       buffer.changeValue(key, 1L, _ + 1L)
     }
+    buffer
   }
 
-  override def merge(buffer: OpenHashMap[Number, Long], other: 
OpenHashMap[Number, Long]): Unit = {
+  override def merge(
+      buffer: OpenHashMap[Number, Long],
+      other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
     other.foreach { case (key, count) =>
       buffer.changeValue(key, count, _ + count)
     }
+    buffer
   }
 
   override def eval(buffer: OpenHashMap[Number, Long]): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 7397b60..8e63fba 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -155,7 +155,7 @@ case class AggregateExpression(
  * Code which accepts [[AggregateFunction]] instances should be prepared to 
handle both types of
  * aggregate functions.
  */
-sealed abstract class AggregateFunction extends Expression {
+abstract class AggregateFunction extends Expression {
 
   /** An aggregate function is not foldable. */
   final override def foldable: Boolean = false
@@ -471,23 +471,29 @@ abstract class TypedImperativeAggregate[T] extends 
ImperativeAggregate {
   def createAggregationBuffer(): T
 
   /**
-   * In-place updates the aggregation buffer object with an input row. buffer 
= buffer + input.
+   * Updates the aggregation buffer object with an input row and returns a new 
buffer object. For
+   * performance, the function may do in-place update and return it instead of 
constructing new
+   * buffer object.
+   *
    * This is typically called when doing Partial or Complete mode aggregation.
    *
    * @param buffer The aggregation buffer object.
    * @param input an input row
    */
-  def update(buffer: T, input: InternalRow): Unit
+  def update(buffer: T, input: InternalRow): T
 
   /**
-   * Merges an input aggregation object into aggregation buffer object. buffer 
= buffer + input.
+   * Merges an input aggregation object into aggregation buffer object and 
returns a new buffer
+   * object. For performance, the function may do in-place merge and return it 
instead of
+   * constructing new buffer object.
+   *
    * This is typically called when doing PartialMerge or Final mode 
aggregation.
    *
    * @param buffer the aggregation buffer object used to store the aggregation 
result.
    * @param input an input aggregation object. Input aggregation object can be 
produced by
    *              de-serializing the partial aggregate's output from Mapper 
side.
    */
-  def merge(buffer: T, input: T): Unit
+  def merge(buffer: T, input: T): T
 
   /**
    * Generates the final aggregation result value for current key group with 
the aggregation buffer
@@ -505,19 +511,18 @@ abstract class TypedImperativeAggregate[T] extends 
ImperativeAggregate {
   def deserialize(storageFormat: Array[Byte]): T
 
   final override def initialize(buffer: InternalRow): Unit = {
-    val bufferObject = createAggregationBuffer()
-    buffer.update(mutableAggBufferOffset, bufferObject)
+    buffer(mutableAggBufferOffset) = createAggregationBuffer()
   }
 
   final override def update(buffer: InternalRow, input: InternalRow): Unit = {
-    update(getBufferObject(buffer), input)
+    buffer(mutableAggBufferOffset) = update(getBufferObject(buffer), input)
   }
 
   final override def merge(buffer: InternalRow, inputBuffer: InternalRow): 
Unit = {
     val bufferObject = getBufferObject(buffer)
     // The inputBuffer stores serialized aggregation buffer object produced by 
partial aggregate
     val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset))
-    merge(bufferObject, inputObject)
+    buffer(mutableAggBufferOffset) = merge(bufferObject, inputObject)
   }
 
   final override def eval(buffer: InternalRow): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e99d786..a3f581f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -75,10 +75,10 @@ class TypedColumn[-T, U](
     val unresolvedDeserializer = 
UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
     val newExpr = expr transform {
       case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
-        ta.copy(
-          inputDeserializer = Some(unresolvedDeserializer),
-          inputClass = Some(inputEncoder.clsTag.runtimeClass),
-          inputSchema = Some(inputEncoder.schema))
+        ta.withInputInfo(
+          deser = unresolvedDeserializer,
+          cls = inputEncoder.clsTag.runtimeClass,
+          schema = inputEncoder.schema)
     }
     new TypedColumn[T, U](newExpr, encoder)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 9911c0b..4146bf3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.aggregate
 import scala.language.existentials
 
 import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, 
UnresolvedDeserializer, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
 import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, 
DeclarativeAggregate, TypedImperativeAggregate}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.types._
@@ -33,9 +35,6 @@ object TypedAggregateExpression {
       aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
     val bufferEncoder = encoderFor[BUF]
     val bufferSerializer = bufferEncoder.namedExpressions
-    val bufferDeserializer = UnresolvedDeserializer(
-      bufferEncoder.deserializer,
-      bufferSerializer.map(_.toAttribute))
 
     val outputEncoder = encoderFor[OUT]
     val outputType = if (outputEncoder.flat) {
@@ -44,24 +43,78 @@ object TypedAggregateExpression {
       outputEncoder.schema
     }
 
-    new TypedAggregateExpression(
-      aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
-      None,
-      None,
-      None,
-      bufferSerializer,
-      bufferDeserializer,
-      outputEncoder.serializer,
-      outputEncoder.deserializer.dataType,
-      outputType,
-      !outputEncoder.flat || outputEncoder.schema.head.nullable)
+    // Checks if the buffer object is simple, i.e. the buffer encoder is flat 
and the serializer
+    // expression is an alias of `BoundReference`, which means the buffer 
object doesn't need
+    // serialization.
+    val isSimpleBuffer = {
+      bufferSerializer.head match {
+        case Alias(_: BoundReference, _) if bufferEncoder.flat => true
+        case _ => false
+      }
+    }
+
+    // If the buffer object is simple, use `SimpleTypedAggregateExpression`, 
which supports whole
+    // stage codegen.
+    if (isSimpleBuffer) {
+      val bufferDeserializer = UnresolvedDeserializer(
+        bufferEncoder.deserializer,
+        bufferSerializer.map(_.toAttribute))
+
+      SimpleTypedAggregateExpression(
+        aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
+        None,
+        None,
+        None,
+        bufferSerializer,
+        bufferDeserializer,
+        outputEncoder.serializer,
+        outputEncoder.deserializer.dataType,
+        outputType,
+        !outputEncoder.flat || outputEncoder.schema.head.nullable)
+    } else {
+      ComplexTypedAggregateExpression(
+        aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
+        None,
+        None,
+        None,
+        bufferSerializer,
+        bufferEncoder.resolveAndBind().deserializer,
+        outputEncoder.serializer,
+        outputType,
+        !outputEncoder.flat || outputEncoder.schema.head.nullable)
+    }
   }
 }
 
 /**
  * A helper class to hook [[Aggregator]] into the aggregation system.
  */
-case class TypedAggregateExpression(
+trait TypedAggregateExpression extends AggregateFunction {
+
+  def aggregator: Aggregator[Any, Any, Any]
+
+  def inputDeserializer: Option[Expression]
+  def inputClass: Option[Class[_]]
+  def inputSchema: Option[StructType]
+
+  def withInputInfo(deser: Expression, cls: Class[_], schema: StructType): 
TypedAggregateExpression
+
+  override def toString: String = {
+    val input = inputDeserializer match {
+      case Some(UnresolvedDeserializer(deserializer, _)) => 
deserializer.dataType.simpleString
+      case Some(deserializer) => deserializer.dataType.simpleString
+      case _ => "unknown"
+    }
+
+    s"$nodeName($input)"
+  }
+
+  override def nodeName: String = 
aggregator.getClass.getSimpleName.stripSuffix("$")
+}
+
+// TODO: merge these 2 implementations once we refactor the 
`AggregateFunction` interface.
+
+case class SimpleTypedAggregateExpression(
     aggregator: Aggregator[Any, Any, Any],
     inputDeserializer: Option[Expression],
     inputClass: Option[Class[_]],
@@ -71,7 +124,8 @@ case class TypedAggregateExpression(
     outputSerializer: Seq[Expression],
     outputExternalType: DataType,
     dataType: DataType,
-    nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression {
+    nullable: Boolean)
+  extends DeclarativeAggregate with TypedAggregateExpression with 
NonSQLExpression {
 
   override def deterministic: Boolean = true
 
@@ -143,15 +197,96 @@ case class TypedAggregateExpression(
     }
   }
 
-  override def toString: String = {
-    val input = inputDeserializer match {
-      case Some(UnresolvedDeserializer(deserializer, _)) => 
deserializer.dataType.simpleString
-      case Some(deserializer) => deserializer.dataType.simpleString
-      case _ => "unknown"
+  override def withInputInfo(
+      deser: Expression,
+      cls: Class[_],
+      schema: StructType): TypedAggregateExpression = {
+    copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema 
= Some(schema))
+  }
+}
+
+case class ComplexTypedAggregateExpression(
+    aggregator: Aggregator[Any, Any, Any],
+    inputDeserializer: Option[Expression],
+    inputClass: Option[Class[_]],
+    inputSchema: Option[StructType],
+    bufferSerializer: Seq[NamedExpression],
+    bufferDeserializer: Expression,
+    outputSerializer: Seq[Expression],
+    dataType: DataType,
+    nullable: Boolean,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[Any] with TypedAggregateExpression with 
NonSQLExpression {
+
+  override def deterministic: Boolean = true
+
+  override def children: Seq[Expression] = inputDeserializer.toSeq
+
+  override lazy val resolved: Boolean = inputDeserializer.isDefined && 
childrenResolved
+
+  override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)
+
+  override def createAggregationBuffer(): Any = aggregator.zero
+
+  private lazy val inputRowToObj = 
GenerateSafeProjection.generate(inputDeserializer.get :: Nil)
+
+  override def update(buffer: Any, input: InternalRow): Any = {
+    val inputObj = inputRowToObj(input).get(0, ObjectType(classOf[Any]))
+    if (inputObj != null) {
+      aggregator.reduce(buffer, inputObj)
+    } else {
+      buffer
+    }
+  }
+
+  override def merge(buffer: Any, input: Any): Any = {
+    aggregator.merge(buffer, input)
+  }
+
+  private lazy val resultObjToRow = dataType match {
+    case _: StructType =>
+      UnsafeProjection.create(CreateStruct(outputSerializer))
+    case _ =>
+      assert(outputSerializer.length == 1)
+      UnsafeProjection.create(outputSerializer.head)
+  }
+
+  override def eval(buffer: Any): Any = {
+    val resultObj = aggregator.finish(buffer)
+    if (resultObj == null) {
+      null
+    } else {
+      resultObjToRow(InternalRow(resultObj)).get(0, dataType)
     }
+  }
 
-    s"$nodeName($input)"
+  private lazy val bufferObjToRow = UnsafeProjection.create(bufferSerializer)
+
+  override def serialize(buffer: Any): Array[Byte] = {
+    bufferObjToRow(InternalRow(buffer)).getBytes
   }
 
-  override def nodeName: String = 
aggregator.getClass.getSimpleName.stripSuffix("$")
+  private lazy val bufferRow = new UnsafeRow(bufferSerializer.length)
+  private lazy val bufferRowToObject = 
GenerateSafeProjection.generate(bufferDeserializer :: Nil)
+
+  override def deserialize(storageFormat: Array[Byte]): Any = {
+    bufferRow.pointTo(storageFormat, storageFormat.length)
+    bufferRowToObject(bufferRow).get(0, ObjectType(classOf[Any]))
+  }
+
+  override def withNewMutableAggBufferOffset(
+      newMutableAggBufferOffset: Int): ComplexTypedAggregateExpression =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(
+      newInputAggBufferOffset: Int): ComplexTypedAggregateExpression =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def withInputInfo(
+      deser: Expression,
+      cls: Class[_],
+      schema: StructType): TypedAggregateExpression = {
+    copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema 
= Some(schema))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index c11605d..66d94d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -192,14 +192,14 @@ object DatasetBenchmark {
     benchmark2.run()
 
     /*
-    OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64
-    Intel Xeon E3-12xx v2 (Ivy Bridge)
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
     aggregate:                               Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
     
------------------------------------------------------------------------------------------------
-    RDD sum                                       1420 / 1523         70.4     
     14.2       1.0X
-    DataFrame sum                                   31 /   49       3214.3     
      0.3      45.6X
-    Dataset sum using Aggregator                  3216 / 3257         31.1     
     32.2       0.4X
-    Dataset complex Aggregator                    7948 / 8461         12.6     
     79.5       0.2X
+    RDD sum                                       1913 / 1942         52.3     
     19.1       1.0X
+    DataFrame sum                                   46 /   61       2157.7     
      0.5      41.3X
+    Dataset sum using Aggregator                  4656 / 4758         21.5     
     46.6       0.4X
+    Dataset complex Aggregator                    6636 / 7039         15.1     
     66.4       0.3X
     */
     benchmark3.run()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
index 70c3951..b76f168 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -240,7 +240,7 @@ object TypedImperativeAggregateSuite {
       new MaxValue(Int.MinValue)
     }
 
-    override def update(buffer: MaxValue, input: InternalRow): Unit = {
+    override def update(buffer: MaxValue, input: InternalRow): MaxValue = {
       child.eval(input) match {
         case inputValue: Int =>
           if (inputValue > buffer.value) {
@@ -249,13 +249,15 @@ object TypedImperativeAggregateSuite {
           }
         case null => // skip
       }
+      buffer
     }
 
-    override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = {
+    override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = {
       if (inputMax.value > bufferMax.value) {
         bufferMax.value = inputMax.value
         bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet
       }
+      bufferMax
     }
 
     override def eval(bufferMax: MaxValue): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 26dc372..fcefd69 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -397,17 +397,19 @@ private[hive] case class HiveUDAFFunction(
   @transient
   private lazy val inputProjection = UnsafeProjection.create(children)
 
-  override def update(buffer: AggregationBuffer, input: InternalRow): Unit = {
+  override def update(buffer: AggregationBuffer, input: InternalRow): 
AggregationBuffer = {
     partial1ModeEvaluator.iterate(
       buffer, wrap(inputProjection(input), inputWrappers, cached, 
inputDataTypes))
+    buffer
   }
 
-  override def merge(buffer: AggregationBuffer, input: AggregationBuffer): 
Unit = {
+  override def merge(buffer: AggregationBuffer, input: AggregationBuffer): 
AggregationBuffer = {
     // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is 
an input aggregation
     // buffer in the 3rd format mentioned in the ScalaDoc of this class. 
Originally, Hive converts
     // this `AggregationBuffer`s into this format before shuffling partial 
aggregation results, and
     // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion.
     partial2ModeEvaluator.merge(buffer, 
partial1ModeEvaluator.terminatePartial(input))
+    buffer
   }
 
   override def eval(buffer: AggregationBuffer): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
index d27287b..aaf1db6 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
@@ -46,14 +46,16 @@ case class TestingTypedCount(
 
   override def createAggregationBuffer(): State = TestingTypedCount.State(0L)
 
-  override def update(buffer: State, input: InternalRow): Unit = {
+  override def update(buffer: State, input: InternalRow): State = {
     if (child.eval(input) != null) {
       buffer.count += 1
     }
+    buffer
   }
 
-  override def merge(buffer: State, input: State): Unit = {
+  override def merge(buffer: State, input: State): State = {
     buffer.count += input.count
+    buffer
   }
 
   override def eval(buffer: State): Any = buffer.count


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to