Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7954#discussion_r36387320
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 ---
    @@ -0,0 +1,663 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.aggregate
    +
    +import org.apache.spark.unsafe.KVIterator
    +import org.apache.spark.{Logging, SparkEnv, TaskContext}
    +import org.apache.spark.sql.catalyst.expressions._
    +import org.apache.spark.sql.catalyst.expressions.aggregate._
    +import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
    +import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, 
UnsafeFixedWidthAggregationMap}
    +import org.apache.spark.sql.types.StructType
    +
    +/**
    + * An iterator used to evaluate aggregate functions. It operates on 
[[UnsafeRow]]s.
    + *
    + * This iterator first uses hash-based aggregation to process input rows. 
It uses
    + * a hash map to store groups and their corresponding aggregation buffers. 
If we
    + * this map cannot allocate memory from 
[[org.apache.spark.shuffle.ShuffleMemoryManager]],
    + * it switches to sort-based aggregation. The process of the switch has 
the following step:
    + *  - Step 1: Sort all entries of the hash map based on values of grouping 
expressions and
    + *            spill them to disk.
    + *  - Step 2: Create a external sorter based on the spilled sorted map 
entries.
    + *  - Step 3: Redirect all input rows to the external sorter.
    + *  - Step 4: Get a sorted [[KVIterator]] from the external sorter.
    + *  - Step 5: Initialize sort-based aggregation.
    + * Then, this iterator works in the way of sort-based aggregation.
    + *
    + * The code of this class is organized as follows:
    + *  - Part 1: Initializing aggregate functions.
    + *  - Part 2: Methods and fields used by setting aggregation buffer values,
    + *            processing input rows from inputIter, and generating output
    + *            rows.
    + *  - Part 3: Methods and fields used by hash-based aggregation.
    + *  - Part 4: The function used to switch this iterator from hash-based
    + *            aggregation to sort-based aggregation.
    + *  - Part 5: Methods and fields used by sort-based aggregation.
    + *  - Part 6: Loads input and process input rows.
    + *  - Part 7: Public methods of this iterator.
    + *  - Part 8: A utility function used to generate a result when there is no
    + *            input and there is no grouping expression.
    + *
    + * @param groupingExpressions
    + *   expressions for grouping keys
    + * @param nonCompleteAggregateExpressions
    + *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode 
[[Partial]],
    + *   [[PartialMerge]], or [[Final]].
    + * @param completeAggregateExpressions
    + *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode 
[[Complete]].
    + * @param initialInputBufferOffset
    + *   If this iterator is used to handle functions with mode 
[[PartialMerge]] or [[Final]].
    + *   The input rows have the format of `grouping keys + aggregation 
buffer`.
    + *   This offset indicates the starting position of aggregation buffer in 
a input row.
    + * @param resultExpressions
    + *   expressions for generating output rows.
    + * @param newMutableProjection
    + *   the function used to create mutable projections.
    + * @param originalInputAttributes
    + *   attributes of representing input rows from `inputIter`.
    + * @param inputIter
    + *   the iterator containing input [[UnsafeRow]]s.
    + */
    +class TungstenAggregationIterator(
    +    groupingExpressions: Seq[NamedExpression],
    +    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
    +    completeAggregateExpressions: Seq[AggregateExpression2],
    +    initialInputBufferOffset: Int,
    +    resultExpressions: Seq[NamedExpression],
    +    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
    +    originalInputAttributes: Seq[Attribute],
    +    inputIter: Iterator[UnsafeRow],
    +    testFallbackStartsAt: Option[Int])
    +  extends Iterator[UnsafeRow] with Logging {
    +
    +  
///////////////////////////////////////////////////////////////////////////
    +  // Part 1: Initializing aggregate functions.
    +  
///////////////////////////////////////////////////////////////////////////
    +
    +  // A Seq containing all AggregateExpressions.
    +  // It is important that all AggregateExpressions with the mode Partial, 
PartialMerge or Final
    +  // are at the beginning of the allAggregateExpressions.
    +  private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
    +    nonCompleteAggregateExpressions ++ completeAggregateExpressions
    +
    +  // Check to make sure we do not have more than three modes in our 
AggregateExpressions.
    +  // If we have, users are hitting a bug and we throw an 
IllegalStateException.
    +  if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
    +    throw new IllegalStateException(
    +      s"$allAggregateExpressions should have no more than 2 kinds of 
modes.")
    +  }
    +
    +  //
    +  // The modes of AggregateExpressions. Right now, we can handle the 
following mode:
    +  //  - Partial-only:
    +  //      All AggregateExpressions have the mode of Partial.
    +  //      For this case, aggregationMode is (Some(Partial), None).
    +  //  - PartialMerge-only:
    +  //      All AggregateExpressions have the mode of PartialMerge).
    +  //      For this case, aggregationMode is (Some(PartialMerge), None).
    +  //  - Final-only:
    +  //      All AggregateExpressions have the mode of Final.
    +  //      For this case, aggregationMode is (Some(Final), None).
    +  //  - Final-Complete:
    +  //      Some AggregateExpressions have the mode of Final and
    +  //      others have the mode of Complete. For this case,
    +  //      aggregationMode is (Some(Final), Some(Complete)).
    +  //  - Complete-only:
    +  //      nonCompleteAggregateExpressions is empty and we have 
AggregateExpressions
    +  //      with mode Complete in completeAggregateExpressions. For this 
case,
    +  //      aggregationMode is (None, Some(Complete)).
    +  //  - Grouping-only:
    +  //      There is no AggregateExpression. For this case, AggregationMode 
is (None,None).
    +  //
    +  private[this] var aggregationMode: (Option[AggregateMode], 
Option[AggregateMode]) = {
    +    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
    +      completeAggregateExpressions.map(_.mode).distinct.headOption
    +  }
    +
    +  // All aggregate functions. TungstenAggregationIterator only handles 
AlgebraicAggregates.
    +  // If there is any functions that is not an AlgebraicAggregate, we throw 
an
    +  // IllegalStateException.
    +  private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
    +    if 
(!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]))
 {
    +      throw new IllegalStateException(
    +        "Only AlgebraicAggregates should be passed in 
TungstenAggregationIterator.")
    +    }
    +
    +    allAggregateExpressions
    +      .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
    +      .toArray
    +  }
    +
    +  
///////////////////////////////////////////////////////////////////////////
    +  // Part 2: Methods and fields used by setting aggregation buffer values,
    +  //         processing input rows from inputIter, and generating output
    +  //         rows.
    +  
///////////////////////////////////////////////////////////////////////////
    +
    +  // The projection used to initialize buffer values.
    +  private[this] val algebraicInitialProjection: MutableProjection = {
    +    val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
    +    newMutableProjection(initExpressions, Nil)()
    +  }
    +
    +  // Creates a new aggregation buffer and initializes buffer values.
    +  // This functions should be only called at most three times (when we 
create the hash map,
    +  // when we switch to sort-based aggregation, and when we create the 
re-used buffer for
    +  // sort-based aggregation).
    +  private def createNewBuffer(): UnsafeRow = {
    +    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
    +    val bufferRowSize: Int = bufferSchema.length
    +
    +    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
    +    val unsafeProjection =
    +      UnsafeProjection.create(bufferSchema.map(_.dataType))
    +    val buffer = unsafeProjection.apply(genericMutableBuffer)
    +    algebraicInitialProjection.target(buffer)(EmptyRow)
    +    buffer
    +  }
    +
    +  // Creates a function used to process a row based on the given 
inputAttributes.
    +  private def generateProcessRow(
    +      inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
    +
    +    val aggregationBufferAttributes = 
allAggregateFunctions.flatMap(_.bufferAttributes)
    +    val aggregationBufferSchema = 
StructType.fromAttributes(aggregationBufferAttributes)
    +    val inputSchema = StructType.fromAttributes(inputAttributes)
    +    val unsafeRowJoiner =
    +      GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
    +
    +    aggregationMode match {
    +      // Partial-only
    +      case (Some(Partial), None) =>
    +        val updateExpressions = 
allAggregateFunctions.flatMap(_.updateExpressions)
    +        val algebraicUpdateProjection =
    +          newMutableProjection(updateExpressions, 
aggregationBufferAttributes ++ inputAttributes)()
    +
    +        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
    +          algebraicUpdateProjection.target(currentBuffer)
    +          algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, 
row))
    --- End diff --
    
    this will create too much memory copying -- and might explain the slow 
down. I was thinking about only doing the unsafe row joining if we are directly 
outputting them into an exchange (i.e. partial aggregation).


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to