Github user liancheng commented on a diff in the pull request: https://github.com/apache/spark/pull/15590#discussion_r84760919 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala --- @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering} +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +class ObjectAggregationIterator( + outputAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + originalInputAttributes: Seq[Attribute], + inputRows: Iterator[InternalRow], + fallbackCountThreshold: Int) + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { + + // Indicates whether we have fallen back to sort-based aggregation or not. + private[this] var sortBased: Boolean = false + + private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ + + // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers + private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _, _) => + agg.copy(mode = Final) + case other => other + } + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + generateProcessRow(newExpressions, newFunctions, newInputAttributes) + } + + // A safe projection used to do deep clone of input rows to prevent false sharing. + private[this] val safeProjection: Projection = + FromUnsafeProjection(outputAttributes.map(_.dataType)) + + /** + * Start processing input rows. + */ + processInputs() + + override final def hasNext: Boolean = { + aggBufferIterator.hasNext + } + + override final def next(): UnsafeRow = { + val entry = aggBufferIterator.next() + generateOutput(entry.groupingKey, entry.aggregationBuffer) + } + + /** + * Generate an output row when there is no input and there is no grouping expression. + */ + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + val defaultAggregationBuffer = createNewAggregationBuffer() + generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer) + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } + + // Creates a new aggregation buffer and initializes buffer values. This function should only be + // called under two cases: + // + // - when creating aggregation buffer for a new group in the hash map, and + // - when creating the re-used buffer for sort-based aggregation + private def createNewAggregationBuffer(): SpecificInternalRow = { + val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType)) + val buffer = new SpecificInternalRow(bufferFieldTypes) + initAggregationBuffer(buffer) + buffer + } + + private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = { + // Initializes declarative aggregates' buffer values + expressionAggInitialProjection.target(buffer)(EmptyRow) + // Initializes imperative aggregates' buffer values + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + } + + private def getAggregationBufferByKey( + hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = { + var aggBuffer = hashMap.getAggregationBuffer(groupingKey) + + if (aggBuffer == null) { + aggBuffer = createNewAggregationBuffer() + hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer) + } + + aggBuffer + } + + // This function is used to read and process input rows. When processing input rows, it first uses + // hash-based aggregation by putting groups and their buffers in `hashMap`. If `hashMap` grows too + // large, it sorts the contents, spills them to disk, and creates a new map. At last, all sorted + // spills are merged together for sort-based aggregation. + private def processInputs(): Unit = { + // In-memory map to store aggregation buffer for hash-based aggregation. + val hashMap = new ObjectAggregationMap() + + // If in-memory map is unable to stores all aggregation buffer, fallback to sort-based + // aggregation backed by sorted physical storage. + var sortBasedAggregationStore: SortBasedAggregationStore = null + + if (groupingExpressions.isEmpty) { + // If there is no grouping expressions, we can just reuse the same buffer over and over again. + val groupingKey = groupingProjection.apply(null) + val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) + while (inputRows.hasNext) { + val newInput = safeProjection(inputRows.next()) + processRow(buffer, newInput) + } + } else { + while (inputRows.hasNext && !sortBased) { + val newInput = safeProjection(inputRows.next()) + val groupingKey = groupingProjection.apply(newInput) + val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) + processRow(buffer, newInput) + + // The the hash map gets too large, makes a sorted spill and clear the map. + if (hashMap.size >= fallbackCountThreshold) { + logInfo( + s"Aggregation hash map reaches threshold " + + s"capacity ($fallbackCountThreshold entries), spilling and falling back to sort" + + s" based aggregation. You may change the threshold by adjust option " + + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key + ) + + // Falls back to sort-based aggregation + sortBased = true + + } + } + + if (sortBased) { + val sortIteratorFromHashMap = hashMap + .dumpToExternalSorter(groupingAttributes, aggregateFunctions) + .sortedIterator() + sortBasedAggregationStore = new SortBasedAggregationStore( + sortIteratorFromHashMap, + StructType.fromAttributes(originalInputAttributes), + StructType.fromAttributes(groupingAttributes), + processRow, + mergeAggregationBuffers, + createNewAggregationBuffer()) + + while (inputRows.hasNext) { + // NOTE: The input row is always UnsafeRow + val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow] + val groupingKey = groupingProjection.apply(unsafeInputRow) + sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow) + } + } + } + + if (sortBased) { + aggBufferIterator = sortBasedAggregationStore.destructiveIterator() + } else { + aggBufferIterator = hashMap.iterator + } + } +} + +/** + * Aggregation store used to do sort-based aggregation. + * + * @param initialAggBufferIterator iterator that points to sorted input aggregation buffers. The + * aggregation buffers in this iterator will be merged to + * SortBasedAggregationStore. + * @param inputSchema The schema of input row + * @param groupingSchema The schema of grouping key + * @param processRow Function to update the aggregation buffer with input rows. + * @param mergeAggregationBuffers Function to merge the aggregation buffer with input aggregation + * buffer. + * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer + */ +class SortBasedAggregationStore( --- End diff -- Yea, that's a bad name. Renamed to `SortBasedAggregator`.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org