[SPARK-9630] [SQL] Clean up new aggregate operators (SPARK-9240 follow up)

This is the followup of https://github.com/apache/spark/pull/7813. It renames 
`HybridUnsafeAggregationIterator` to `TungstenAggregationIterator` and makes it 
only work with `UnsafeRow`. Also, I add a `TungstenAggregate` that uses 
`TungstenAggregationIterator` and make `SortBasedAggregate` (renamed from 
`SortBasedAggregate`) only works with `SafeRow`.

Author: Yin Huai <yh...@databricks.com>

Closes #7954 from yhuai/agg-followUp and squashes the following commits:

4d2f4fc [Yin Huai] Add comments and free map.
0d7ddb9 [Yin Huai] Add TungstenAggregationQueryWithControlledFallbackSuite to 
test fall back process.
91d69c2 [Yin Huai] Rename UnsafeHybridAggregationIterator to  
TungstenAggregateIteraotr and make it only work with UnsafeRow.

(cherry picked from commit 3504bf3aa9f7b75c0985f04ce2944833d8c5b5bd)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/272e8834
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/272e8834
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/272e8834

Branch: refs/heads/branch-1.5
Commit: 272e88342540328a24702f07a730b156657bd3be
Parents: 9806872
Author: Yin Huai <yh...@databricks.com>
Authored: Thu Aug 6 15:04:44 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Aug 6 15:04:53 2015 -0700

----------------------------------------------------------------------
 .../expressions/aggregate/functions.scala       |  14 +-
 .../spark/sql/execution/SparkStrategies.scala   |   3 +-
 .../sql/execution/UnsafeRowSerializer.scala     |  20 +-
 .../sql/execution/aggregate/Aggregate.scala     | 182 -----
 .../aggregate/SortBasedAggregate.scala          | 103 +++
 .../SortBasedAggregationIterator.scala          |  26 -
 .../execution/aggregate/TungstenAggregate.scala | 102 +++
 .../aggregate/TungstenAggregationIterator.scala | 667 +++++++++++++++++++
 .../UnsafeHybridAggregationIterator.scala       | 372 -----------
 .../spark/sql/execution/aggregate/utils.scala   | 260 ++++++--
 .../org/apache/spark/sql/SQLQuerySuite.scala    |   2 +-
 .../hive/execution/AggregationQuerySuite.scala  | 104 ++-
 12 files changed, 1192 insertions(+), 663 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 88fb516..a73024d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -31,8 +31,11 @@ case class Average(child: Expression) extends 
AlgebraicAggregate {
   override def dataType: DataType = resultType
 
   // Expected input data type.
-  // TODO: Once we remove the old code path, we can use our analyzer to cast 
NullType
-  // to the default data type of the NumericType.
+  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
+  // new version at planning time (after analysis phase). For now, NullType is 
added at here
+  // to make it resolved when we have cases like `select avg(null)`.
+  // We can use our analyzer to cast NullType to the default data type of the 
NumericType once
+  // we remove the old aggregate functions. Then, we will not need NullType at 
here.
   override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType, NullType))
 
   private val resultType = child.dataType match {
@@ -256,12 +259,19 @@ case class Sum(child: Expression) extends 
AlgebraicAggregate {
   override def dataType: DataType = resultType
 
   // Expected input data type.
+  // TODO: Right now, we replace old aggregate functions (based on 
AggregateExpression1) to the
+  // new version at planning time (after analysis phase). For now, NullType is 
added at here
+  // to make it resolved when we have cases like `select sum(null)`.
+  // We can use our analyzer to cast NullType to the default data type of the 
NumericType once
+  // we remove the old aggregate functions. Then, we will not need NullType at 
here.
   override def inputTypes: Seq[AbstractDataType] =
     Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
 
   private val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
+    // TODO: Remove this line once we remove the NullType from inputTypes.
+    case NullType => IntegerType
     case _ => child.dataType
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a730ffb..c5aaebe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
             // aggregate function to the corresponding attribute of the 
function.
             val aggregateFunctionMap = aggregateExpressions.map { agg =>
               val aggregateFunction = agg.aggregateFunction
+              val attribtue = Alias(aggregateFunction, 
aggregateFunction.toString)().toAttribute
               (aggregateFunction, agg.isDistinct) ->
-                Alias(aggregateFunction, 
aggregateFunction.toString)().toAttribute
+                (aggregateFunction -> attribtue)
             }.toMap
 
             val (functionsWithDistinct, functionsWithoutDistinct) =

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 16498da..39f8f99 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
+import java.io._
 import java.nio.ByteBuffer
 
 import scala.reflect.ClassTag
@@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
    */
   override def serializeStream(out: OutputStream): SerializationStream = new 
SerializationStream {
     private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
+    // When `out` is backed by ChainedBufferOutputStream, we will get an
+    // UnsupportedOperationException when we call dOut.writeInt because it 
internally calls
+    // ChainedBufferOutputStream's write(b: Int), which is not supported.
+    // To workaround this issue, we create an array for sorting the int value.
+    // To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
+    // run SparkSqlSerializer2SortMergeShuffleSuite.
+    private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
     private[this] val dOut: DataOutputStream = new DataOutputStream(out)
 
     override def writeValue[T: ClassTag](value: T): SerializationStream = {
       val row = value.asInstanceOf[UnsafeRow]
-      dOut.writeInt(row.getSizeInBytes)
+      val size = row.getSizeInBytes
+      // This part is based on DataOutputStream's writeInt.
+      // It is for dOut.writeInt(row.getSizeInBytes).
+      intBuffer(0) = ((size >>> 24) & 0xFF).toByte
+      intBuffer(1) = ((size >>> 16) & 0xFF).toByte
+      intBuffer(2) = ((size >>> 8) & 0xFF).toByte
+      intBuffer(3) = ((size >>> 0) & 0xFF).toByte
+      dOut.write(intBuffer, 0, 4)
+
       row.writeToStream(out, writeBuffer)
       this
     }
@@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) 
extends SerializerInst
 
     override def close(): Unit = {
       writeBuffer = null
+      intBuffer = null
       dOut.writeInt(EOF)
       dOut.close()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
deleted file mode 100644
index cf568dc..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/Aggregate.scala
+++ /dev/null
@@ -1,182 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, 
ClusteredDistribution, AllTuples, Distribution}
-import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, 
SparkPlan, UnaryNode}
-import org.apache.spark.sql.types.StructType
-
-/**
- * An Aggregate Operator used to evaluate [[AggregateFunction2]]. Based on the 
data types
- * of the grouping expressions and aggregate functions, it determines if it 
uses
- * sort-based aggregation and hybrid (hash-based with sort-based as the 
fallback) to
- * process input rows.
- */
-case class Aggregate(
-    requiredChildDistributionExpressions: Option[Seq[Expression]],
-    groupingExpressions: Seq[NamedExpression],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression2],
-    completeAggregateAttributes: Seq[Attribute],
-    initialInputBufferOffset: Int,
-    resultExpressions: Seq[NamedExpression],
-    child: SparkPlan)
-  extends UnaryNode {
-
-  private[this] val allAggregateExpressions =
-    nonCompleteAggregateExpressions ++ completeAggregateExpressions
-
-  private[this] val hasNonAlgebricAggregateFunctions =
-    
!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate])
-
-  // Use the hybrid iterator if (1) unsafe is enabled, (2) the schemata of
-  // grouping key and aggregation buffer is supported; and (3) all
-  // aggregate functions are algebraic.
-  private[this] val supportsHybridIterator: Boolean = {
-    val aggregationBufferSchema: StructType =
-      StructType.fromAttributes(
-        allAggregateExpressions.flatMap(_.aggregateFunction.bufferAttributes))
-    val groupKeySchema: StructType =
-      StructType.fromAttributes(groupingExpressions.map(_.toAttribute))
-
-    val schemaSupportsUnsafe: Boolean =
-      
UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
 &&
-        UnsafeProjection.canSupport(groupKeySchema)
-
-    // TODO: Use the hybrid iterator for non-algebric aggregate functions.
-    sqlContext.conf.unsafeEnabled && schemaSupportsUnsafe && 
!hasNonAlgebricAggregateFunctions
-  }
-
-  // We need to use sorted input if we have grouping expressions, and
-  // we cannot use the hybrid iterator or the hybrid is disabled.
-  private[this] val requiresSortedInput: Boolean = {
-    groupingExpressions.nonEmpty && !supportsHybridIterator
-  }
-
-  override def canProcessUnsafeRows: Boolean = 
!hasNonAlgebricAggregateFunctions
-
-  // If result expressions' data types are all fixed length, we generate 
unsafe rows
-  // (We have this requirement instead of check the result of 
UnsafeProjection.canSupport
-  // is because we use a mutable projection to generate the result).
-  override def outputsUnsafeRows: Boolean = {
-    // resultExpressions.map(_.dataType).forall(UnsafeRow.isFixedLength)
-    // TODO: Supports generating UnsafeRows. We can just re-enable the line 
above and fix
-    // any issue we get.
-    false
-  }
-
-  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
-  override def requiredChildDistribution: List[Distribution] = {
-    requiredChildDistributionExpressions match {
-      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
-      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: 
Nil
-      case None => UnspecifiedDistribution :: Nil
-    }
-  }
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
-    if (requiresSortedInput) {
-      // TODO: We should not sort the input rows if they are just in reversed 
order.
-      groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
-    } else {
-      Seq.fill(children.size)(Nil)
-    }
-  }
-
-  override def outputOrdering: Seq[SortOrder] = {
-    if (requiresSortedInput) {
-      // It is possible that the child.outputOrdering starts with the required
-      // ordering expressions (e.g. we require [a] as the sort expression and 
the
-      // child's outputOrdering is [a, b]). We can only guarantee the output 
rows
-      // are sorted by values of groupingExpressions.
-      groupingExpressions.map(SortOrder(_, Ascending))
-    } else {
-      Nil
-    }
-  }
-
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
-    child.execute().mapPartitions { iter =>
-      // Because the constructor of an aggregation iterator will read at least 
the first row,
-      // we need to get the value of iter.hasNext first.
-      val hasInput = iter.hasNext
-      val useHybridIterator =
-        hasInput &&
-          supportsHybridIterator &&
-          groupingExpressions.nonEmpty
-      if (useHybridIterator) {
-        UnsafeHybridAggregationIterator.createFromInputIterator(
-          groupingExpressions,
-          nonCompleteAggregateExpressions,
-          nonCompleteAggregateAttributes,
-          completeAggregateExpressions,
-          completeAggregateAttributes,
-          initialInputBufferOffset,
-          resultExpressions,
-          newMutableProjection _,
-          child.output,
-          iter,
-          outputsUnsafeRows)
-      } else {
-        if (!hasInput && groupingExpressions.nonEmpty) {
-          // This is a grouped aggregate and the input iterator is empty,
-          // so return an empty iterator.
-          Iterator[InternalRow]()
-        } else {
-          val outputIter = 
SortBasedAggregationIterator.createFromInputIterator(
-            groupingExpressions,
-            nonCompleteAggregateExpressions,
-            nonCompleteAggregateAttributes,
-            completeAggregateExpressions,
-            completeAggregateAttributes,
-            initialInputBufferOffset,
-            resultExpressions,
-            newMutableProjection _ ,
-            newProjection _,
-            child.output,
-            iter,
-            outputsUnsafeRows)
-          if (!hasInput && groupingExpressions.isEmpty) {
-            // There is no input and there is no grouping expressions.
-            // We need to output a single row as the output.
-            
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
-          } else {
-            outputIter
-          }
-        }
-      }
-    }
-  }
-
-  override def simpleString: String = {
-    val iterator = if (supportsHybridIterator && groupingExpressions.nonEmpty) 
{
-      classOf[UnsafeHybridAggregationIterator].getSimpleName
-    } else {
-      classOf[SortBasedAggregationIterator].getSimpleName
-    }
-
-    s"""NewAggregate with $iterator ${groupingExpressions} 
${allAggregateExpressions}"""
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
new file mode 100644
index 0000000..ad428ad
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, 
ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, 
SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+case class SortBasedAggregate(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    nonCompleteAggregateAttributes: Seq[Attribute],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  override def outputsUnsafeRows: Boolean = false
+
+  override def canProcessUnsafeRows: Boolean = false
+
+  override def canProcessSafeRows: Boolean = true
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: 
Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+    groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+  }
+
+  override def outputOrdering: Seq[SortOrder] = {
+    groupingExpressions.map(SortOrder(_, Ascending))
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
+    child.execute().mapPartitions { iter =>
+      // Because the constructor of an aggregation iterator will read at least 
the first row,
+      // we need to get the value of iter.hasNext first.
+      val hasInput = iter.hasNext
+      if (!hasInput && groupingExpressions.nonEmpty) {
+        // This is a grouped aggregate and the input iterator is empty,
+        // so return an empty iterator.
+        Iterator[InternalRow]()
+      } else {
+        val outputIter = SortBasedAggregationIterator.createFromInputIterator(
+          groupingExpressions,
+          nonCompleteAggregateExpressions,
+          nonCompleteAggregateAttributes,
+          completeAggregateExpressions,
+          completeAggregateAttributes,
+          initialInputBufferOffset,
+          resultExpressions,
+          newMutableProjection _,
+          newProjection _,
+          child.output,
+          iter,
+          outputsUnsafeRows)
+        if (!hasInput && groupingExpressions.isEmpty) {
+          // There is no input and there is no grouping expressions.
+          // We need to output a single row as the output.
+          
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+        } else {
+          outputIter
+        }
+      }
+    }
+  }
+
+  override def simpleString: String = {
+    val allAggregateExpressions = nonCompleteAggregateExpressions ++ 
completeAggregateExpressions
+    s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}"""
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index 40f6bff..67ebafd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -204,31 +204,5 @@ object SortBasedAggregationIterator {
       newMutableProjection,
       outputsUnsafeRows)
   }
-
-  def createFromKVIterator(
-      groupingKeyAttributes: Seq[Attribute],
-      valueAttributes: Seq[Attribute],
-      inputKVIterator: KVIterator[InternalRow, InternalRow],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
-      outputsUnsafeRows: Boolean): SortBasedAggregationIterator = {
-    new SortBasedAggregationIterator(
-      groupingKeyAttributes,
-      valueAttributes,
-      inputKVIterator,
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows)
-  }
   // scalastyle:on
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
new file mode 100644
index 0000000..5a0b4d4
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, 
ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
+
+case class TungstenAggregate(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryNode {
+
+  override def outputsUnsafeRows: Boolean = true
+
+  override def canProcessUnsafeRows: Boolean = true
+
+  override def canProcessSafeRows: Boolean = false
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+      case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: 
Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  // This is for testing. We force TungstenAggregationIterator to fall back to 
sort-based
+  // aggregation once it has processed a given number of input rows.
+  private val testFallbackStartsAt: Option[Int] = {
+    sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", 
null) match {
+      case null | "" => None
+      case fallbackStartsAt => Some(fallbackStartsAt.toInt)
+    }
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
+    child.execute().mapPartitions { iter =>
+      val hasInput = iter.hasNext
+      if (!hasInput && groupingExpressions.nonEmpty) {
+        // This is a grouped aggregate and the input iterator is empty,
+        // so return an empty iterator.
+        Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
+      } else {
+        val aggregationIterator =
+          new TungstenAggregationIterator(
+            groupingExpressions,
+            nonCompleteAggregateExpressions,
+            completeAggregateExpressions,
+            initialInputBufferOffset,
+            resultExpressions,
+            newMutableProjection,
+            child.output,
+            iter.asInstanceOf[Iterator[UnsafeRow]],
+            testFallbackStartsAt)
+
+        if (!hasInput && groupingExpressions.isEmpty) {
+          
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
+        } else {
+          aggregationIterator
+        }
+      }
+    }
+  }
+
+  override def simpleString: String = {
+    val allAggregateExpressions = nonCompleteAggregateExpressions ++ 
completeAggregateExpressions
+
+    testFallbackStartsAt match {
+      case None => s"TungstenAggregate ${groupingExpressions} 
${allAggregateExpressions}"
+      case Some(fallbackStartsAt) =>
+        s"TungstenAggregateWithControlledFallback ${groupingExpressions} " +
+          s"${allAggregateExpressions} fallbackStartsAt=$fallbackStartsAt"
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
new file mode 100644
index 0000000..b9d44aa
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -0,0 +1,667 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, 
UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An iterator used to evaluate aggregate functions. It operates on 
[[UnsafeRow]]s.
+ *
+ * This iterator first uses hash-based aggregation to process input rows. It 
uses
+ * a hash map to store groups and their corresponding aggregation buffers. If 
we
+ * this map cannot allocate memory from 
[[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * it switches to sort-based aggregation. The process of the switch has the 
following step:
+ *  - Step 1: Sort all entries of the hash map based on values of grouping 
expressions and
+ *            spill them to disk.
+ *  - Step 2: Create a external sorter based on the spilled sorted map entries.
+ *  - Step 3: Redirect all input rows to the external sorter.
+ *  - Step 4: Get a sorted [[KVIterator]] from the external sorter.
+ *  - Step 5: Initialize sort-based aggregation.
+ * Then, this iterator works in the way of sort-based aggregation.
+ *
+ * The code of this class is organized as follows:
+ *  - Part 1: Initializing aggregate functions.
+ *  - Part 2: Methods and fields used by setting aggregation buffer values,
+ *            processing input rows from inputIter, and generating output
+ *            rows.
+ *  - Part 3: Methods and fields used by hash-based aggregation.
+ *  - Part 4: The function used to switch this iterator from hash-based
+ *            aggregation to sort-based aggregation.
+ *  - Part 5: Methods and fields used by sort-based aggregation.
+ *  - Part 6: Loads input and process input rows.
+ *  - Part 7: Public methods of this iterator.
+ *  - Part 8: A utility function used to generate a result when there is no
+ *            input and there is no grouping expression.
+ *
+ * @param groupingExpressions
+ *   expressions for grouping keys
+ * @param nonCompleteAggregateExpressions
+ *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode 
[[Partial]],
+ *   [[PartialMerge]], or [[Final]].
+ * @param completeAggregateExpressions
+ *   [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode 
[[Complete]].
+ * @param initialInputBufferOffset
+ *   If this iterator is used to handle functions with mode [[PartialMerge]] 
or [[Final]].
+ *   The input rows have the format of `grouping keys + aggregation buffer`.
+ *   This offset indicates the starting position of aggregation buffer in a 
input row.
+ * @param resultExpressions
+ *   expressions for generating output rows.
+ * @param newMutableProjection
+ *   the function used to create mutable projections.
+ * @param originalInputAttributes
+ *   attributes of representing input rows from `inputIter`.
+ * @param inputIter
+ *   the iterator containing input [[UnsafeRow]]s.
+ */
+class TungstenAggregationIterator(
+    groupingExpressions: Seq[NamedExpression],
+    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
+    completeAggregateExpressions: Seq[AggregateExpression2],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
+    originalInputAttributes: Seq[Attribute],
+    inputIter: Iterator[UnsafeRow],
+    testFallbackStartsAt: Option[Int])
+  extends Iterator[UnsafeRow] with Logging {
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 1: Initializing aggregate functions.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // A Seq containing all AggregateExpressions.
+  // It is important that all AggregateExpressions with the mode Partial, 
PartialMerge or Final
+  // are at the beginning of the allAggregateExpressions.
+  private[this] val allAggregateExpressions: Seq[AggregateExpression2] =
+    nonCompleteAggregateExpressions ++ completeAggregateExpressions
+
+  // Check to make sure we do not have more than three modes in our 
AggregateExpressions.
+  // If we have, users are hitting a bug and we throw an IllegalStateException.
+  if (allAggregateExpressions.map(_.mode).distinct.length > 2) {
+    throw new IllegalStateException(
+      s"$allAggregateExpressions should have no more than 2 kinds of modes.")
+  }
+
+  //
+  // The modes of AggregateExpressions. Right now, we can handle the following 
mode:
+  //  - Partial-only:
+  //      All AggregateExpressions have the mode of Partial.
+  //      For this case, aggregationMode is (Some(Partial), None).
+  //  - PartialMerge-only:
+  //      All AggregateExpressions have the mode of PartialMerge).
+  //      For this case, aggregationMode is (Some(PartialMerge), None).
+  //  - Final-only:
+  //      All AggregateExpressions have the mode of Final.
+  //      For this case, aggregationMode is (Some(Final), None).
+  //  - Final-Complete:
+  //      Some AggregateExpressions have the mode of Final and
+  //      others have the mode of Complete. For this case,
+  //      aggregationMode is (Some(Final), Some(Complete)).
+  //  - Complete-only:
+  //      nonCompleteAggregateExpressions is empty and we have 
AggregateExpressions
+  //      with mode Complete in completeAggregateExpressions. For this case,
+  //      aggregationMode is (None, Some(Complete)).
+  //  - Grouping-only:
+  //      There is no AggregateExpression. For this case, AggregationMode is 
(None,None).
+  //
+  private[this] var aggregationMode: (Option[AggregateMode], 
Option[AggregateMode]) = {
+    nonCompleteAggregateExpressions.map(_.mode).distinct.headOption ->
+      completeAggregateExpressions.map(_.mode).distinct.headOption
+  }
+
+  // All aggregate functions. TungstenAggregationIterator only handles 
AlgebraicAggregates.
+  // If there is any functions that is not an AlgebraicAggregate, we throw an
+  // IllegalStateException.
+  private[this] val allAggregateFunctions: Array[AlgebraicAggregate] = {
+    if 
(!allAggregateExpressions.forall(_.aggregateFunction.isInstanceOf[AlgebraicAggregate]))
 {
+      throw new IllegalStateException(
+        "Only AlgebraicAggregates should be passed in 
TungstenAggregationIterator.")
+    }
+
+    allAggregateExpressions
+      .map(_.aggregateFunction.asInstanceOf[AlgebraicAggregate])
+      .toArray
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 2: Methods and fields used by setting aggregation buffer values,
+  //         processing input rows from inputIter, and generating output
+  //         rows.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // The projection used to initialize buffer values.
+  private[this] val algebraicInitialProjection: MutableProjection = {
+    val initExpressions = allAggregateFunctions.flatMap(_.initialValues)
+    newMutableProjection(initExpressions, Nil)()
+  }
+
+  // Creates a new aggregation buffer and initializes buffer values.
+  // This functions should be only called at most three times (when we create 
the hash map,
+  // when we switch to sort-based aggregation, and when we create the re-used 
buffer for
+  // sort-based aggregation).
+  private def createNewAggregationBuffer(): UnsafeRow = {
+    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val bufferRowSize: Int = bufferSchema.length
+
+    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
+    val unsafeProjection =
+      UnsafeProjection.create(bufferSchema.map(_.dataType))
+    val buffer = unsafeProjection.apply(genericMutableBuffer)
+    algebraicInitialProjection.target(buffer)(EmptyRow)
+    buffer
+  }
+
+  // Creates a function used to process a row based on the given 
inputAttributes.
+  private def generateProcessRow(
+      inputAttributes: Seq[Attribute]): (UnsafeRow, UnsafeRow) => Unit = {
+
+    val aggregationBufferAttributes = 
allAggregateFunctions.flatMap(_.bufferAttributes)
+    val aggregationBufferSchema = 
StructType.fromAttributes(aggregationBufferAttributes)
+    val inputSchema = StructType.fromAttributes(inputAttributes)
+    val unsafeRowJoiner =
+      GenerateUnsafeRowJoiner.create(aggregationBufferSchema, inputSchema)
+
+    aggregationMode match {
+      // Partial-only
+      case (Some(Partial), None) =>
+        val updateExpressions = 
allAggregateFunctions.flatMap(_.updateExpressions)
+        val algebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferAttributes 
++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          algebraicUpdateProjection.target(currentBuffer)
+          algebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+        }
+
+      // PartialMerge-only or Final-only
+      case (Some(PartialMerge), None) | (Some(Final), None) =>
+        val mergeExpressions = 
allAggregateFunctions.flatMap(_.mergeExpressions)
+        // This projection is used to merge buffer values for all 
AlgebraicAggregates.
+        val algebraicMergeProjection =
+          newMutableProjection(
+            mergeExpressions,
+            aggregationBufferAttributes ++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          // Process all algebraic aggregate functions.
+          algebraicMergeProjection.target(currentBuffer)
+          algebraicMergeProjection(unsafeRowJoiner.join(currentBuffer, row))
+        }
+
+      // Final-Complete
+      case (Some(Final), Some(Complete)) =>
+        val nonCompleteAggregateFunctions: Array[AlgebraicAggregate] =
+          allAggregateFunctions.take(nonCompleteAggregateExpressions.length)
+        val completeAggregateFunctions: Array[AlgebraicAggregate] =
+          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+        val completeOffsetExpressions =
+          
Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+        val mergeExpressions =
+          nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ 
completeOffsetExpressions
+        val finalAlgebraicMergeProjection =
+          newMutableProjection(
+            mergeExpressions,
+            aggregationBufferAttributes ++ inputAttributes)()
+
+        // We do not touch buffer values of aggregate functions with the Final 
mode.
+        val finalOffsetExpressions =
+          
Seq.fill(nonCompleteAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp)
+        val updateExpressions =
+          finalOffsetExpressions ++ 
completeAggregateFunctions.flatMap(_.updateExpressions)
+        val completeAlgebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferAttributes 
++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          val input = unsafeRowJoiner.join(currentBuffer, row)
+          // For all aggregate functions with mode Complete, update the given 
currentBuffer.
+          completeAlgebraicUpdateProjection.target(currentBuffer)(input)
+
+          // For all aggregate functions with mode Final, merge buffer values 
in row to
+          // currentBuffer.
+          finalAlgebraicMergeProjection.target(currentBuffer)(input)
+        }
+
+      // Complete-only
+      case (None, Some(Complete)) =>
+        val completeAggregateFunctions: Array[AlgebraicAggregate] =
+          allAggregateFunctions.takeRight(completeAggregateExpressions.length)
+
+        val updateExpressions =
+          completeAggregateFunctions.flatMap(_.updateExpressions)
+        val completeAlgebraicUpdateProjection =
+          newMutableProjection(updateExpressions, aggregationBufferAttributes 
++ inputAttributes)()
+
+        (currentBuffer: UnsafeRow, row: UnsafeRow) => {
+          completeAlgebraicUpdateProjection.target(currentBuffer)
+          // For all aggregate functions with mode Complete, update the given 
currentBuffer.
+          
completeAlgebraicUpdateProjection(unsafeRowJoiner.join(currentBuffer, row))
+        }
+
+      // Grouping only.
+      case (None, None) => (currentBuffer: UnsafeRow, row: UnsafeRow) => {}
+
+      case other =>
+        throw new IllegalStateException(
+          s"${aggregationMode} should not be passed into 
TungstenAggregationIterator.")
+    }
+  }
+
+  // Creates a function used to generate output rows.
+  private def generateResultProjection(): (UnsafeRow, UnsafeRow) => UnsafeRow 
= {
+
+    val groupingAttributes = groupingExpressions.map(_.toAttribute)
+    val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+    val bufferAttributes = allAggregateFunctions.flatMap(_.bufferAttributes)
+    val bufferSchema = StructType.fromAttributes(bufferAttributes)
+    val unsafeRowJoiner = GenerateUnsafeRowJoiner.create(groupingKeySchema, 
bufferSchema)
+
+    aggregationMode match {
+      // Partial-only or PartialMerge-only: every output row is basically the 
values of
+      // the grouping expressions and the corresponding aggregation buffer.
+      case (Some(Partial), None) | (Some(PartialMerge), None) =>
+        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+          unsafeRowJoiner.join(currentGroupingKey, currentBuffer)
+        }
+
+      // Final-only, Complete-only and Final-Complete: a output row is 
generated based on
+      // resultExpressions.
+      case (Some(Final), None) | (Some(Final) | None, Some(Complete)) =>
+        val resultProjection =
+          UnsafeProjection.create(resultExpressions, groupingAttributes ++ 
bufferAttributes)
+
+        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+          resultProjection(unsafeRowJoiner.join(currentGroupingKey, 
currentBuffer))
+        }
+
+      // Grouping-only: a output row is generated from values of grouping 
expressions.
+      case (None, None) =>
+        val resultProjection =
+          UnsafeProjection.create(resultExpressions, groupingAttributes)
+
+        (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => {
+          resultProjection(currentGroupingKey)
+        }
+
+      case other =>
+        throw new IllegalStateException(
+          s"${aggregationMode} should not be passed into 
TungstenAggregationIterator.")
+    }
+  }
+
+  // An UnsafeProjection used to extract grouping keys from the input rows.
+  private[this] val groupProjection =
+    UnsafeProjection.create(groupingExpressions, originalInputAttributes)
+
+  // A function used to process a input row. Its first argument is the 
aggregation buffer
+  // and the second argument is the input row.
+  private[this] var processRow: (UnsafeRow, UnsafeRow) => Unit =
+    generateProcessRow(originalInputAttributes)
+
+  // A function used to generate output rows based on the grouping keys (first 
argument)
+  // and the corresponding aggregation buffer (second argument).
+  private[this] var generateOutput: (UnsafeRow, UnsafeRow) => UnsafeRow =
+    generateResultProjection()
+
+  // An aggregation buffer containing initial buffer values. It is used to
+  // initialize other aggregation buffers.
+  private[this] val initialAggregationBuffer: UnsafeRow = 
createNewAggregationBuffer()
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 3: Methods and fields used by hash-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // This is the hash map used for hash-based aggregation. It is backed by an
+  // UnsafeFixedWidthAggregationMap and it is used to store
+  // all groups and their corresponding aggregation buffers for hash-based 
aggregation.
+  private[this] val hashMap = new UnsafeFixedWidthAggregationMap(
+    initialAggregationBuffer,
+    
StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
+    StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
+    TaskContext.get.taskMemoryManager(),
+    SparkEnv.get.shuffleMemoryManager,
+    1024 * 16, // initial capacity
+    SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
+    false // disable tracking of performance metrics
+  )
+
+  // The function used to read and process input rows. When processing input 
rows,
+  // it first uses hash-based aggregation by putting groups and their buffers 
in
+  // hashMap. If we could not allocate more memory for the map, we switch to
+  // sort-based aggregation (by calling switchToSortBasedAggregation).
+  private def processInputs(): Unit = {
+    while (!sortBased && inputIter.hasNext) {
+      val newInput = inputIter.next()
+      val groupingKey = groupProjection.apply(newInput)
+      val buffer: UnsafeRow = hashMap.getAggregationBuffer(groupingKey)
+      if (buffer == null) {
+        // buffer == null means that we could not allocate more memory.
+        // Now, we need to spill the map and switch to sort-based aggregation.
+        switchToSortBasedAggregation(groupingKey, newInput)
+      } else {
+        processRow(buffer, newInput)
+      }
+    }
+  }
+
+  // This function is only used for testing. It basically the same as 
processInputs except
+  // that it switch to sort-based aggregation after `fallbackStartsAt` input 
rows have
+  // been processed.
+  private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit 
= {
+    var i = 0
+    while (!sortBased && inputIter.hasNext) {
+      val newInput = inputIter.next()
+      val groupingKey = groupProjection.apply(newInput)
+      val buffer: UnsafeRow = if (i < fallbackStartsAt) {
+        hashMap.getAggregationBuffer(groupingKey)
+      } else {
+        null
+      }
+      if (buffer == null) {
+        // buffer == null means that we could not allocate more memory.
+        // Now, we need to spill the map and switch to sort-based aggregation.
+        switchToSortBasedAggregation(groupingKey, newInput)
+      } else {
+        processRow(buffer, newInput)
+      }
+      i += 1
+    }
+  }
+
+  // The iterator created from hashMap. It is used to generate output rows 
when we
+  // are using hash-based aggregation.
+  private[this] var aggregationBufferMapIterator: KVIterator[UnsafeRow, 
UnsafeRow] = null
+
+  // Indicates if aggregationBufferMapIterator still has key-value pairs.
+  private[this] var mapIteratorHasNext: Boolean = false
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 4: The function used to switch this iterator from hash-based
+  // aggregation to sort-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: 
UnsafeRow): Unit = {
+    logInfo("falling back to sort based aggregation.")
+    // Step 1: Get the ExternalSorter containing sorted entries of the map.
+    val externalSorter: UnsafeKVExternalSorter = 
hashMap.destructAndCreateExternalSorter()
+
+    // Step 2: Free the memory used by the map.
+    hashMap.free()
+
+    // Step 3: If we have aggregate function with mode Partial or Complete,
+    // we need to process input rows to get aggregation buffer.
+    // So, later in the sort-based aggregation iterator, we can do merge.
+    // If aggregate functions are with mode Final and PartialMerge,
+    // we just need to project the aggregation buffer from an input row.
+    val needsProcess = aggregationMode match {
+      case (Some(Partial), None) => true
+      case (None, Some(Complete)) => true
+      case (Some(Final), Some(Complete)) => true
+      case _ => false
+    }
+
+    if (needsProcess) {
+      // First, we create a buffer.
+      val buffer = createNewAggregationBuffer()
+
+      // Process firstKey and firstInput.
+      // Initialize buffer.
+      buffer.copyFrom(initialAggregationBuffer)
+      processRow(buffer, firstInput)
+      externalSorter.insertKV(firstKey, buffer)
+
+      // Process the rest of input rows.
+      while (inputIter.hasNext) {
+        val newInput = inputIter.next()
+        val groupingKey = groupProjection.apply(newInput)
+        buffer.copyFrom(initialAggregationBuffer)
+        processRow(buffer, newInput)
+        externalSorter.insertKV(groupingKey, buffer)
+      }
+    } else {
+      // When needsProcess is false, the format of input rows is groupingKey + 
aggregation buffer.
+      // We need to project the aggregation buffer part from an input row.
+      val buffer = createNewAggregationBuffer()
+      // The originalInputAttributes are using cloneBufferAttributes. So, we 
need to use
+      // allAggregateFunctions.flatMap(_.cloneBufferAttributes).
+      val bufferExtractor = newMutableProjection(
+        allAggregateFunctions.flatMap(_.cloneBufferAttributes),
+        originalInputAttributes)()
+      bufferExtractor.target(buffer)
+
+      // Insert firstKey and its buffer.
+      bufferExtractor(firstInput)
+      externalSorter.insertKV(firstKey, buffer)
+
+      // Insert the rest of input rows.
+      while (inputIter.hasNext) {
+        val newInput = inputIter.next()
+        val groupingKey = groupProjection.apply(newInput)
+        bufferExtractor(newInput)
+        externalSorter.insertKV(groupingKey, buffer)
+      }
+    }
+
+    // Set aggregationMode, processRow, and generateOutput for sort-based 
aggregation.
+    val newAggregationMode = aggregationMode match {
+      case (Some(Partial), None) => (Some(PartialMerge), None)
+      case (None, Some(Complete)) => (Some(Final), None)
+      case (Some(Final), Some(Complete)) => (Some(Final), None)
+      case other => other
+    }
+    aggregationMode = newAggregationMode
+
+    // Basically the value of the KVIterator returned by externalSorter
+    // will just aggregation buffer. At here, we use cloneBufferAttributes.
+    val newInputAttributes: Seq[Attribute] =
+      allAggregateFunctions.flatMap(_.cloneBufferAttributes)
+
+    // Set up new processRow and generateOutput.
+    processRow = generateProcessRow(newInputAttributes)
+    generateOutput = generateResultProjection()
+
+    // Step 5: Get the sorted iterator from the externalSorter.
+    sortedKVIterator = externalSorter.sortedIterator()
+
+    // Step 6: Pre-load the first key-value pair from the sorted iterator to 
make
+    // hasNext idempotent.
+    sortedInputHasNewGroup = sortedKVIterator.next()
+
+    // Copy the first key and value (aggregation buffer).
+    if (sortedInputHasNewGroup) {
+      val key = sortedKVIterator.getKey
+      val value = sortedKVIterator.getValue
+      nextGroupingKey = key.copy()
+      currentGroupingKey = key.copy()
+      firstRowInNextGroup = value.copy()
+    }
+
+    // Step 7: set sortBased to true.
+    sortBased = true
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 5: Methods and fields used by sort-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // Indicates if we are using sort-based aggregation. Because we first try to 
use
+  // hash-based aggregation, its initial value is false.
+  private[this] var sortBased: Boolean = false
+
+  // The KVIterator containing input rows for the sort-based aggregation. It 
will be
+  // set in switchToSortBasedAggregation when we switch to sort-based 
aggregation.
+  private[this] var sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator 
= null
+
+  // The grouping key of the current group.
+  private[this] var currentGroupingKey: UnsafeRow = null
+
+  // The grouping key of next group.
+  private[this] var nextGroupingKey: UnsafeRow = null
+
+  // The first row of next group.
+  private[this] var firstRowInNextGroup: UnsafeRow = null
+
+  // Indicates if we has new group of rows from the sorted input iterator.
+  private[this] var sortedInputHasNewGroup: Boolean = false
+
+  // The aggregation buffer used by the sort-based aggregation.
+  private[this] val sortBasedAggregationBuffer: UnsafeRow = 
createNewAggregationBuffer()
+
+  // Processes rows in the current group. It will stop when it find a new 
group.
+  private def processCurrentSortedGroup(): Unit = {
+    // First, we need to copy nextGroupingKey to currentGroupingKey.
+    currentGroupingKey.copyFrom(nextGroupingKey)
+    // Now, we will start to find all rows belonging to this group.
+    // We create a variable to track if we see the next group.
+    var findNextPartition = false
+    // firstRowInNextGroup is the first row of this group. We first process it.
+    processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
+
+    // The search will stop when we see the next group or there is no
+    // input row left in the iter.
+    // Pre-load the first key-value pair to make the condition of the while 
loop
+    // has no action (we do not trigger loading a new key-value pair
+    // when we evaluate the condition).
+    var hasNext = sortedKVIterator.next()
+    while (!findNextPartition && hasNext) {
+      // Get the grouping key and value (aggregation buffer).
+      val groupingKey = sortedKVIterator.getKey
+      val inputAggregationBuffer = sortedKVIterator.getValue
+
+      // Check if the current row belongs the current input row.
+      if (currentGroupingKey.equals(groupingKey)) {
+        processRow(sortBasedAggregationBuffer, inputAggregationBuffer)
+
+        hasNext = sortedKVIterator.next()
+      } else {
+        // We find a new group.
+        findNextPartition = true
+        // copyFrom will fail when
+        nextGroupingKey.copyFrom(groupingKey) // = groupingKey.copy()
+        firstRowInNextGroup.copyFrom(inputAggregationBuffer) // = 
inputAggregationBuffer.copy()
+
+      }
+    }
+    // We have not seen a new group. It means that there is no new row in the 
input
+    // iter. The current group is the last group of the sortedKVIterator.
+    if (!findNextPartition) {
+      sortedInputHasNewGroup = false
+      sortedKVIterator.close()
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 6: Loads input rows and setup aggregationBufferMapIterator if we
+  //         have not switched to sort-based aggregation.
+  ///////////////////////////////////////////////////////////////////////////
+
+  // Starts to process input rows.
+  testFallbackStartsAt match {
+    case None =>
+      processInputs()
+    case Some(fallbackStartsAt) =>
+      // This is the testing path. processInputsWithControlledFallback is same 
as processInputs
+      // except that it switches to sort-based aggregation after 
`fallbackStartsAt` input rows
+      // have been processed.
+      processInputsWithControlledFallback(fallbackStartsAt)
+  }
+
+  // If we did not switch to sort-based aggregation in processInputs,
+  // we pre-load the first key-value pair from the map (to make hasNext 
idempotent).
+  if (!sortBased) {
+    // First, set aggregationBufferMapIterator.
+    aggregationBufferMapIterator = hashMap.iterator()
+    // Pre-load the first key-value pair from the aggregationBufferMapIterator.
+    mapIteratorHasNext = aggregationBufferMapIterator.next()
+    // If the map is empty, we just free it.
+    if (!mapIteratorHasNext) {
+      hashMap.free()
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Par 7: Iterator's public methods.
+  ///////////////////////////////////////////////////////////////////////////
+
+  override final def hasNext: Boolean = {
+    (sortBased && sortedInputHasNewGroup) || (!sortBased && mapIteratorHasNext)
+  }
+
+  override final def next(): UnsafeRow = {
+    if (hasNext) {
+      if (sortBased) {
+        // Process the current group.
+        processCurrentSortedGroup()
+        // Generate output row for the current group.
+        val outputRow = generateOutput(currentGroupingKey, 
sortBasedAggregationBuffer)
+        // Initialize buffer values for the next group.
+        sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+
+        outputRow
+      } else {
+        // We did not fall back to sort-based aggregation.
+        val result =
+          generateOutput(
+            aggregationBufferMapIterator.getKey,
+            aggregationBufferMapIterator.getValue)
+
+        // Pre-load next key-value pair form aggregationBufferMapIterator to 
make hasNext
+        // idempotent.
+        mapIteratorHasNext = aggregationBufferMapIterator.next()
+
+        if (!mapIteratorHasNext) {
+          // If there is no input from aggregationBufferMapIterator, we copy 
current result.
+          val resultCopy = result.copy()
+          // Then, we free the map.
+          hashMap.free()
+
+          resultCopy
+        } else {
+          result
+        }
+      }
+    } else {
+      // no more result
+      throw new NoSuchElementException
+    }
+  }
+
+  ///////////////////////////////////////////////////////////////////////////
+  // Part 8: A utility function used to generate a output row when there is no
+  // input and there is no grouping expression.
+  ///////////////////////////////////////////////////////////////////////////
+  def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
+    if (groupingExpressions.isEmpty) {
+      sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
+      // We create a output row and copy it. So, we can free the map.
+      val resultCopy =
+        generateOutput(UnsafeRow.createFromByteArray(0, 0), 
sortBasedAggregationBuffer).copy()
+      hashMap.free()
+      resultCopy
+    } else {
+      throw new IllegalStateException(
+        "This method should not be called when groupingExpressions is not 
empty.")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/272e8834/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
deleted file mode 100644
index b465787..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UnsafeHybridAggregationIterator.scala
+++ /dev/null
@@ -1,372 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.execution.{UnsafeKVExternalSorter, 
UnsafeFixedWidthAggregationMap}
-import org.apache.spark.sql.types.StructType
-
-/**
- * An iterator used to evaluate [[AggregateFunction2]].
- * It first tries to use in-memory hash-based aggregation. If we cannot 
allocate more
- * space for the hash map, we spill the sorted map entries, free the map, and 
then
- * switch to sort-based aggregation.
- */
-class UnsafeHybridAggregationIterator(
-    groupingKeyAttributes: Seq[Attribute],
-    valueAttributes: Seq[Attribute],
-    inputKVIterator: KVIterator[UnsafeRow, InternalRow],
-    nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-    nonCompleteAggregateAttributes: Seq[Attribute],
-    completeAggregateExpressions: Seq[AggregateExpression2],
-    completeAggregateAttributes: Seq[Attribute],
-    initialInputBufferOffset: Int,
-    resultExpressions: Seq[NamedExpression],
-    newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
-    outputsUnsafeRows: Boolean)
-  extends AggregationIterator(
-    groupingKeyAttributes,
-    valueAttributes,
-    nonCompleteAggregateExpressions,
-    nonCompleteAggregateAttributes,
-    completeAggregateExpressions,
-    completeAggregateAttributes,
-    initialInputBufferOffset,
-    resultExpressions,
-    newMutableProjection,
-    outputsUnsafeRows) {
-
-  require(groupingKeyAttributes.nonEmpty)
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Unsafe Aggregation buffers
-  ///////////////////////////////////////////////////////////////////////////
-
-  // This is the Unsafe Aggregation Map used to store all buffers.
-  private[this] val buffers = new UnsafeFixedWidthAggregationMap(
-    newBuffer,
-    
StructType.fromAttributes(allAggregateFunctions.flatMap(_.bufferAttributes)),
-    StructType.fromAttributes(groupingKeyAttributes),
-    TaskContext.get.taskMemoryManager(),
-    SparkEnv.get.shuffleMemoryManager,
-    1024 * 16, // initial capacity
-    SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m"),
-    false // disable tracking of performance metrics
-  )
-
-  override protected def newBuffer: UnsafeRow = {
-    val bufferSchema = allAggregateFunctions.flatMap(_.bufferAttributes)
-    val bufferRowSize: Int = bufferSchema.length
-
-    val genericMutableBuffer = new GenericMutableRow(bufferRowSize)
-    val unsafeProjection =
-      UnsafeProjection.create(bufferSchema.map(_.dataType))
-    val buffer = unsafeProjection.apply(genericMutableBuffer)
-    initializeBuffer(buffer)
-    buffer
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Methods and variables related to switching to sort-based aggregation
-  ///////////////////////////////////////////////////////////////////////////
-  private[this] var sortBased = false
-
-  private[this] var sortBasedAggregationIterator: SortBasedAggregationIterator 
= _
-
-  // The value part of the input KV iterator is used to store original input 
values of
-  // aggregate functions, we need to convert them to aggregation buffers.
-  private def processOriginalInput(
-      firstKey: UnsafeRow,
-      firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
-    new KVIterator[UnsafeRow, UnsafeRow] {
-      private[this] var isFirstRow = true
-
-      private[this] var groupingKey: UnsafeRow = _
-
-      private[this] val buffer: UnsafeRow = newBuffer
-
-      override def next(): Boolean = {
-        initializeBuffer(buffer)
-        if (isFirstRow) {
-          isFirstRow = false
-          groupingKey = firstKey
-          processRow(buffer, firstValue)
-
-          true
-        } else if (inputKVIterator.next()) {
-          groupingKey = inputKVIterator.getKey()
-          val value = inputKVIterator.getValue()
-          processRow(buffer, value)
-
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): UnsafeRow = {
-        groupingKey
-      }
-
-      override def getValue(): UnsafeRow = {
-        buffer
-      }
-
-      override def close(): Unit = {
-        // Do nothing.
-      }
-    }
-  }
-
-  // The value of the input KV Iterator has the format of groupingExprs + 
aggregation buffer.
-  // We need to project the aggregation buffer out.
-  private def projectInputBufferToUnsafe(
-      firstKey: UnsafeRow,
-      firstValue: InternalRow): KVIterator[UnsafeRow, UnsafeRow] = {
-    new KVIterator[UnsafeRow, UnsafeRow] {
-      private[this] var isFirstRow = true
-
-      private[this] var groupingKey: UnsafeRow = _
-
-      private[this] val bufferSchema = 
allAggregateFunctions.flatMap(_.bufferAttributes)
-
-      private[this] val value: UnsafeRow = {
-        val genericMutableRow = new GenericMutableRow(bufferSchema.length)
-        
UnsafeProjection.create(bufferSchema.map(_.dataType)).apply(genericMutableRow)
-      }
-
-      private[this] val projectInputBuffer = {
-        newMutableProjection(bufferSchema, valueAttributes)().target(value)
-      }
-
-      override def next(): Boolean = {
-        if (isFirstRow) {
-          isFirstRow = false
-          groupingKey = firstKey
-          projectInputBuffer(firstValue)
-
-          true
-        } else if (inputKVIterator.next()) {
-          groupingKey = inputKVIterator.getKey()
-          projectInputBuffer(inputKVIterator.getValue())
-
-          true
-        } else {
-          false
-        }
-      }
-
-      override def getKey(): UnsafeRow = {
-        groupingKey
-      }
-
-      override def getValue(): UnsafeRow = {
-        value
-      }
-
-      override def close(): Unit = {
-        // Do nothing.
-      }
-    }
-  }
-
-  /**
-   * We need to fall back to sort based aggregation because we do not have 
enough memory
-   * for our in-memory hash map (i.e. `buffers`).
-   */
-  private def switchToSortBasedAggregation(
-      currentGroupingKey: UnsafeRow,
-      currentRow: InternalRow): Unit = {
-    logInfo("falling back to sort based aggregation.")
-
-    // Step 1: Get the ExternalSorter containing entries of the map.
-    val externalSorter = buffers.destructAndCreateExternalSorter()
-
-    // Step 2: Free the memory used by the map.
-    buffers.free()
-
-    // Step 3: If we have aggregate function with mode Partial or Complete,
-    // we need to process them to get aggregation buffer.
-    // So, later in the sort-based aggregation iterator, we can do merge.
-    // If aggregate functions are with mode Final and PartialMerge,
-    // we just need to project the aggregation buffer from the input.
-    val needsProcess = aggregationMode match {
-      case (Some(Partial), None) => true
-      case (None, Some(Complete)) => true
-      case (Some(Final), Some(Complete)) => true
-      case _ => false
-    }
-
-    val processedIterator = if (needsProcess) {
-      processOriginalInput(currentGroupingKey, currentRow)
-    } else {
-      // The input value's format is groupingExprs + buffer.
-      // We need to project the buffer part out.
-      projectInputBufferToUnsafe(currentGroupingKey, currentRow)
-    }
-
-    // Step 4: Redirect processedIterator to externalSorter.
-    while (processedIterator.next()) {
-      externalSorter.insertKV(processedIterator.getKey(), 
processedIterator.getValue())
-    }
-
-    // Step 5: Get the sorted iterator from the externalSorter.
-    val sortedKVIterator: UnsafeKVExternalSorter#KVSorterIterator = 
externalSorter.sortedIterator()
-
-    // Step 6: We now create a SortBasedAggregationIterator based on 
sortedKVIterator.
-    // For a aggregate function with mode Partial, its mode in the 
SortBasedAggregationIterator
-    // will be PartialMerge. For a aggregate function with mode Complete,
-    // its mode in the SortBasedAggregationIterator will be Final.
-    val newNonCompleteAggregateExpressions = allAggregateExpressions.map {
-        case AggregateExpression2(func, Partial, isDistinct) =>
-          AggregateExpression2(func, PartialMerge, isDistinct)
-        case AggregateExpression2(func, Complete, isDistinct) =>
-          AggregateExpression2(func, Final, isDistinct)
-        case other => other
-      }
-    val newNonCompleteAggregateAttributes =
-      nonCompleteAggregateAttributes ++ completeAggregateAttributes
-
-    val newValueAttributes =
-      
allAggregateExpressions.flatMap(_.aggregateFunction.cloneBufferAttributes)
-
-    sortBasedAggregationIterator = 
SortBasedAggregationIterator.createFromKVIterator(
-      groupingKeyAttributes = groupingKeyAttributes,
-      valueAttributes = newValueAttributes,
-      inputKVIterator = sortedKVIterator.asInstanceOf[KVIterator[InternalRow, 
InternalRow]],
-      nonCompleteAggregateExpressions = newNonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes = newNonCompleteAggregateAttributes,
-      completeAggregateExpressions = Nil,
-      completeAggregateAttributes = Nil,
-      initialInputBufferOffset = 0,
-      resultExpressions = resultExpressions,
-      newMutableProjection = newMutableProjection,
-      outputsUnsafeRows = outputsUnsafeRows)
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Methods used to initialize this iterator.
-  ///////////////////////////////////////////////////////////////////////////
-
-  /** Starts to read input rows and falls back to sort-based aggregation if 
necessary. */
-  protected def initialize(): Unit = {
-    var hasNext = inputKVIterator.next()
-    while (!sortBased && hasNext) {
-      val groupingKey = inputKVIterator.getKey()
-      val currentRow = inputKVIterator.getValue()
-      val buffer = buffers.getAggregationBuffer(groupingKey)
-      if (buffer == null) {
-        // buffer == null means that we could not allocate more memory.
-        // Now, we need to spill the map and switch to sort-based aggregation.
-        switchToSortBasedAggregation(groupingKey, currentRow)
-        sortBased = true
-      } else {
-        processRow(buffer, currentRow)
-        hasNext = inputKVIterator.next()
-      }
-    }
-  }
-
-  // This is the starting point of this iterator.
-  initialize()
-
-  // Creates the iterator for the Hash Aggregation Map after we have populated
-  // contents of that map.
-  private[this] val aggregationBufferMapIterator = buffers.iterator()
-
-  private[this] var _mapIteratorHasNext = false
-
-  // Pre-load the first key-value pair from the map to make hasNext idempotent.
-  if (!sortBased) {
-    _mapIteratorHasNext = aggregationBufferMapIterator.next()
-    // If the map is empty, we just free it.
-    if (!_mapIteratorHasNext) {
-      buffers.free()
-    }
-  }
-
-  ///////////////////////////////////////////////////////////////////////////
-  // Iterator's public methods
-  ///////////////////////////////////////////////////////////////////////////
-
-  override final def hasNext: Boolean = {
-    (sortBased && sortBasedAggregationIterator.hasNext) || (!sortBased && 
_mapIteratorHasNext)
-  }
-
-
-  override final def next(): InternalRow = {
-    if (hasNext) {
-      if (sortBased) {
-        sortBasedAggregationIterator.next()
-      } else {
-        // We did not fall back to the sort-based aggregation.
-        val result =
-          generateOutput(
-            aggregationBufferMapIterator.getKey,
-            aggregationBufferMapIterator.getValue)
-        // Pre-load next key-value pair form aggregationBufferMapIterator.
-        _mapIteratorHasNext = aggregationBufferMapIterator.next()
-
-        if (!_mapIteratorHasNext) {
-          val resultCopy = result.copy()
-          buffers.free()
-          resultCopy
-        } else {
-          result
-        }
-      }
-    } else {
-      // no more result
-      throw new NoSuchElementException
-    }
-  }
-}
-
-object UnsafeHybridAggregationIterator {
-  // scalastyle:off
-  def createFromInputIterator(
-      groupingExprs: Seq[NamedExpression],
-      nonCompleteAggregateExpressions: Seq[AggregateExpression2],
-      nonCompleteAggregateAttributes: Seq[Attribute],
-      completeAggregateExpressions: Seq[AggregateExpression2],
-      completeAggregateAttributes: Seq[Attribute],
-      initialInputBufferOffset: Int,
-      resultExpressions: Seq[NamedExpression],
-      newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
-      inputAttributes: Seq[Attribute],
-      inputIter: Iterator[InternalRow],
-      outputsUnsafeRows: Boolean): UnsafeHybridAggregationIterator = {
-    new UnsafeHybridAggregationIterator(
-      groupingExprs.map(_.toAttribute),
-      inputAttributes,
-      AggregationIterator.unsafeKVIterator(groupingExprs, inputAttributes, 
inputIter),
-      nonCompleteAggregateExpressions,
-      nonCompleteAggregateAttributes,
-      completeAggregateExpressions,
-      completeAggregateAttributes,
-      initialInputBufferOffset,
-      resultExpressions,
-      newMutableProjection,
-      outputsUnsafeRows)
-  }
-  // scalastyle:on
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to