Github user liancheng commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15590#discussion_r84760919
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
 ---
    @@ -0,0 +1,323 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.aggregate
    +
    +import org.apache.spark.{SparkEnv, TaskContext}
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.expressions._
    +import org.apache.spark.sql.catalyst.expressions.aggregate._
    +import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, 
GenerateOrdering}
    +import org.apache.spark.sql.execution.UnsafeKVExternalSorter
    +import org.apache.spark.sql.internal.SQLConf
    +import org.apache.spark.sql.types.StructType
    +import org.apache.spark.unsafe.KVIterator
    +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
    +
    +class ObjectAggregationIterator(
    +    outputAttributes: Seq[Attribute],
    +    groupingExpressions: Seq[NamedExpression],
    +    aggregateExpressions: Seq[AggregateExpression],
    +    aggregateAttributes: Seq[Attribute],
    +    initialInputBufferOffset: Int,
    +    resultExpressions: Seq[NamedExpression],
    +    newMutableProjection: (Seq[Expression], Seq[Attribute]) => 
MutableProjection,
    +    originalInputAttributes: Seq[Attribute],
    +    inputRows: Iterator[InternalRow],
    +    fallbackCountThreshold: Int)
    +  extends AggregationIterator(
    +    groupingExpressions,
    +    originalInputAttributes,
    +    aggregateExpressions,
    +    aggregateAttributes,
    +    initialInputBufferOffset,
    +    resultExpressions,
    +    newMutableProjection) with Logging {
    +
    +  // Indicates whether we have fallen back to sort-based aggregation or 
not.
    +  private[this] var sortBased: Boolean = false
    +
    +  private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _
    +
    +  // Hacking the aggregation mode to call AggregateFunction.merge to merge 
two aggregation buffers
    +  private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit 
= {
    +    val newExpressions = aggregateExpressions.map {
    +      case agg @ AggregateExpression(_, Partial, _, _) =>
    +        agg.copy(mode = PartialMerge)
    +      case agg @ AggregateExpression(_, Complete, _, _) =>
    +        agg.copy(mode = Final)
    +      case other => other
    +    }
    +    val newFunctions = initializeAggregateFunctions(newExpressions, 0)
    +    val newInputAttributes = 
newFunctions.flatMap(_.inputAggBufferAttributes)
    +    generateProcessRow(newExpressions, newFunctions, newInputAttributes)
    +  }
    +
    +  // A safe projection used to do deep clone of input rows to prevent 
false sharing.
    +  private[this] val safeProjection: Projection =
    +    FromUnsafeProjection(outputAttributes.map(_.dataType))
    +
    +  /**
    +   * Start processing input rows.
    +   */
    +  processInputs()
    +
    +  override final def hasNext: Boolean = {
    +    aggBufferIterator.hasNext
    +  }
    +
    +  override final def next(): UnsafeRow = {
    +    val entry = aggBufferIterator.next()
    +    generateOutput(entry.groupingKey, entry.aggregationBuffer)
    +  }
    +
    +  /**
    +   * Generate an output row when there is no input and there is no 
grouping expression.
    +   */
    +  def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
    +    if (groupingExpressions.isEmpty) {
    +      val defaultAggregationBuffer = createNewAggregationBuffer()
    +      generateOutput(UnsafeRow.createFromByteArray(0, 0), 
defaultAggregationBuffer)
    +    } else {
    +      throw new IllegalStateException(
    +        "This method should not be called when groupingExpressions is not 
empty.")
    +    }
    +  }
    +
    +  // Creates a new aggregation buffer and initializes buffer values. This 
function should only be
    +  // called under two cases:
    +  //
    +  //  - when creating aggregation buffer for a new group in the hash map, 
and
    +  //  - when creating the re-used buffer for sort-based aggregation
    +  private def createNewAggregationBuffer(): SpecificInternalRow = {
    +    val bufferFieldTypes = 
aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType))
    +    val buffer = new SpecificInternalRow(bufferFieldTypes)
    +    initAggregationBuffer(buffer)
    +    buffer
    +  }
    +
    +  private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = {
    +    // Initializes declarative aggregates' buffer values
    +    expressionAggInitialProjection.target(buffer)(EmptyRow)
    +    // Initializes imperative aggregates' buffer values
    +    aggregateFunctions.collect { case f: ImperativeAggregate => f 
}.foreach(_.initialize(buffer))
    +  }
    +
    +  private def getAggregationBufferByKey(
    +    hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = {
    +    var aggBuffer = hashMap.getAggregationBuffer(groupingKey)
    +
    +    if (aggBuffer == null) {
    +      aggBuffer = createNewAggregationBuffer()
    +      hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer)
    +    }
    +
    +    aggBuffer
    +  }
    +
    +  // This function is used to read and process input rows. When processing 
input rows, it first uses
    +  // hash-based aggregation by putting groups and their buffers in 
`hashMap`. If `hashMap` grows too
    +  // large, it sorts the contents, spills them to disk, and creates a new 
map. At last, all sorted
    +  // spills are merged together for sort-based aggregation.
    +  private def processInputs(): Unit = {
    +    // In-memory map to store aggregation buffer for hash-based 
aggregation.
    +    val hashMap = new ObjectAggregationMap()
    +
    +    // If in-memory map is unable to stores all aggregation buffer, 
fallback to sort-based
    +    // aggregation backed by sorted physical storage.
    +    var sortBasedAggregationStore: SortBasedAggregationStore = null
    +
    +    if (groupingExpressions.isEmpty) {
    +      // If there is no grouping expressions, we can just reuse the same 
buffer over and over again.
    +      val groupingKey = groupingProjection.apply(null)
    +      val buffer: InternalRow = getAggregationBufferByKey(hashMap, 
groupingKey)
    +      while (inputRows.hasNext) {
    +        val newInput = safeProjection(inputRows.next())
    +        processRow(buffer, newInput)
    +      }
    +    } else {
    +      while (inputRows.hasNext && !sortBased) {
    +        val newInput = safeProjection(inputRows.next())
    +        val groupingKey = groupingProjection.apply(newInput)
    +        val buffer: InternalRow = getAggregationBufferByKey(hashMap, 
groupingKey)
    +        processRow(buffer, newInput)
    +
    +        // The the hash map gets too large, makes a sorted spill and clear 
the map.
    +        if (hashMap.size >= fallbackCountThreshold) {
    +          logInfo(
    +            s"Aggregation hash map reaches threshold " +
    +              s"capacity ($fallbackCountThreshold entries), spilling and 
falling back to sort" +
    +              s" based aggregation. You may change the threshold by adjust 
option " +
    +              SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key
    +          )
    +
    +          // Falls back to sort-based aggregation
    +          sortBased = true
    +
    +        }
    +      }
    +
    +      if (sortBased) {
    +        val sortIteratorFromHashMap = hashMap
    +          .dumpToExternalSorter(groupingAttributes, aggregateFunctions)
    +          .sortedIterator()
    +        sortBasedAggregationStore = new SortBasedAggregationStore(
    +          sortIteratorFromHashMap,
    +          StructType.fromAttributes(originalInputAttributes),
    +          StructType.fromAttributes(groupingAttributes),
    +          processRow,
    +          mergeAggregationBuffers,
    +          createNewAggregationBuffer())
    +
    +        while (inputRows.hasNext) {
    +          // NOTE: The input row is always UnsafeRow
    +          val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow]
    +          val groupingKey = groupingProjection.apply(unsafeInputRow)
    +          sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow)
    +        }
    +      }
    +    }
    +
    +    if (sortBased) {
    +      aggBufferIterator = sortBasedAggregationStore.destructiveIterator()
    +    } else {
    +      aggBufferIterator = hashMap.iterator
    +    }
    +  }
    +}
    +
    +/**
    + * Aggregation store used to do sort-based aggregation.
    + *
    + * @param initialAggBufferIterator iterator that points to sorted input 
aggregation buffers. The
    + *                                 aggregation buffers in this iterator 
will be merged to
    + *                                 SortBasedAggregationStore.
    + * @param inputSchema  The schema of input row
    + * @param groupingSchema The schema of grouping key
    + * @param processRow  Function to update the aggregation buffer with input 
rows.
    + * @param mergeAggregationBuffers Function to merge the aggregation buffer 
with input aggregation
    + *                               buffer.
    + * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
    + */
    +class SortBasedAggregationStore(
    --- End diff --
    
    Yea, that's a bad name. Renamed to `SortBasedAggregator`.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to