Repository: spark
Updated Branches:
  refs/heads/master 66a99f4a4 -> 27daf6bcd


[SPARK-17949][SQL] A JVM object based aggregate operator

## What changes were proposed in this pull request?

This PR adds a new hash-based aggregate operator named 
`ObjectHashAggregateExec` that supports `TypedImperativeAggregate`, which may 
use arbitrary Java objects as aggregation states. Please refer to the [design 
doc](https://issues.apache.org/jira/secure/attachment/12834260/%5BDesign%20Doc%5D%20Support%20for%20Arbitrary%20Aggregation%20States.pdf)
 attached in [SPARK-17949](https://issues.apache.org/jira/browse/SPARK-17949) 
for more details about it.

The major benefit of this operator is better performance when evaluating 
`TypedImperativeAggregate` functions, especially when there are relatively few 
distinct groups. Functions like Hive UDAFs, `collect_list`, and `collect_set` 
may also benefit from this after being migrated to `TypedImperativeAggregate`.

The following feature flag is introduced to enable or disable the new aggregate 
operator:
- Name: `spark.sql.execution.useObjectHashAggregateExec`
- Default value: `true`

We can also configure the fallback threshold using the following SQL operation:
- Name: `spark.sql.objectHashAggregate.sortBased.fallbackThreshold`
- Default value: 128

  Fallback to sort-based aggregation when more than 128 distinct groups are 
accumulated in the aggregation hash map. This number is intentionally made 
small to avoid GC problems since aggregation buffers of this operator may 
contain arbitrary Java objects.

  This may be improved by implementing size tracking for this operator, but 
that can be done in a separate PR.

Code generation and size tracking are planned to be implemented in follow-up 
PRs.
## Benchmark results
### `ObjectHashAggregateExec` vs `SortAggregateExec`

The first benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` 
by evaluating `typed_count`, a testing `TypedImperativeAggregate` version of 
the SQL `count` function.

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

object agg v.s. sort agg:                Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
------------------------------------------------------------------------------------------------
sort agg w/ group by                        31251 / 31908          3.4         
298.0       1.0X
object agg w/ group by w/o fallback           6903 / 7141         15.2          
65.8       4.5X
object agg w/ group by w/ fallback          20945 / 21613          5.0         
199.7       1.5X
sort agg w/o group by                         4734 / 5463         22.1          
45.2       6.6X
object agg w/o group by w/o fallback          4310 / 4529         24.3          
41.1       7.3X
```

The next benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` 
by evaluating the Spark native version of `percentile_approx`.

Note that `percentile_approx` is so heavy an aggregate function that the 
bottleneck of the benchmark is evaluating the aggregate function itself rather 
than the aggregate operator since I couldn't run a large scale benchmark on my 
laptop. That's why the results are so close and looks counter-intuitive 
(aggregation with grouping is even faster than that aggregation without 
grouping).

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

object agg v.s. sort agg:                Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
------------------------------------------------------------------------------------------------
sort agg w/ group by                          3418 / 3530          0.6        
1630.0       1.0X
object agg w/ group by w/o fallback           3210 / 3314          0.7        
1530.7       1.1X
object agg w/ group by w/ fallback            3419 / 3511          0.6        
1630.1       1.0X
sort agg w/o group by                         4336 / 4499          0.5        
2067.3       0.8X
object agg w/o group by w/o fallback          4271 / 4372          0.5        
2036.7       0.8X
```
### Hive UDAF vs Spark AF

This benchmark compares the following two kinds of aggregate functions:
- "hive udaf": Hive implementation of `percentile_approx`, without partial 
aggregation supports, evaluated using `SortAggregateExec`.
- "spark af": Spark native implementation of `percentile_approx`, with partial 
aggregation support, evaluated using `ObjectHashAggregateExec`

The performance differences are mostly due to faster implementation and partial 
aggregation support in the Spark native version of `percentile_approx`.

This benchmark basically shows the performance differences between the worst 
case, where an aggregate function without partial aggregation support is 
evaluated using `SortAggregateExec`, and the best case, where a 
`TypedImperativeAggregate` with partial aggregation support is evaluated using 
`ObjectHashAggregateExec`.

```
Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz

hive udaf vs spark af:                   Best/Avg Time(ms)    Rate(M/s)   Per 
Row(ns)   Relative
------------------------------------------------------------------------------------------------
hive udaf w/o group by                        5326 / 5408          0.0       
81264.2       1.0X
spark af w/o group by                           93 /  111          0.7        
1415.6      57.4X
hive udaf w/ group by                         3804 / 3946          0.0       
58050.1       1.4X
spark af w/ group by w/o fallback               71 /   90          0.9        
1085.7      74.8X
spark af w/ group by w/ fallback                98 /  111          0.7        
1501.6      54.1X
```
### Real world benchmark

We also did a relatively large benchmark using a real world query involving 
`percentile_approx`:
- Hive UDAF implementation, sort-based aggregation, w/o partial aggregation 
support

  24.77 minutes
- Native implementation, sort-based aggregation, w/ partial aggregation support

  4.64 minutes
- Native implementation, object hash aggregator, w/ partial aggregation support

  1.80 minutes
## How was this patch tested?

New unit tests and randomized test cases are added in 
`ObjectAggregateFunctionSuite`.

Author: Cheng Lian <l...@databricks.com>

Closes #15590 from liancheng/obj-hash-agg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27daf6bc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27daf6bc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27daf6bc

Branch: refs/heads/master
Commit: 27daf6bcde782ed3e0f0d951c90c8040fd47e985
Parents: 66a99f4
Author: Cheng Lian <l...@databricks.com>
Authored: Thu Nov 3 09:34:51 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Nov 3 09:34:51 2016 -0700

----------------------------------------------------------------------
 .../sql/execution/aggregate/AggUtils.scala      |  31 +-
 .../aggregate/ObjectAggregationIterator.scala   | 323 ++++++++++++++
 .../aggregate/ObjectAggregationMap.scala        | 110 +++++
 .../aggregate/ObjectHashAggregateExec.scala     | 155 +++++++
 .../org/apache/spark/sql/internal/SQLConf.scala |  22 +
 .../sql/TypedImperativeAggregateSuite.scala     |   6 +-
 .../SortBasedAggregationStoreSuite.scala        | 141 ++++++
 .../ObjectHashAggregateExecBenchmark.scala      | 230 ++++++++++
 .../execution/ObjectHashAggregateSuite.scala    | 433 +++++++++++++++++++
 .../sql/hive/execution/TestingTypedCount.scala  |  87 ++++
 10 files changed, 1527 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 4fbb9d5..3c8ef1a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, 
StateStoreSaveExec}
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Utility functions used by the query planner to convert our plan to new 
aggregation code path.
@@ -66,14 +67,28 @@ object AggUtils {
         resultExpressions = resultExpressions,
         child = child)
     } else {
-      SortAggregateExec(
-        requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
-        groupingExpressions = groupingExpressions,
-        aggregateExpressions = aggregateExpressions,
-        aggregateAttributes = aggregateAttributes,
-        initialInputBufferOffset = initialInputBufferOffset,
-        resultExpressions = resultExpressions,
-        child = child)
+      val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation
+      val useObjectHash = 
ObjectHashAggregateExec.supportsAggregate(aggregateExpressions)
+
+      if (objectHashEnabled && useObjectHash) {
+        ObjectHashAggregateExec(
+          requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
+          groupingExpressions = groupingExpressions,
+          aggregateExpressions = aggregateExpressions,
+          aggregateAttributes = aggregateAttributes,
+          initialInputBufferOffset = initialInputBufferOffset,
+          resultExpressions = resultExpressions,
+          child = child)
+      } else {
+        SortAggregateExec(
+          requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
+          groupingExpressions = groupingExpressions,
+          aggregateExpressions = aggregateExpressions,
+          aggregateAttributes = aggregateAttributes,
+          initialInputBufferOffset = initialInputBufferOffset,
+          resultExpressions = resultExpressions,
+          child = child)
+      }
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
new file mode 100644
index 0000000..3c7b9ee
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, 
GenerateOrdering}
+import org.apache.spark.sql.execution.UnsafeKVExternalSorter
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.KVIterator
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+class ObjectAggregationIterator(
+    outputAttributes: Seq[Attribute],
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    newMutableProjection: (Seq[Expression], Seq[Attribute]) => 
MutableProjection,
+    originalInputAttributes: Seq[Attribute],
+    inputRows: Iterator[InternalRow],
+    fallbackCountThreshold: Int)
+  extends AggregationIterator(
+    groupingExpressions,
+    originalInputAttributes,
+    aggregateExpressions,
+    aggregateAttributes,
+    initialInputBufferOffset,
+    resultExpressions,
+    newMutableProjection) with Logging {
+
+  // Indicates whether we have fallen back to sort-based aggregation or not.
+  private[this] var sortBased: Boolean = false
+
+  private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _
+
+  // Hacking the aggregation mode to call AggregateFunction.merge to merge two 
aggregation buffers
+  private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = {
+    val newExpressions = aggregateExpressions.map {
+      case agg @ AggregateExpression(_, Partial, _, _) =>
+        agg.copy(mode = PartialMerge)
+      case agg @ AggregateExpression(_, Complete, _, _) =>
+        agg.copy(mode = Final)
+      case other => other
+    }
+    val newFunctions = initializeAggregateFunctions(newExpressions, 0)
+    val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes)
+    generateProcessRow(newExpressions, newFunctions, newInputAttributes)
+  }
+
+  // A safe projection used to do deep clone of input rows to prevent false 
sharing.
+  private[this] val safeProjection: Projection =
+    FromUnsafeProjection(outputAttributes.map(_.dataType))
+
+  /**
+   * Start processing input rows.
+   */
+  processInputs()
+
+  override final def hasNext: Boolean = {
+    aggBufferIterator.hasNext
+  }
+
+  override final def next(): UnsafeRow = {
+    val entry = aggBufferIterator.next()
+    generateOutput(entry.groupingKey, entry.aggregationBuffer)
+  }
+
+  /**
+   * Generate an output row when there is no input and there is no grouping 
expression.
+   */
+  def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
+    if (groupingExpressions.isEmpty) {
+      val defaultAggregationBuffer = createNewAggregationBuffer()
+      generateOutput(UnsafeRow.createFromByteArray(0, 0), 
defaultAggregationBuffer)
+    } else {
+      throw new IllegalStateException(
+        "This method should not be called when groupingExpressions is not 
empty.")
+    }
+  }
+
+  // Creates a new aggregation buffer and initializes buffer values. This 
function should only be
+  // called under two cases:
+  //
+  //  - when creating aggregation buffer for a new group in the hash map, and
+  //  - when creating the re-used buffer for sort-based aggregation
+  private def createNewAggregationBuffer(): SpecificInternalRow = {
+    val bufferFieldTypes = 
aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType))
+    val buffer = new SpecificInternalRow(bufferFieldTypes)
+    initAggregationBuffer(buffer)
+    buffer
+  }
+
+  private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = {
+    // Initializes declarative aggregates' buffer values
+    expressionAggInitialProjection.target(buffer)(EmptyRow)
+    // Initializes imperative aggregates' buffer values
+    aggregateFunctions.collect { case f: ImperativeAggregate => f 
}.foreach(_.initialize(buffer))
+  }
+
+  private def getAggregationBufferByKey(
+    hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = {
+    var aggBuffer = hashMap.getAggregationBuffer(groupingKey)
+
+    if (aggBuffer == null) {
+      aggBuffer = createNewAggregationBuffer()
+      hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer)
+    }
+
+    aggBuffer
+  }
+
+  // This function is used to read and process input rows. When processing 
input rows, it first uses
+  // hash-based aggregation by putting groups and their buffers in `hashMap`. 
If `hashMap` grows too
+  // large, it sorts the contents, spills them to disk, and creates a new map. 
At last, all sorted
+  // spills are merged together for sort-based aggregation.
+  private def processInputs(): Unit = {
+    // In-memory map to store aggregation buffer for hash-based aggregation.
+    val hashMap = new ObjectAggregationMap()
+
+    // If in-memory map is unable to stores all aggregation buffer, fallback 
to sort-based
+    // aggregation backed by sorted physical storage.
+    var sortBasedAggregationStore: SortBasedAggregator = null
+
+    if (groupingExpressions.isEmpty) {
+      // If there is no grouping expressions, we can just reuse the same 
buffer over and over again.
+      val groupingKey = groupingProjection.apply(null)
+      val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey)
+      while (inputRows.hasNext) {
+        val newInput = safeProjection(inputRows.next())
+        processRow(buffer, newInput)
+      }
+    } else {
+      while (inputRows.hasNext && !sortBased) {
+        val newInput = safeProjection(inputRows.next())
+        val groupingKey = groupingProjection.apply(newInput)
+        val buffer: InternalRow = getAggregationBufferByKey(hashMap, 
groupingKey)
+        processRow(buffer, newInput)
+
+        // The the hash map gets too large, makes a sorted spill and clear the 
map.
+        if (hashMap.size >= fallbackCountThreshold) {
+          logInfo(
+            s"Aggregation hash map reaches threshold " +
+              s"capacity ($fallbackCountThreshold entries), spilling and 
falling back to sort" +
+              s" based aggregation. You may change the threshold by adjust 
option " +
+              SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key
+          )
+
+          // Falls back to sort-based aggregation
+          sortBased = true
+
+        }
+      }
+
+      if (sortBased) {
+        val sortIteratorFromHashMap = hashMap
+          .dumpToExternalSorter(groupingAttributes, aggregateFunctions)
+          .sortedIterator()
+        sortBasedAggregationStore = new SortBasedAggregator(
+          sortIteratorFromHashMap,
+          StructType.fromAttributes(originalInputAttributes),
+          StructType.fromAttributes(groupingAttributes),
+          processRow,
+          mergeAggregationBuffers,
+          createNewAggregationBuffer())
+
+        while (inputRows.hasNext) {
+          // NOTE: The input row is always UnsafeRow
+          val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow]
+          val groupingKey = groupingProjection.apply(unsafeInputRow)
+          sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow)
+        }
+      }
+    }
+
+    if (sortBased) {
+      aggBufferIterator = sortBasedAggregationStore.destructiveIterator()
+    } else {
+      aggBufferIterator = hashMap.iterator
+    }
+  }
+}
+
+/**
+ * A class used to handle sort-based aggregation, used together with 
[[ObjectHashAggregateExec]].
+ *
+ * @param initialAggBufferIterator iterator that points to sorted input 
aggregation buffers
+ * @param inputSchema  The schema of input row
+ * @param groupingSchema The schema of grouping key
+ * @param processRow  Function to update the aggregation buffer with input rows
+ * @param mergeAggregationBuffers Function used to merge the input aggregation 
buffers into existing
+ *                                aggregation buffers
+ * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer
+ *
+ * @todo Try to eliminate this class by refactor and reuse code paths in 
[[SortAggregateExec]].
+ */
+class SortBasedAggregator(
+    initialAggBufferIterator: KVIterator[UnsafeRow, UnsafeRow],
+    inputSchema: StructType,
+    groupingSchema: StructType,
+    processRow: (InternalRow, InternalRow) => Unit,
+    mergeAggregationBuffers: (InternalRow, InternalRow) => Unit,
+    makeEmptyAggregationBuffer: => InternalRow) {
+
+  // external sorter to sort the input (grouping key + input row) with 
grouping key.
+  private val inputSorter = createExternalSorterForInput()
+  private val groupingKeyOrdering: BaseOrdering = 
GenerateOrdering.create(groupingSchema)
+
+  def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = {
+    inputSorter.insertKV(groupingKey, inputRow)
+  }
+
+  /**
+   * Returns a destructive iterator of AggregationBufferEntry.
+   * Notice: it is illegal to call any method after `destructiveIterator()` 
has been called.
+   */
+  def destructiveIterator(): Iterator[AggregationBufferEntry] = {
+    new Iterator[AggregationBufferEntry] {
+      val inputIterator = inputSorter.sortedIterator()
+      var hasNextInput: Boolean = inputIterator.next()
+      var hasNextAggBuffer: Boolean = initialAggBufferIterator.next()
+      private var result: AggregationBufferEntry = _
+      private var groupingKey: UnsafeRow = _
+
+      override def hasNext(): Boolean = {
+        result != null || findNextSortedGroup()
+      }
+
+      override def next(): AggregationBufferEntry = {
+        val returnResult = result
+        result = null
+        returnResult
+      }
+
+      // Two-way merges initialAggBufferIterator and inputIterator
+      private def findNextSortedGroup(): Boolean = {
+        if (hasNextInput || hasNextAggBuffer) {
+          // Find smaller key of the initialAggBufferIterator and 
initialAggBufferIterator
+          groupingKey = findGroupingKey()
+          result = new AggregationBufferEntry(groupingKey, 
makeEmptyAggregationBuffer)
+
+          // Firstly, update the aggregation buffer with input rows.
+          while (hasNextInput &&
+            groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 
0) {
+            processRow(result.aggregationBuffer, inputIterator.getValue)
+            hasNextInput = inputIterator.next()
+          }
+
+          // Secondly, merge the aggregation buffer with existing aggregation 
buffers.
+          // NOTE: the ordering of these two while-block matter, 
mergeAggregationBuffer() should
+          // be called after calling processRow.
+          while (hasNextAggBuffer &&
+            groupingKeyOrdering.compare(initialAggBufferIterator.getKey, 
groupingKey) == 0) {
+            mergeAggregationBuffers(result.aggregationBuffer, 
initialAggBufferIterator.getValue)
+            hasNextAggBuffer = initialAggBufferIterator.next()
+          }
+
+          true
+        } else {
+          false
+        }
+      }
+
+      private def findGroupingKey(): UnsafeRow = {
+        var newGroupingKey: UnsafeRow = null
+        if (!hasNextInput) {
+          newGroupingKey = initialAggBufferIterator.getKey
+        } else if (!hasNextAggBuffer) {
+          newGroupingKey = inputIterator.getKey
+        } else {
+          val compareResult =
+            groupingKeyOrdering.compare(inputIterator.getKey, 
initialAggBufferIterator.getKey)
+          if (compareResult <= 0) {
+            newGroupingKey = inputIterator.getKey
+          } else {
+            newGroupingKey = initialAggBufferIterator.getKey
+          }
+        }
+
+        if (groupingKey == null) {
+          groupingKey = newGroupingKey.copy()
+        } else {
+          groupingKey.copyFrom(newGroupingKey)
+        }
+        groupingKey
+      }
+    }
+  }
+
+  private def createExternalSorterForInput(): UnsafeKVExternalSorter = {
+    new UnsafeKVExternalSorter(
+      groupingSchema,
+      inputSchema,
+      SparkEnv.get.blockManager,
+      SparkEnv.get.serializerManager,
+      TaskContext.get().taskMemoryManager().pageSizeBytes,
+      SparkEnv.get.conf.getLong(
+        "spark.shuffle.spill.numElementsForceSpillThreshold",
+        UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
+      null
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
new file mode 100644
index 0000000..f2d4f6c
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import java.{util => ju}
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, 
UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, 
TypedImperativeAggregate}
+import org.apache.spark.sql.execution.UnsafeKVExternalSorter
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+/**
+ * An aggregation map that supports using safe `SpecificInternalRow`s 
aggregation buffers, so that
+ * we can support storing arbitrary Java objects as aggregate function states 
in the aggregation
+ * buffers. This class is only used together with [[ObjectHashAggregateExec]].
+ */
+class ObjectAggregationMap() {
+  private[this] val hashMap = new ju.LinkedHashMap[UnsafeRow, InternalRow]
+
+  def getAggregationBuffer(groupingKey: UnsafeRow): InternalRow = {
+    hashMap.get(groupingKey)
+  }
+
+  def putAggregationBuffer(groupingKey: UnsafeRow, aggBuffer: InternalRow): 
Unit = {
+    hashMap.put(groupingKey, aggBuffer)
+  }
+
+  def size: Int = hashMap.size()
+
+  def iterator: Iterator[AggregationBufferEntry] = {
+    val iter = hashMap.entrySet().iterator()
+    new Iterator[AggregationBufferEntry] {
+
+      override def hasNext: Boolean = {
+        iter.hasNext
+      }
+      override def next(): AggregationBufferEntry = {
+        val entry = iter.next()
+        new AggregationBufferEntry(entry.getKey, entry.getValue)
+      }
+    }
+  }
+
+  /**
+   * Dumps all entries into a newly created external sorter, clears the hash 
map, and returns the
+   * external sorter.
+   */
+  def dumpToExternalSorter(
+      groupingAttributes: Seq[Attribute],
+      aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = {
+    val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
+    val sorter = new UnsafeKVExternalSorter(
+      StructType.fromAttributes(groupingAttributes),
+      StructType.fromAttributes(aggBufferAttributes),
+      SparkEnv.get.blockManager,
+      SparkEnv.get.serializerManager,
+      TaskContext.get().taskMemoryManager().pageSizeBytes,
+      SparkEnv.get.conf.getLong(
+        "spark.shuffle.spill.numElementsForceSpillThreshold",
+        UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
+      null
+    )
+
+    val mapIterator = iterator
+    val unsafeAggBufferProjection =
+      UnsafeProjection.create(aggBufferAttributes.map(_.dataType).toArray)
+
+    while (mapIterator.hasNext) {
+      val entry = mapIterator.next()
+      aggregateFunctions.foreach {
+        case agg: TypedImperativeAggregate[_] =>
+          agg.serializeAggregateBufferInPlace(entry.aggregationBuffer)
+        case _ =>
+      }
+
+      sorter.insertKV(
+        entry.groupingKey,
+        unsafeAggBufferProjection(entry.aggregationBuffer)
+      )
+    }
+
+    hashMap.clear()
+    sorter
+  }
+
+  def clear(): Unit = {
+    hashMap.clear()
+  }
+}
+
+// Stores the grouping key and aggregation buffer
+class AggregationBufferEntry(var groupingKey: UnsafeRow, var 
aggregationBuffer: InternalRow)

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
new file mode 100644
index 0000000..3fcb7ec
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.util.Utils
+
+/**
+ * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] 
functions that may
+ * use arbitrary JVM objects as aggregation states.
+ *
+ * Similar to [[HashAggregateExec]], this operator also falls back to 
sort-based aggregation when
+ * the size of the internal hash map exceeds the threshold. The differences 
are:
+ *
+ *  - It uses safe rows as aggregation buffer since it must support JVM 
objects as aggregation
+ *    states.
+ *
+ *  - It tracks entry count of the hash map instead of byte size to decide 
when we should fall back.
+ *    This is because it's hard to estimate the accurate size of arbitrary JVM 
objects in a
+ *    lightweight way.
+ *
+ *  - Whenever fallen back to sort-based aggregation, this operator feeds all 
of the rest input rows
+ *    into external sorters instead of building more hash map(s) as what 
[[HashAggregateExec]] does.
+ *    This is because having too many JVM object aggregation states floating 
there can be dangerous
+ *    for GC.
+ *
+ *  - CodeGen is not supported yet.
+ *
+ * This operator may be turned off by setting the following SQL configuration 
to `false`:
+ * {{{
+ *   spark.sql.execution.useObjectHashAggregateExec
+ * }}}
+ * The fallback threshold can be configured by tuning:
+ * {{{
+ *   spark.sql.objectHashAggregate.sortBased.fallbackThreshold
+ * }}}
+ */
+case class ObjectHashAggregateExec(
+    requiredChildDistributionExpressions: Option[Seq[Expression]],
+    groupingExpressions: Seq[NamedExpression],
+    aggregateExpressions: Seq[AggregateExpression],
+    aggregateAttributes: Seq[Attribute],
+    initialInputBufferOffset: Int,
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryExecNode {
+
+  private[this] val aggregateBufferAttributes = {
+    aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+  }
+
+  override lazy val allAttributes: AttributeSeq =
+    child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
+      
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows")
+  )
+
+  override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def producedAttributes: AttributeSet =
+    AttributeSet(aggregateAttributes) ++
+    
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+    AttributeSet(aggregateBufferAttributes)
+
+  override def requiredChildDistribution: List[Distribution] = {
+    requiredChildDistributionExpressions match {
+      case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+      case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
+      case None => UnspecifiedDistribution :: Nil
+    }
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
+    val numOutputRows = longMetric("numOutputRows")
+    val fallbackCountThreshold = 
sqlContext.conf.objectAggSortBasedFallbackThreshold
+
+    child.execute().mapPartitionsInternal { iter =>
+      val hasInput = iter.hasNext
+      if (!hasInput && groupingExpressions.nonEmpty) {
+        // This is a grouped aggregate and the input kvIterator is empty,
+        // so return an empty kvIterator.
+        Iterator.empty
+      } else {
+        val aggregationIterator =
+          new ObjectAggregationIterator(
+            child.output,
+            groupingExpressions,
+            aggregateExpressions,
+            aggregateAttributes,
+            initialInputBufferOffset,
+            resultExpressions,
+            (expressions, inputSchema) =>
+              newMutableProjection(expressions, inputSchema, 
subexpressionEliminationEnabled),
+            child.output,
+            iter,
+            fallbackCountThreshold)
+        if (!hasInput && groupingExpressions.isEmpty) {
+          numOutputRows += 1
+          
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
+        } else {
+          aggregationIterator
+        }
+      }
+    }
+  }
+
+  override def verboseString: String = toString(verbose = true)
+
+  override def simpleString: String = toString(verbose = false)
+
+  private def toString(verbose: Boolean): String = {
+    val allAggregateExpressions = aggregateExpressions
+    val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]")
+    val functionString = Utils.truncatedString(allAggregateExpressions, "[", 
", ", "]")
+    val outputString = Utils.truncatedString(output, "[", ", ", "]")
+    if (verbose) {
+      s"ObjectHashAggregate(keys=$keyString, functions=$functionString, 
output=$outputString)"
+    } else {
+      s"ObjectHashAggregate(keys=$keyString, functions=$functionString)"
+    }
+  }
+}
+
+object ObjectHashAggregateExec {
+  def supportsAggregate(aggregateExpressions: Seq[AggregateExpression]): 
Boolean = {
+    aggregateExpressions.map(_.aggregateFunction).exists {
+      case _: TypedImperativeAggregate[_] => true
+      case _ => false
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 7b8ed65..71f3a67 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -526,6 +526,24 @@ object SQLConf {
       .stringConf
       .createWithDefault(classOf[ManifestFileCommitProtocol].getName)
 
+  val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD =
+    
SQLConfigBuilder("spark.sql.objectHashAggregate.sortBased.fallbackThreshold")
+      .internal()
+      .doc("In the case of ObjectHashAggregateExec, when the size of the 
in-memory hash map " +
+        "grows too large, we will fall back to sort-based aggregation. This 
option sets a row " +
+        "count threshold for the size of the hash map.")
+      .intConf
+      // We are trying to be conservative and use a relatively small default 
count threshold here
+      // since the state object of some TypedImperativeAggregate function can 
be quite large (e.g.
+      // percentile_approx).
+      .createWithDefault(128)
+
+  val USE_OBJECT_HASH_AGG = 
SQLConfigBuilder("spark.sql.execution.useObjectHashAggregateExec")
+    .internal()
+    .doc("Decides if we use ObjectHashAggregateExec")
+    .booleanConf
+    .createWithDefault(true)
+
   val FILE_SINK_LOG_DELETION = 
SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion")
     .internal()
     .doc("Whether to delete the expired log files in file stream sink.")
@@ -769,6 +787,10 @@ private[sql] class SQLConf extends Serializable with 
CatalystConf with Logging {
 
   def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP)
 
+  def useObjectHashAggregation: Boolean = getConf(USE_OBJECT_HASH_AGG)
+
+  def objectAggSortBasedFallbackThreshold: Int = 
getConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD)
+
   def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED)
 
   def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH)

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
index ffa26f1..0759915 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -23,7 +23,7 @@ import 
org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, 
GenericInternalRow, SpecificInternalRow}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
-import org.apache.spark.sql.execution.aggregate.SortAggregateExec
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -87,11 +87,11 @@ class TypedImperativeAggregateSuite extends QueryTest with 
SharedSQLContext {
 
   test("dataframe aggregate with object aggregate buffer, should not use 
HashAggregate") {
     val df = data.toDF("a", "b")
-    val max = new TypedMax($"a".expr)
+    val max = TypedMax($"a".expr)
 
     // Always uses SortAggregateExec
     val sparkPlan = 
df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan
-    assert(sparkPlan.isInstanceOf[SortAggregateExec])
+    assert(!sparkPlan.isInstanceOf[HashAggregateExec])
   }
 
   test("dataframe aggregate with object aggregate buffer, no group by") {

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
new file mode 100644
index 0000000..bc9cb6e
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import java.util.Properties
+
+import scala.collection.mutable
+
+import org.apache.spark._
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.KVIterator
+
+class SortBasedAggregationStoreSuite  extends SparkFunSuite with 
LocalSparkContext {
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    val conf = new SparkConf()
+    sc = new SparkContext("local[2, 4]", "test", conf)
+    val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0)
+    TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, 
new Properties, null))
+  }
+
+  override def afterAll(): Unit = TaskContext.unset()
+
+  private val rand = new java.util.Random()
+
+  // In this test, the aggregator is XOR checksum.
+  test("merge input kv iterator and aggregation buffer iterator") {
+
+    val inputSchema = StructType(Seq(StructField("a", IntegerType), 
StructField("b", IntegerType)))
+    val groupingSchema = StructType(Seq(StructField("b", IntegerType)))
+
+    // Schema: a: Int, b: Int
+    val inputRow: UnsafeRow = createUnsafeRow(2)
+
+    // Schema: group: Int
+    val group: UnsafeRow = createUnsafeRow(1)
+
+    val expected = new mutable.HashMap[Int, Int]()
+    val hashMap = new ObjectAggregationMap
+    (0 to 5000).foreach { _ =>
+      randomKV(inputRow, group)
+
+      // XOR aggregate on first column of input row
+      expected.put(group.getInt(0), expected.getOrElse(group.getInt(0), 0) ^ 
inputRow.getInt(0))
+      if (hashMap.getAggregationBuffer(group) == null) {
+        hashMap.putAggregationBuffer(group.copy, createNewAggregationBuffer())
+      }
+      updateInputRow(hashMap.getAggregationBuffer(group), inputRow)
+    }
+
+    val store = new SortBasedAggregator(
+      createSortedAggBufferIterator(hashMap),
+      inputSchema,
+      groupingSchema,
+      updateInputRow,
+      mergeAggBuffer,
+      createNewAggregationBuffer)
+
+    (5000 to 100000).foreach { _ =>
+      randomKV(inputRow, group)
+      // XOR aggregate on first column of input row
+      expected.put(group.getInt(0), expected.getOrElse(group.getInt(0), 0) ^ 
inputRow.getInt(0))
+      store.addInput(group, inputRow)
+    }
+
+    val iter = store.destructiveIterator()
+    while(iter.hasNext) {
+      val agg = iter.next()
+      assert(agg.aggregationBuffer.getInt(0) == 
expected(agg.groupingKey.getInt(0)))
+    }
+  }
+
+  private def createNewAggregationBuffer(): InternalRow = {
+    val buffer = createUnsafeRow(1)
+    buffer.setInt(0, 0)
+    buffer
+  }
+
+  private def updateInputRow: (InternalRow, InternalRow) => Unit = {
+    (buffer: InternalRow, input: InternalRow) => {
+      buffer.setInt(0, buffer.getInt(0) ^ input.getInt(0))
+    }
+  }
+
+  private def mergeAggBuffer: (InternalRow, InternalRow) => Unit = 
updateInputRow
+
+  private def createUnsafeRow(numOfField: Int): UnsafeRow = {
+    val buffer: Array[Byte] = new Array(1024)
+    val row: UnsafeRow = new UnsafeRow(numOfField)
+    row.pointTo(buffer, 1024)
+    row
+  }
+
+  private def randomKV(inputRow: UnsafeRow, group: UnsafeRow): Unit = {
+    inputRow.setInt(0, rand.nextInt(100000))
+    inputRow.setInt(1, rand.nextInt(10000))
+    group.setInt(0, inputRow.getInt(1) % 100)
+  }
+
+  def createSortedAggBufferIterator(
+      hashMap: ObjectAggregationMap): KVIterator[UnsafeRow, UnsafeRow] = {
+
+    val sortedIterator = 
hashMap.iterator.toList.sortBy(_.groupingKey.getInt(0)).iterator
+    new KVIterator[UnsafeRow, UnsafeRow] {
+      var key: UnsafeRow = null
+      var value: UnsafeRow = null
+      override def next: Boolean = {
+        if (sortedIterator.hasNext) {
+          val kv = sortedIterator.next()
+          key = kv.groupingKey
+          value = kv.aggregationBuffer.asInstanceOf[UnsafeRow]
+          true
+        } else {
+          false
+        }
+      }
+      override def getKey(): UnsafeRow = key
+      override def getValue(): UnsafeRow = value
+      override def close(): Unit = Unit
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
new file mode 100644
index 0000000..197110f
--- /dev/null
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import scala.concurrent.duration._
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo, Literal}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
+import org.apache.spark.sql.hive.HiveSessionCatalog
+import org.apache.spark.sql.hive.execution.TestingTypedCount
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.LongType
+import org.apache.spark.util.Benchmark
+
+class ObjectHashAggregateExecBenchmark extends BenchmarkBase with 
TestHiveSingleton {
+  ignore("Hive UDAF vs Spark AF") {
+    val N = 2 << 15
+
+    val benchmark = new Benchmark(
+      name = "hive udaf vs spark af",
+      valuesPerIteration = N,
+      minNumIters = 5,
+      warmupTime = 5.seconds,
+      minTime = 10.seconds,
+      outputPerIteration = true
+    )
+
+    registerHiveFunction("hive_percentile_approx", 
classOf[GenericUDAFPercentileApprox])
+
+    sparkSession.range(N).createOrReplaceTempView("t")
+
+    benchmark.addCase("hive udaf w/o group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false")
+      sparkSession.sql("SELECT hive_percentile_approx(id, 0.5) FROM 
t").collect()
+    }
+
+    benchmark.addCase("spark af w/o group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      sparkSession.sql("SELECT percentile_approx(id, 0.5) FROM t").collect()
+    }
+
+    benchmark.addCase("hive udaf w/ group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false")
+      sparkSession.sql(
+        s"SELECT hive_percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N 
/ 4} AS BIGINT)"
+      ).collect()
+    }
+
+    benchmark.addCase("spark af w/ group by w/o fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      sparkSession.sql(
+        s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} 
AS BIGINT)"
+      ).collect()
+    }
+
+    benchmark.addCase("spark af w/ group by w/ fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      
sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2")
+      sparkSession.sql(
+        s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} 
AS BIGINT)"
+      ).collect()
+    }
+
+    benchmark.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+
+    hive udaf vs spark af:                   Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+    
------------------------------------------------------------------------------------------------
+    hive udaf w/o group by                        5326 / 5408          0.0     
  81264.2       1.0X
+    spark af w/o group by                           93 /  111          0.7     
   1415.6      57.4X
+    hive udaf w/ group by                         3804 / 3946          0.0     
  58050.1       1.4X
+    spark af w/ group by w/o fallback               71 /   90          0.9     
   1085.7      74.8X
+    spark af w/ group by w/ fallback                98 /  111          0.7     
   1501.6      54.1X
+     */
+  }
+
+  ignore("ObjectHashAggregateExec vs SortAggregateExec - typed_count") {
+    val N: Long = 1024 * 1024 * 100
+
+    val benchmark = new Benchmark(
+      name = "object agg v.s. sort agg",
+      valuesPerIteration = N,
+      minNumIters = 1,
+      warmupTime = 10.seconds,
+      minTime = 45.seconds,
+      outputPerIteration = true
+    )
+
+    import sparkSession.implicits._
+
+    def typed_count(column: Column): Column =
+      Column(TestingTypedCount(column.expr).toAggregateExpression())
+
+    val df = sparkSession.range(N)
+
+    benchmark.addCase("sort agg w/ group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false")
+      df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect()
+    }
+
+    benchmark.addCase("object agg w/ group by w/o fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect()
+    }
+
+    benchmark.addCase("object agg w/ group by w/ fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      
sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2")
+      df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect()
+    }
+
+    benchmark.addCase("sort agg w/o group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false")
+      df.select(typed_count($"id")).collect()
+    }
+
+    benchmark.addCase("object agg w/o group by w/o fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      df.select(typed_count($"id")).collect()
+    }
+
+    benchmark.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+
+    object agg v.s. sort agg:                Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+    
------------------------------------------------------------------------------------------------
+    sort agg w/ group by                        31251 / 31908          3.4     
    298.0       1.0X
+    object agg w/ group by w/o fallback           6903 / 7141         15.2     
     65.8       4.5X
+    object agg w/ group by w/ fallback          20945 / 21613          5.0     
    199.7       1.5X
+    sort agg w/o group by                         4734 / 5463         22.1     
     45.2       6.6X
+    object agg w/o group by w/o fallback          4310 / 4529         24.3     
     41.1       7.3X
+     */
+  }
+
+  ignore("ObjectHashAggregateExec vs SortAggregateExec - percentile_approx") {
+    val N = 2 << 20
+
+    val benchmark = new Benchmark(
+      name = "object agg v.s. sort agg",
+      valuesPerIteration = N,
+      minNumIters = 5,
+      warmupTime = 15.seconds,
+      minTime = 45.seconds,
+      outputPerIteration = true
+    )
+
+    import sparkSession.implicits._
+
+    val df = sparkSession.range(N).coalesce(1)
+
+    benchmark.addCase("sort agg w/ group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false")
+      df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 
0.5)).collect()
+    }
+
+    benchmark.addCase("object agg w/ group by w/o fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 
0.5)).collect()
+    }
+
+    benchmark.addCase("object agg w/ group by w/ fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      
sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2")
+      df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 
0.5)).collect()
+    }
+
+    benchmark.addCase("sort agg w/o group by") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false")
+      df.select(percentile_approx($"id", 0.5)).collect()
+    }
+
+    benchmark.addCase("object agg w/o group by w/o fallback") { _ =>
+      sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true")
+      df.select(percentile_approx($"id", 0.5)).collect()
+    }
+
+    benchmark.run()
+
+    /*
+    Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5
+    Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
+
+    object agg v.s. sort agg:                Best/Avg Time(ms)    Rate(M/s)   
Per Row(ns)   Relative
+    
------------------------------------------------------------------------------------------------
+    sort agg w/ group by                          3418 / 3530          0.6     
   1630.0       1.0X
+    object agg w/ group by w/o fallback           3210 / 3314          0.7     
   1530.7       1.1X
+    object agg w/ group by w/ fallback            3419 / 3511          0.6     
   1630.1       1.0X
+    sort agg w/o group by                         4336 / 4499          0.5     
   2067.3       0.8X
+    object agg w/o group by w/o fallback          4271 / 4372          0.5     
   2036.7       0.8X
+     */
+  }
+
+  private def registerHiveFunction(functionName: String, clazz: Class[_]): 
Unit = {
+    val sessionCatalog = 
sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+    val builder = sessionCatalog.makeFunctionBuilder(functionName, 
clazz.getName)
+    val info = new ExpressionInfo(clazz.getName, functionName)
+    sessionCatalog.createTempFunction(functionName, info, builder, 
ignoreIfExists = false)
+  }
+
+  private def percentile_approx(
+      column: Column, percentage: Double, isDistinct: Boolean = false): Column 
= {
+    val approxPercentile = new ApproximatePercentile(column.expr, 
Literal(percentage))
+    Column(approxPercentile.toAggregateExpression(isDistinct))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
new file mode 100644
index 0000000..527626b
--- /dev/null
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala
@@ -0,0 +1,433 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.util.Random
+
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax
+import org.scalatest.Matchers._
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction
+import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, 
ExpressionInfo, Literal}
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.HiveSessionCatalog
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types._
+
+class ObjectHashAggregateSuite
+  extends QueryTest
+  with SQLTestUtils
+  with TestHiveSingleton
+  with ExpressionEvalHelper {
+
+  import testImplicits._
+
+  test("typed_count without grouping keys") {
+    val df = Seq((1: Integer, 2), (null, 2), (3: Integer, 4)).toDF("a", "b")
+
+    checkAnswer(
+      df.coalesce(1).select(typed_count($"a")),
+      Seq(Row(2))
+    )
+  }
+
+  test("typed_count without grouping keys and empty input") {
+    val df = Seq.empty[(Integer, Int)].toDF("a", "b")
+
+    checkAnswer(
+      df.coalesce(1).select(typed_count($"a")),
+      Seq(Row(0))
+    )
+  }
+
+  test("typed_count with grouping keys") {
+    val df = Seq((1: Integer, 1), (null, 1), (2: Integer, 2)).toDF("a", "b")
+
+    checkAnswer(
+      df.coalesce(1).groupBy($"b").agg(typed_count($"a")),
+      Seq(
+        Row(1, 1),
+        Row(2, 1))
+    )
+  }
+
+  test("typed_count fallback to sort-based aggregation") {
+    withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "2") {
+      val df = Seq(
+        (null, 1),
+        (null, 1),
+        (1: Integer, 1),
+        (2: Integer, 2),
+        (2: Integer, 2),
+        (2: Integer, 2)
+      ).toDF("a", "b")
+
+      checkAnswer(
+        df.coalesce(1).groupBy($"b").agg(typed_count($"a")),
+        Seq(Row(1, 1), Row(2, 3))
+      )
+    }
+  }
+
+  test("random input data types") {
+    val dataTypes = Seq(
+      // Integral types
+      ByteType, ShortType, IntegerType, LongType,
+
+      // Fractional types
+      FloatType, DoubleType,
+
+      // Decimal types
+      DecimalType(25, 5), DecimalType(6, 5),
+
+      // Datetime types
+      DateType, TimestampType,
+
+      // Complex types
+      ArrayType(IntegerType),
+      MapType(DoubleType, LongType),
+      new StructType()
+        .add("f1", FloatType, nullable = true)
+        .add("f2", ArrayType(BooleanType), nullable = true),
+
+      // UDT
+      new UDT.MyDenseVectorUDT(),
+
+      // Others
+      StringType,
+      BinaryType, NullType, BooleanType
+    )
+
+    dataTypes.sliding(2, 1).map(_.toSeq).foreach { dataTypes =>
+      // Schema used to generate random input data.
+      val schemaForGenerator = StructType(dataTypes.zipWithIndex.map {
+        case (fieldType, index) =>
+          StructField(s"col_$index", fieldType, nullable = true)
+      })
+
+      // Schema of the DataFrame to be tested.
+      val schema = StructType(
+        StructField("id", IntegerType, nullable = false) +: 
schemaForGenerator.fields
+      )
+
+      logInfo(s"Testing schema:\n${schema.treeString}")
+
+      // Creates a DataFrame for the schema with random data.
+      val data = generateRandomRows(schemaForGenerator)
+      val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), 
schema)
+      val aggFunctions = schema.fieldNames.map(f => typed_count(col(f)))
+
+      checkAnswer(
+        df.agg(aggFunctions.head, aggFunctions.tail: _*),
+        Row.fromSeq(data.map(_.toSeq).transpose.map(_.count(_ != null): Long))
+      )
+
+      checkAnswer(
+        df.groupBy($"id" % 4 as 'mod).agg(aggFunctions.head, 
aggFunctions.tail: _*),
+        data.groupBy(_.getInt(0) % 4).map { case (key, value) =>
+          key -> Row.fromSeq(value.map(_.toSeq).transpose.map(_.count(_ != 
null): Long))
+        }.toSeq.map {
+          case (key, value) => Row.fromSeq(key +: value.toSeq)
+        }
+      )
+
+      withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "5") 
{
+        checkAnswer(
+          df.agg(aggFunctions.head, aggFunctions.tail: _*),
+          Row.fromSeq(data.map(_.toSeq).transpose.map(_.count(_ != null): 
Long))
+        )
+      }
+    }
+  }
+
+  private def percentile_approx(
+      column: Column, percentage: Double, isDistinct: Boolean = false): Column 
= {
+    val approxPercentile = new ApproximatePercentile(column.expr, 
Literal(percentage))
+    Column(approxPercentile.toAggregateExpression(isDistinct))
+  }
+
+  private def typed_count(column: Column): Column =
+    Column(TestingTypedCount(column.expr).toAggregateExpression())
+
+  // Generates 50 random rows for a given schema.
+  private def generateRandomRows(schemaForGenerator: StructType): Seq[Row] = {
+    val dataGenerator = RandomDataGenerator.forType(
+      dataType = schemaForGenerator,
+      nullable = true,
+      new Random(System.nanoTime())
+    ).getOrElse {
+      fail(s"Failed to create data generator for schema $schemaForGenerator")
+    }
+
+    (1 to 50).map { i =>
+      dataGenerator() match {
+        case row: Row => Row.fromSeq(i +: row.toSeq)
+        case null => Row.fromSeq(i +: 
Seq.fill(schemaForGenerator.length)(null))
+        case other => fail(
+          s"Row or null is expected to be generated, " +
+            s"but a ${other.getClass.getCanonicalName} is generated."
+        )
+      }
+    }
+  }
+
+  makeRandomizedTests()
+
+  private def makeRandomizedTests(): Unit = {
+    // A TypedImperativeAggregate function
+    val typed = percentile_approx($"c0", 0.5)
+
+    // A Hive UDAF without partial aggregation support
+    val withoutPartial = {
+      registerHiveFunction("hive_max", classOf[GenericUDAFMax])
+      function("hive_max", $"c1")
+    }
+
+    // A Spark SQL native aggregate function with partial aggregation support 
that can be executed
+    // by the Tungsten `HashAggregateExec`
+    val withPartialUnsafe = max($"c2")
+
+    // A Spark SQL native aggregate function with partial aggregation support 
that can only be
+    // executed by the Tungsten `HashAggregateExec`
+    val withPartialSafe = max($"c3")
+
+    // A Spark SQL native distinct aggregate function
+    val withDistinct = countDistinct($"c4")
+
+    val allAggs = Seq(
+      "typed" -> typed,
+      "without partial" -> withoutPartial,
+      "with partial + unsafe" -> withPartialUnsafe,
+      "with partial + safe" -> withPartialSafe,
+      "with distinct" -> withDistinct
+    )
+
+    val builtinNumericTypes = Seq(
+      // Integral types
+      ByteType, ShortType, IntegerType, LongType,
+
+      // Fractional types
+      FloatType, DoubleType
+    )
+
+    val numericTypes = builtinNumericTypes ++ Seq(
+      // Decimal types
+      DecimalType(25, 5), DecimalType(6, 5)
+    )
+
+    val dateTimeTypes = Seq(DateType, TimestampType)
+
+    val arrayType = ArrayType(IntegerType)
+
+    val structType = new StructType()
+      .add("f1", FloatType, nullable = true)
+      .add("f2", ArrayType(BooleanType), nullable = true)
+
+    val mapType = MapType(DoubleType, LongType)
+
+    val complexTypes = Seq(arrayType, mapType, structType)
+
+    val orderedComplexType = Seq(arrayType, structType)
+
+    val orderedTypes = numericTypes ++ dateTimeTypes ++ orderedComplexType ++ 
Seq(
+      StringType, BinaryType, NullType, BooleanType
+    )
+
+    val udt = new UDT.MyDenseVectorUDT()
+
+    val fixedLengthTypes = builtinNumericTypes ++ Seq(BooleanType, NullType)
+
+    val varLenTypes = complexTypes ++ Seq(StringType, BinaryType, udt)
+
+    val varLenOrderedTypes = varLenTypes.intersect(orderedTypes)
+
+    val allTypes = orderedTypes :+ udt
+
+    val seed = System.nanoTime()
+    val random = new Random(seed)
+
+    logInfo(s"Using random seed $seed")
+
+    // Generates a random schema for the randomized data generator
+    val schema = new StructType()
+      .add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = 
true)
+      .add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = 
true)
+      .add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), 
nullable = true)
+      .add("c3", 
varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true)
+      .add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true)
+
+    logInfo(
+      s"""Using the following random schema to generate all the randomized 
aggregation tests:
+         |
+         |${schema.treeString}
+       """.stripMargin
+    )
+
+    // Builds a randomly generated DataFrame
+    val schemaWithId = StructType(StructField("id", IntegerType, nullable = 
false) +: schema.fields)
+    val data = generateRandomRows(schema)
+    val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), 
schemaWithId)
+
+    // Tests all combinations of length 1 to 5 types of aggregate functions
+    (1 to allAggs.length) foreach { i =>
+      allAggs.combinations(i) foreach { targetAggs =>
+        val (names, aggs) = targetAggs.unzip
+
+        // Tests aggregation of w/ and w/o grouping keys
+        Seq(true, false).foreach { withGroupingKeys =>
+
+          // Tests aggregation with empty and non-empty input rows
+          Seq(true, false).foreach { emptyInput =>
+
+            // Builds the aggregation to be tested according to different 
configurations
+            def doAggregation(df: DataFrame): DataFrame = {
+              val baseDf = if (emptyInput) {
+                val emptyRows = spark.sparkContext.parallelize(Seq.empty[Row], 
1)
+                spark.createDataFrame(emptyRows, schemaWithId)
+              } else {
+                df
+              }
+
+              if (withGroupingKeys) {
+                baseDf
+                  .groupBy($"id" % 10 as "group")
+                  .agg(aggs.head, aggs.tail: _*)
+                  .orderBy("group")
+              } else {
+                baseDf.agg(aggs.head, aggs.tail: _*)
+              }
+            }
+
+            // Currently Spark SQL doesn't support evaluating distinct 
aggregate function together
+            // with aggregate functions without partial aggregation support.
+            if (!(aggs.contains(withoutPartial) && 
aggs.contains(withDistinct))) {
+              test(
+                s"randomized aggregation test - " +
+                  s"${names.mkString("[", ", ", "]")} - " +
+                  s"${if (withGroupingKeys) "with" else "without"} grouping 
keys - " +
+                  s"with ${if (emptyInput) "empty" else "non-empty"} input"
+              ) {
+                var expected: Seq[Row] = null
+                var actual1: Seq[Row] = null
+                var actual2: Seq[Row] = null
+
+                // Disables `ObjectHashAggregateExec` to obtain a standard 
answer
+                withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") {
+                  val aggDf = doAggregation(df)
+
+                  if (aggs.intersect(Seq(withoutPartial, withPartialSafe, 
typed)).nonEmpty) {
+                    assert(containsSortAggregateExec(aggDf))
+                    assert(!containsObjectHashAggregateExec(aggDf))
+                    assert(!containsHashAggregateExec(aggDf))
+                  } else {
+                    assert(!containsSortAggregateExec(aggDf))
+                    assert(!containsObjectHashAggregateExec(aggDf))
+                    assert(containsHashAggregateExec(aggDf))
+                  }
+
+                  expected = aggDf.collect().toSeq
+                }
+
+                // Enables `ObjectHashAggregateExec`
+                withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") {
+                  val aggDf = doAggregation(df)
+
+                  if (aggs.contains(typed) && !aggs.contains(withoutPartial)) {
+                    assert(!containsSortAggregateExec(aggDf))
+                    assert(containsObjectHashAggregateExec(aggDf))
+                    assert(!containsHashAggregateExec(aggDf))
+                  } else if (aggs.intersect(Seq(withoutPartial, 
withPartialSafe)).nonEmpty) {
+                    assert(containsSortAggregateExec(aggDf))
+                    assert(!containsObjectHashAggregateExec(aggDf))
+                    assert(!containsHashAggregateExec(aggDf))
+                  } else {
+                    assert(!containsSortAggregateExec(aggDf))
+                    assert(!containsObjectHashAggregateExec(aggDf))
+                    assert(containsHashAggregateExec(aggDf))
+                  }
+
+                  // Disables sort-based aggregation fallback (we only 
generate 50 rows, so 100 is
+                  // big enough) to obtain a result to be checked.
+                  
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") {
+                    actual1 = aggDf.collect().toSeq
+                  }
+
+                  // Enables sort-based aggregation fallback to obtain another 
result to be checked.
+                  
withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") {
+                    // Here we are not reusing `aggDf` because the physical 
plan in `aggDf` is
+                    // cached and won't be re-planned using the new fallback 
threshold.
+                    actual2 = doAggregation(df).collect().toSeq
+                  }
+                }
+
+                doubleSafeCheckRows(actual1, expected, 1e-4)
+                doubleSafeCheckRows(actual2, expected, 1e-4)
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
+  private def containsSortAggregateExec(df: DataFrame): Boolean = {
+    df.queryExecution.executedPlan.collectFirst {
+      case _: SortAggregateExec => ()
+    }.nonEmpty
+  }
+
+  private def containsObjectHashAggregateExec(df: DataFrame): Boolean = {
+    df.queryExecution.executedPlan.collectFirst {
+      case _: ObjectHashAggregateExec => ()
+    }.nonEmpty
+  }
+
+  private def containsHashAggregateExec(df: DataFrame): Boolean = {
+    df.queryExecution.executedPlan.collectFirst {
+      case _: HashAggregateExec => ()
+    }.nonEmpty
+  }
+
+  private def doubleSafeCheckRows(actual: Seq[Row], expected: Seq[Row], 
tolerance: Double): Unit = {
+    assert(actual.length == expected.length)
+    actual.zip(expected).foreach { case (lhs: Row, rhs: Row) =>
+      assert(lhs.length == rhs.length)
+      lhs.toSeq.zip(rhs.toSeq).foreach {
+        case (a: Double, b: Double) => checkResult(a, b +- tolerance)
+        case (a, b) => checkResult(a, b)
+      }
+    }
+  }
+
+  private def registerHiveFunction(functionName: String, clazz: Class[_]): 
Unit = {
+    val sessionCatalog = 
spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
+    val builder = sessionCatalog.makeFunctionBuilder(functionName, 
clazz.getName)
+    val info = new ExpressionInfo(clazz.getName, functionName)
+    sessionCatalog.createTempFunction(functionName, info, builder, 
ignoreIfExists = false)
+  }
+
+  private def function(name: String, args: Column*): Column = {
+    Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), 
isDistinct = false))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
new file mode 100644
index 0000000..a3d48d9
--- /dev/null
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, 
TypedImperativeAggregate}
+import org.apache.spark.sql.hive.execution.TestingTypedCount.State
+import org.apache.spark.sql.types._
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - A testing aggregate function resembles COUNT " +
+          "but implements ObjectAggregateFunction.")
+case class TestingTypedCount(
+    child: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[TestingTypedCount.State] {
+
+  def this(child: Expression) = this(child, 0, 0)
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def dataType: DataType = LongType
+
+  override def nullable: Boolean = false
+
+  override val supportsPartial: Boolean = true
+
+  override def createAggregationBuffer(): State = TestingTypedCount.State(0L)
+
+  override def update(buffer: State, input: InternalRow): Unit = {
+    if (child.eval(input) != null) {
+      buffer.count += 1
+    }
+  }
+
+  override def merge(buffer: State, input: State): Unit = {
+    buffer.count += input.count
+  }
+
+  override def eval(buffer: State): Any = buffer.count
+
+  override def serialize(buffer: State): Array[Byte] = {
+    val byteStream = new ByteArrayOutputStream()
+    val dataStream = new DataOutputStream(byteStream)
+    dataStream.writeLong(buffer.count)
+    byteStream.toByteArray
+  }
+
+  override def deserialize(storageFormat: Array[Byte]): State = {
+    val byteStream = new ByteArrayInputStream(storageFormat)
+    val dataStream = new DataInputStream(byteStream)
+    TestingTypedCount.State(dataStream.readLong())
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = AnyDataType :: Nil
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): 
ImperativeAggregate =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override val prettyName: String = "typed_count"
+}
+
+object TestingTypedCount {
+  case class State(var count: Long)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to