Repository: spark Updated Branches: refs/heads/master 7026ee23e -> 8a7db8a60
[SPARK-18980][SQL] implement Aggregator with TypedImperativeAggregate ## What changes were proposed in this pull request? Currently we implement `Aggregator` with `DeclarativeAggregate`, which will serialize/deserialize the buffer object every time we process an input. This PR implements `Aggregator` with `TypedImperativeAggregate` and avoids to serialize/deserialize buffer object many times. The benchmark shows we get about 2 times speed up. For simple buffer object that doesn't need serialization, we still go with `DeclarativeAggregate`, to avoid performance regression. ## How was this patch tested? N/A Author: Wenchen Fan <[email protected]> Closes #16383 from cloud-fan/aggregator. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a7db8a6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a7db8a6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a7db8a6 Branch: refs/heads/master Commit: 8a7db8a608a9e27b10f205cc1b4ed5f2c3e83799 Parents: 7026ee2 Author: Wenchen Fan <[email protected]> Authored: Mon Dec 26 22:10:20 2016 +0800 Committer: Wenchen Fan <[email protected]> Committed: Mon Dec 26 22:10:20 2016 +0800 ---------------------------------------------------------------------- .../aggregate/ApproximatePercentile.scala | 6 +- .../aggregate/CountMinSketchAgg.scala | 6 +- .../expressions/aggregate/Percentile.scala | 10 +- .../expressions/aggregate/interfaces.scala | 23 ++- .../scala/org/apache/spark/sql/Column.scala | 8 +- .../aggregate/TypedAggregateExpression.scala | 185 ++++++++++++++++--- .../org/apache/spark/sql/DatasetBenchmark.scala | 12 +- .../sql/TypedImperativeAggregateSuite.scala | 6 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 6 +- .../sql/hive/execution/TestingTypedCount.scala | 6 +- 10 files changed, 212 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 0e71442..18b7f95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -120,16 +120,18 @@ case class ApproximatePercentile( new PercentileDigest(relativeError) } - override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = { + override def update(buffer: PercentileDigest, inputRow: InternalRow): PercentileDigest = { val value = child.eval(inputRow) // Ignore empty rows, for example: percentile_approx(null) if (value != null) { buffer.add(value.asInstanceOf[Double]) } + buffer } - override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = { + override def merge(buffer: PercentileDigest, other: PercentileDigest): PercentileDigest = { buffer.merge(other) + buffer } override def eval(buffer: PercentileDigest): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index 612c198..dae88c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -90,7 +90,7 @@ case class CountMinSketchAgg( CountMinSketch.create(eps, confidence, seed) } - override def update(buffer: CountMinSketch, input: InternalRow): Unit = { + override def update(buffer: CountMinSketch, input: InternalRow): CountMinSketch = { val value = child.eval(input) // Ignore empty rows if (value != null) { @@ -101,10 +101,12 @@ case class CountMinSketchAgg( case _ => buffer.add(value) } } + buffer } - override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = { + override def merge(buffer: CountMinSketch, input: CountMinSketch): CountMinSketch = { buffer.mergeInPlace(input) + buffer } override def eval(buffer: CountMinSketch): Any = serialize(buffer) http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 2f68195..2f4d68d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -123,19 +123,25 @@ case class Percentile( new OpenHashMap[Number, Long]() } - override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = { + override def update( + buffer: OpenHashMap[Number, Long], + input: InternalRow): OpenHashMap[Number, Long] = { val key = child.eval(input).asInstanceOf[Number] // Null values are ignored in counts map. if (key != null) { buffer.changeValue(key, 1L, _ + 1L) } + buffer } - override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = { + override def merge( + buffer: OpenHashMap[Number, Long], + other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = { other.foreach { case (key, count) => buffer.changeValue(key, count, _ + count) } + buffer } override def eval(buffer: OpenHashMap[Number, Long]): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 7397b60..8e63fba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -155,7 +155,7 @@ case class AggregateExpression( * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. */ -sealed abstract class AggregateFunction extends Expression { +abstract class AggregateFunction extends Expression { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false @@ -471,23 +471,29 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { def createAggregationBuffer(): T /** - * In-place updates the aggregation buffer object with an input row. buffer = buffer + input. + * Updates the aggregation buffer object with an input row and returns a new buffer object. For + * performance, the function may do in-place update and return it instead of constructing new + * buffer object. + * * This is typically called when doing Partial or Complete mode aggregation. * * @param buffer The aggregation buffer object. * @param input an input row */ - def update(buffer: T, input: InternalRow): Unit + def update(buffer: T, input: InternalRow): T /** - * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input. + * Merges an input aggregation object into aggregation buffer object and returns a new buffer + * object. For performance, the function may do in-place merge and return it instead of + * constructing new buffer object. + * * This is typically called when doing PartialMerge or Final mode aggregation. * * @param buffer the aggregation buffer object used to store the aggregation result. * @param input an input aggregation object. Input aggregation object can be produced by * de-serializing the partial aggregate's output from Mapper side. */ - def merge(buffer: T, input: T): Unit + def merge(buffer: T, input: T): T /** * Generates the final aggregation result value for current key group with the aggregation buffer @@ -505,19 +511,18 @@ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { def deserialize(storageFormat: Array[Byte]): T final override def initialize(buffer: InternalRow): Unit = { - val bufferObject = createAggregationBuffer() - buffer.update(mutableAggBufferOffset, bufferObject) + buffer(mutableAggBufferOffset) = createAggregationBuffer() } final override def update(buffer: InternalRow, input: InternalRow): Unit = { - update(getBufferObject(buffer), input) + buffer(mutableAggBufferOffset) = update(getBufferObject(buffer), input) } final override def merge(buffer: InternalRow, inputBuffer: InternalRow): Unit = { val bufferObject = getBufferObject(buffer) // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate val inputObject = deserialize(inputBuffer.getBinary(inputAggBufferOffset)) - merge(bufferObject, inputObject) + buffer(mutableAggBufferOffset) = merge(bufferObject, inputObject) } final override def eval(buffer: InternalRow): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/main/scala/org/apache/spark/sql/Column.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e99d786..a3f581f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -75,10 +75,10 @@ class TypedColumn[-T, U]( val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) val newExpr = expr transform { case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => - ta.copy( - inputDeserializer = Some(unresolvedDeserializer), - inputClass = Some(inputEncoder.clsTag.runtimeClass), - inputSchema = Some(inputEncoder.schema)) + ta.withInputInfo( + deser = unresolvedDeserializer, + cls = inputEncoder.clsTag.runtimeClass, + schema = inputEncoder.schema) } new TypedColumn[T, U](newExpr, encoder) } http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 9911c0b..4146bf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, DeclarativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ @@ -33,9 +35,6 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] val bufferSerializer = bufferEncoder.namedExpressions - val bufferDeserializer = UnresolvedDeserializer( - bufferEncoder.deserializer, - bufferSerializer.map(_.toAttribute)) val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { @@ -44,24 +43,78 @@ object TypedAggregateExpression { outputEncoder.schema } - new TypedAggregateExpression( - aggregator.asInstanceOf[Aggregator[Any, Any, Any]], - None, - None, - None, - bufferSerializer, - bufferDeserializer, - outputEncoder.serializer, - outputEncoder.deserializer.dataType, - outputType, - !outputEncoder.flat || outputEncoder.schema.head.nullable) + // Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer + // expression is an alias of `BoundReference`, which means the buffer object doesn't need + // serialization. + val isSimpleBuffer = { + bufferSerializer.head match { + case Alias(_: BoundReference, _) if bufferEncoder.flat => true + case _ => false + } + } + + // If the buffer object is simple, use `SimpleTypedAggregateExpression`, which supports whole + // stage codegen. + if (isSimpleBuffer) { + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) + + SimpleTypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + None, + None, + bufferSerializer, + bufferDeserializer, + outputEncoder.serializer, + outputEncoder.deserializer.dataType, + outputType, + !outputEncoder.flat || outputEncoder.schema.head.nullable) + } else { + ComplexTypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + None, + None, + bufferSerializer, + bufferEncoder.resolveAndBind().deserializer, + outputEncoder.serializer, + outputType, + !outputEncoder.flat || outputEncoder.schema.head.nullable) + } } } /** * A helper class to hook [[Aggregator]] into the aggregation system. */ -case class TypedAggregateExpression( +trait TypedAggregateExpression extends AggregateFunction { + + def aggregator: Aggregator[Any, Any, Any] + + def inputDeserializer: Option[Expression] + def inputClass: Option[Class[_]] + def inputSchema: Option[StructType] + + def withInputInfo(deser: Expression, cls: Class[_], schema: StructType): TypedAggregateExpression + + override def toString: String = { + val input = inputDeserializer match { + case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString + case Some(deserializer) => deserializer.dataType.simpleString + case _ => "unknown" + } + + s"$nodeName($input)" + } + + override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") +} + +// TODO: merge these 2 implementations once we refactor the `AggregateFunction` interface. + +case class SimpleTypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], inputClass: Option[Class[_]], @@ -71,7 +124,8 @@ case class TypedAggregateExpression( outputSerializer: Seq[Expression], outputExternalType: DataType, dataType: DataType, - nullable: Boolean) extends DeclarativeAggregate with NonSQLExpression { + nullable: Boolean) + extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { override def deterministic: Boolean = true @@ -143,15 +197,96 @@ case class TypedAggregateExpression( } } - override def toString: String = { - val input = inputDeserializer match { - case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString - case Some(deserializer) => deserializer.dataType.simpleString - case _ => "unknown" + override def withInputInfo( + deser: Expression, + cls: Class[_], + schema: StructType): TypedAggregateExpression = { + copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) + } +} + +case class ComplexTypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + inputDeserializer: Option[Expression], + inputClass: Option[Class[_]], + inputSchema: Option[StructType], + bufferSerializer: Seq[NamedExpression], + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + dataType: DataType, + nullable: Boolean, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { + + override def deterministic: Boolean = true + + override def children: Seq[Expression] = inputDeserializer.toSeq + + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved + + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) + + override def createAggregationBuffer(): Any = aggregator.zero + + private lazy val inputRowToObj = GenerateSafeProjection.generate(inputDeserializer.get :: Nil) + + override def update(buffer: Any, input: InternalRow): Any = { + val inputObj = inputRowToObj(input).get(0, ObjectType(classOf[Any])) + if (inputObj != null) { + aggregator.reduce(buffer, inputObj) + } else { + buffer + } + } + + override def merge(buffer: Any, input: Any): Any = { + aggregator.merge(buffer, input) + } + + private lazy val resultObjToRow = dataType match { + case _: StructType => + UnsafeProjection.create(CreateStruct(outputSerializer)) + case _ => + assert(outputSerializer.length == 1) + UnsafeProjection.create(outputSerializer.head) + } + + override def eval(buffer: Any): Any = { + val resultObj = aggregator.finish(buffer) + if (resultObj == null) { + null + } else { + resultObjToRow(InternalRow(resultObj)).get(0, dataType) } + } - s"$nodeName($input)" + private lazy val bufferObjToRow = UnsafeProjection.create(bufferSerializer) + + override def serialize(buffer: Any): Array[Byte] = { + bufferObjToRow(InternalRow(buffer)).getBytes } - override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") + private lazy val bufferRow = new UnsafeRow(bufferSerializer.length) + private lazy val bufferRowToObject = GenerateSafeProjection.generate(bufferDeserializer :: Nil) + + override def deserialize(storageFormat: Array[Byte]): Any = { + bufferRow.pointTo(storageFormat, storageFormat.length) + bufferRowToObject(bufferRow).get(0, ObjectType(classOf[Any])) + } + + override def withNewMutableAggBufferOffset( + newMutableAggBufferOffset: Int): ComplexTypedAggregateExpression = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset( + newInputAggBufferOffset: Int): ComplexTypedAggregateExpression = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def withInputInfo( + deser: Expression, + cls: Class[_], + schema: StructType): TypedAggregateExpression = { + copy(inputDeserializer = Some(deser), inputClass = Some(cls), inputSchema = Some(schema)) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index c11605d..66d94d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -192,14 +192,14 @@ object DatasetBenchmark { benchmark2.run() /* - OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 3.10.0-327.18.2.el7.x86_64 - Intel Xeon E3-12xx v2 (Ivy Bridge) + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.12.1 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ - RDD sum 1420 / 1523 70.4 14.2 1.0X - DataFrame sum 31 / 49 3214.3 0.3 45.6X - Dataset sum using Aggregator 3216 / 3257 31.1 32.2 0.4X - Dataset complex Aggregator 7948 / 8461 12.6 79.5 0.2X + RDD sum 1913 / 1942 52.3 19.1 1.0X + DataFrame sum 46 / 61 2157.7 0.5 41.3X + Dataset sum using Aggregator 4656 / 4758 21.5 46.6 0.4X + Dataset complex Aggregator 6636 / 7039 15.1 66.4 0.3X */ benchmark3.run() } http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index 70c3951..b76f168 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -240,7 +240,7 @@ object TypedImperativeAggregateSuite { new MaxValue(Int.MinValue) } - override def update(buffer: MaxValue, input: InternalRow): Unit = { + override def update(buffer: MaxValue, input: InternalRow): MaxValue = { child.eval(input) match { case inputValue: Int => if (inputValue > buffer.value) { @@ -249,13 +249,15 @@ object TypedImperativeAggregateSuite { } case null => // skip } + buffer } - override def merge(bufferMax: MaxValue, inputMax: MaxValue): Unit = { + override def merge(bufferMax: MaxValue, inputMax: MaxValue): MaxValue = { if (inputMax.value > bufferMax.value) { bufferMax.value = inputMax.value bufferMax.isValueSet = bufferMax.isValueSet || inputMax.isValueSet } + bufferMax } override def eval(bufferMax: MaxValue): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 26dc372..fcefd69 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -397,17 +397,19 @@ private[hive] case class HiveUDAFFunction( @transient private lazy val inputProjection = UnsafeProjection.create(children) - override def update(buffer: AggregationBuffer, input: InternalRow): Unit = { + override def update(buffer: AggregationBuffer, input: InternalRow): AggregationBuffer = { partial1ModeEvaluator.iterate( buffer, wrap(inputProjection(input), inputWrappers, cached, inputDataTypes)) + buffer } - override def merge(buffer: AggregationBuffer, input: AggregationBuffer): Unit = { + override def merge(buffer: AggregationBuffer, input: AggregationBuffer): AggregationBuffer = { // The 2nd argument of the Hive `GenericUDAFEvaluator.merge()` method is an input aggregation // buffer in the 3rd format mentioned in the ScalaDoc of this class. Originally, Hive converts // this `AggregationBuffer`s into this format before shuffling partial aggregation results, and // calls `GenericUDAFEvaluator.terminatePartial()` to do the conversion. partial2ModeEvaluator.merge(buffer, partial1ModeEvaluator.terminatePartial(input)) + buffer } override def eval(buffer: AggregationBuffer): Any = { http://git-wip-us.apache.org/repos/asf/spark/blob/8a7db8a6/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala index d27287b..aaf1db6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -46,14 +46,16 @@ case class TestingTypedCount( override def createAggregationBuffer(): State = TestingTypedCount.State(0L) - override def update(buffer: State, input: InternalRow): Unit = { + override def update(buffer: State, input: InternalRow): State = { if (child.eval(input) != null) { buffer.count += 1 } + buffer } - override def merge(buffer: State, input: State): Unit = { + override def merge(buffer: State, input: State): State = { buffer.count += input.count + buffer } override def eval(buffer: State): Any = buffer.count --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
