Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/14753#discussion_r76022712 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala --- @@ -389,3 +389,145 @@ abstract class DeclarativeAggregate def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } } + +/** + * Aggregation function which allows **arbitrary** user-defined java object to be used as internal + * aggregation buffer object. + * + * {{{ + * aggregation buffer for normal aggregation function `avg` + * | + * v + * +--------------+---------------+-----------------------------------+ + * | sum1 (Long) | count1 (Long) | generic user-defined java objects | + * +--------------+---------------+-----------------------------------+ + * ^ + * | + * Aggregation buffer object for `TypedImperativeAggregate` aggregation function + * }}} + * + * Work flow (Partial mode aggregate at Mapper side, and Final mode aggregate at Reducer side): + * + * Stage 1: Partial aggregate at Mapper side: + * + * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation + * buffer object. + * 2. Upon each input row, the framework calls + * `update(buffer: T, input: InternalRow): Unit` to update the aggregation buffer object T. + * 3. After processing all rows of current group (group by key), the framework will serialize + * aggregation buffer object T to storage format (Array[Byte]) and persist the Array[Byte] + * to disk if needed. + * 4. The framework moves on to next group, until all groups have been processed. + * + * Shuffling exchange data to Reducer tasks... + * + * Stage 2: Final mode aggregate at Reducer side: + * + * 1. The framework calls `createAggregationBuffer(): T` to create an empty internal aggregation + * buffer object (type T) for merging. + * 2. For each aggregation output of Stage 1, The framework de-serializes the storage + * format (Array[Byte]) and produces one input aggregation object (type T). + * 3. For each input aggregation object, the framework calls `merge(buffer: T, input: T): Unit` + * to merge the input aggregation object into aggregation buffer object. + * 4. After processing all input aggregation objects of current group (group by key), the framework + * calls method `eval(buffer: T)` to generate the final output for this group. + * 5. The framework moves on to next group, until all groups have been processed. + * + * NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation, + * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation + * buffer's storage format, which is not supported by hash based aggregation. Hash based + * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have + * fixed length and can be mutated in place in UnsafeRow) + */ +abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { + + /** + * Creates an empty aggregation buffer object. This is called before processing each key group + * (group by key). + * + * @return an aggregation buffer object + */ + def createAggregationBuffer(): T + + /** + * In-place updates the aggregation buffer object with an input row. buffer = buffer + input. + * This is typically called when doing Partial or Complete mode aggregation. + * + * @param buffer The aggregation buffer object. + * @param input an input row + */ + def update(buffer: T, input: InternalRow): Unit + + /** + * Merges an input aggregation object into aggregation buffer object. buffer = buffer + input. + * This is typically called when doing PartialMerge or Final mode aggregation. + * + * @param buffer the aggregation buffer object used to store the aggregation result. + * @param input an input aggregation object. Input aggregation object can be produced by + * de-serializing the partial aggregate's output from Mapper side. + */ + def merge(buffer: T, input: T): Unit + + /** + * Generates the final aggregation result value for current key group with the aggregation buffer + * object. + * + * @param buffer aggregation buffer object. + * @return The aggregation result of current key group + */ + def eval(buffer: T): Any + + /** Serializes the aggregation buffer object T to Array[Byte] */ + def serialize(buffer: T): Array[Byte] + + /** De-serializes the serialized format Array[Byte], and produces aggregation buffer object T */ + def deserialize(storageFormat: Array[Byte]): T + + final override def initialize(buffer: MutableRow): Unit = { + val bufferObject = createAggregationBuffer() + buffer.update(mutableAggBufferOffset, bufferObject) + } + + final override def update(buffer: MutableRow, input: InternalRow): Unit = { + val bufferObject = getField[T](buffer, mutableAggBufferOffset) + update(bufferObject, input) + } + + final override def merge(buffer: MutableRow, inputBuffer: InternalRow): Unit = { + val bufferObject = getField[T](buffer, mutableAggBufferOffset) + // The inputBuffer stores serialized aggregation buffer object produced by partial aggregate + val inputObject = deserialize(getField[Array[Byte]](inputBuffer, inputAggBufferOffset)) + merge(bufferObject, inputObject) + } + + final override def eval(buffer: InternalRow): Any = { + val bufferObject = getField[T](buffer, mutableAggBufferOffset) + eval(bufferObject) + } + + private[this] val anyObjectType = ObjectType(classOf[AnyRef]) + private def getField[U](input: InternalRow, fieldIndex: Int): U = { + input.get(fieldIndex, anyObjectType).asInstanceOf[U] + } + + final override lazy val aggBufferAttributes: Seq[AttributeReference] = { + // Underlying storage type for the aggregation buffer object + Seq(AttributeReference("buf", BinaryType)()) + } + + final override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + final override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + + /** + * In-place replaces the aggregation buffer object stored at buffer's index + * `mutableAggBufferOffset`, with SparkSQL internally supported underlying storage format. + * + * The framework calls this method every time after updating/merging one group (group by key). --- End diff -- `... every time before we shuffle the buffer object`?
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org