Repository: spark Updated Branches: refs/heads/master 66a99f4a4 -> 27daf6bcd
[SPARK-17949][SQL] A JVM object based aggregate operator ## What changes were proposed in this pull request? This PR adds a new hash-based aggregate operator named `ObjectHashAggregateExec` that supports `TypedImperativeAggregate`, which may use arbitrary Java objects as aggregation states. Please refer to the [design doc](https://issues.apache.org/jira/secure/attachment/12834260/%5BDesign%20Doc%5D%20Support%20for%20Arbitrary%20Aggregation%20States.pdf) attached in [SPARK-17949](https://issues.apache.org/jira/browse/SPARK-17949) for more details about it. The major benefit of this operator is better performance when evaluating `TypedImperativeAggregate` functions, especially when there are relatively few distinct groups. Functions like Hive UDAFs, `collect_list`, and `collect_set` may also benefit from this after being migrated to `TypedImperativeAggregate`. The following feature flag is introduced to enable or disable the new aggregate operator: - Name: `spark.sql.execution.useObjectHashAggregateExec` - Default value: `true` We can also configure the fallback threshold using the following SQL operation: - Name: `spark.sql.objectHashAggregate.sortBased.fallbackThreshold` - Default value: 128 Fallback to sort-based aggregation when more than 128 distinct groups are accumulated in the aggregation hash map. This number is intentionally made small to avoid GC problems since aggregation buffers of this operator may contain arbitrary Java objects. This may be improved by implementing size tracking for this operator, but that can be done in a separate PR. Code generation and size tracking are planned to be implemented in follow-up PRs. ## Benchmark results ### `ObjectHashAggregateExec` vs `SortAggregateExec` The first benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` by evaluating `typed_count`, a testing `TypedImperativeAggregate` version of the SQL `count` function. ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ sort agg w/ group by 31251 / 31908 3.4 298.0 1.0X object agg w/ group by w/o fallback 6903 / 7141 15.2 65.8 4.5X object agg w/ group by w/ fallback 20945 / 21613 5.0 199.7 1.5X sort agg w/o group by 4734 / 5463 22.1 45.2 6.6X object agg w/o group by w/o fallback 4310 / 4529 24.3 41.1 7.3X ``` The next benchmark compares `ObjectHashAggregateExec` and `SortAggregateExec` by evaluating the Spark native version of `percentile_approx`. Note that `percentile_approx` is so heavy an aggregate function that the bottleneck of the benchmark is evaluating the aggregate function itself rather than the aggregate operator since I couldn't run a large scale benchmark on my laptop. That's why the results are so close and looks counter-intuitive (aggregation with grouping is even faster than that aggregation without grouping). ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ sort agg w/ group by 3418 / 3530 0.6 1630.0 1.0X object agg w/ group by w/o fallback 3210 / 3314 0.7 1530.7 1.1X object agg w/ group by w/ fallback 3419 / 3511 0.6 1630.1 1.0X sort agg w/o group by 4336 / 4499 0.5 2067.3 0.8X object agg w/o group by w/o fallback 4271 / 4372 0.5 2036.7 0.8X ``` ### Hive UDAF vs Spark AF This benchmark compares the following two kinds of aggregate functions: - "hive udaf": Hive implementation of `percentile_approx`, without partial aggregation supports, evaluated using `SortAggregateExec`. - "spark af": Spark native implementation of `percentile_approx`, with partial aggregation support, evaluated using `ObjectHashAggregateExec` The performance differences are mostly due to faster implementation and partial aggregation support in the Spark native version of `percentile_approx`. This benchmark basically shows the performance differences between the worst case, where an aggregate function without partial aggregation support is evaluated using `SortAggregateExec`, and the best case, where a `TypedImperativeAggregate` with partial aggregation support is evaluated using `ObjectHashAggregateExec`. ``` Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ hive udaf w/o group by 5326 / 5408 0.0 81264.2 1.0X spark af w/o group by 93 / 111 0.7 1415.6 57.4X hive udaf w/ group by 3804 / 3946 0.0 58050.1 1.4X spark af w/ group by w/o fallback 71 / 90 0.9 1085.7 74.8X spark af w/ group by w/ fallback 98 / 111 0.7 1501.6 54.1X ``` ### Real world benchmark We also did a relatively large benchmark using a real world query involving `percentile_approx`: - Hive UDAF implementation, sort-based aggregation, w/o partial aggregation support 24.77 minutes - Native implementation, sort-based aggregation, w/ partial aggregation support 4.64 minutes - Native implementation, object hash aggregator, w/ partial aggregation support 1.80 minutes ## How was this patch tested? New unit tests and randomized test cases are added in `ObjectAggregateFunctionSuite`. Author: Cheng Lian <l...@databricks.com> Closes #15590 from liancheng/obj-hash-agg. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27daf6bc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27daf6bc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27daf6bc Branch: refs/heads/master Commit: 27daf6bcde782ed3e0f0d951c90c8040fd47e985 Parents: 66a99f4 Author: Cheng Lian <l...@databricks.com> Authored: Thu Nov 3 09:34:51 2016 -0700 Committer: Yin Huai <yh...@databricks.com> Committed: Thu Nov 3 09:34:51 2016 -0700 ---------------------------------------------------------------------- .../sql/execution/aggregate/AggUtils.scala | 31 +- .../aggregate/ObjectAggregationIterator.scala | 323 ++++++++++++++ .../aggregate/ObjectAggregationMap.scala | 110 +++++ .../aggregate/ObjectHashAggregateExec.scala | 155 +++++++ .../org/apache/spark/sql/internal/SQLConf.scala | 22 + .../sql/TypedImperativeAggregateSuite.scala | 6 +- .../SortBasedAggregationStoreSuite.scala | 141 ++++++ .../ObjectHashAggregateExecBenchmark.scala | 230 ++++++++++ .../execution/ObjectHashAggregateSuite.scala | 433 +++++++++++++++++++ .../sql/hive/execution/TestingTypedCount.scala | 87 ++++ 10 files changed, 1527 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 4fbb9d5..3c8ef1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} +import org.apache.spark.sql.internal.SQLConf /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -66,14 +67,28 @@ object AggUtils { resultExpressions = resultExpressions, child = child) } else { - SortAggregateExec( - requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - initialInputBufferOffset = initialInputBufferOffset, - resultExpressions = resultExpressions, - child = child) + val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation + val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions) + + if (objectHashEnabled && useObjectHash) { + ObjectHashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + SortAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala new file mode 100644 index 0000000..3c7b9ee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -0,0 +1,323 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{BaseOrdering, GenerateOrdering} +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.KVIterator +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +class ObjectAggregationIterator( + outputAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, + originalInputAttributes: Seq[Attribute], + inputRows: Iterator[InternalRow], + fallbackCountThreshold: Int) + extends AggregationIterator( + groupingExpressions, + originalInputAttributes, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + newMutableProjection) with Logging { + + // Indicates whether we have fallen back to sort-based aggregation or not. + private[this] var sortBased: Boolean = false + + private[this] var aggBufferIterator: Iterator[AggregationBufferEntry] = _ + + // Hacking the aggregation mode to call AggregateFunction.merge to merge two aggregation buffers + private val mergeAggregationBuffers: (InternalRow, InternalRow) => Unit = { + val newExpressions = aggregateExpressions.map { + case agg @ AggregateExpression(_, Partial, _, _) => + agg.copy(mode = PartialMerge) + case agg @ AggregateExpression(_, Complete, _, _) => + agg.copy(mode = Final) + case other => other + } + val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) + generateProcessRow(newExpressions, newFunctions, newInputAttributes) + } + + // A safe projection used to do deep clone of input rows to prevent false sharing. + private[this] val safeProjection: Projection = + FromUnsafeProjection(outputAttributes.map(_.dataType)) + + /** + * Start processing input rows. + */ + processInputs() + + override final def hasNext: Boolean = { + aggBufferIterator.hasNext + } + + override final def next(): UnsafeRow = { + val entry = aggBufferIterator.next() + generateOutput(entry.groupingKey, entry.aggregationBuffer) + } + + /** + * Generate an output row when there is no input and there is no grouping expression. + */ + def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = { + if (groupingExpressions.isEmpty) { + val defaultAggregationBuffer = createNewAggregationBuffer() + generateOutput(UnsafeRow.createFromByteArray(0, 0), defaultAggregationBuffer) + } else { + throw new IllegalStateException( + "This method should not be called when groupingExpressions is not empty.") + } + } + + // Creates a new aggregation buffer and initializes buffer values. This function should only be + // called under two cases: + // + // - when creating aggregation buffer for a new group in the hash map, and + // - when creating the re-used buffer for sort-based aggregation + private def createNewAggregationBuffer(): SpecificInternalRow = { + val bufferFieldTypes = aggregateFunctions.flatMap(_.aggBufferAttributes.map(_.dataType)) + val buffer = new SpecificInternalRow(bufferFieldTypes) + initAggregationBuffer(buffer) + buffer + } + + private def initAggregationBuffer(buffer: SpecificInternalRow): Unit = { + // Initializes declarative aggregates' buffer values + expressionAggInitialProjection.target(buffer)(EmptyRow) + // Initializes imperative aggregates' buffer values + aggregateFunctions.collect { case f: ImperativeAggregate => f }.foreach(_.initialize(buffer)) + } + + private def getAggregationBufferByKey( + hashMap: ObjectAggregationMap, groupingKey: UnsafeRow): InternalRow = { + var aggBuffer = hashMap.getAggregationBuffer(groupingKey) + + if (aggBuffer == null) { + aggBuffer = createNewAggregationBuffer() + hashMap.putAggregationBuffer(groupingKey.copy(), aggBuffer) + } + + aggBuffer + } + + // This function is used to read and process input rows. When processing input rows, it first uses + // hash-based aggregation by putting groups and their buffers in `hashMap`. If `hashMap` grows too + // large, it sorts the contents, spills them to disk, and creates a new map. At last, all sorted + // spills are merged together for sort-based aggregation. + private def processInputs(): Unit = { + // In-memory map to store aggregation buffer for hash-based aggregation. + val hashMap = new ObjectAggregationMap() + + // If in-memory map is unable to stores all aggregation buffer, fallback to sort-based + // aggregation backed by sorted physical storage. + var sortBasedAggregationStore: SortBasedAggregator = null + + if (groupingExpressions.isEmpty) { + // If there is no grouping expressions, we can just reuse the same buffer over and over again. + val groupingKey = groupingProjection.apply(null) + val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) + while (inputRows.hasNext) { + val newInput = safeProjection(inputRows.next()) + processRow(buffer, newInput) + } + } else { + while (inputRows.hasNext && !sortBased) { + val newInput = safeProjection(inputRows.next()) + val groupingKey = groupingProjection.apply(newInput) + val buffer: InternalRow = getAggregationBufferByKey(hashMap, groupingKey) + processRow(buffer, newInput) + + // The the hash map gets too large, makes a sorted spill and clear the map. + if (hashMap.size >= fallbackCountThreshold) { + logInfo( + s"Aggregation hash map reaches threshold " + + s"capacity ($fallbackCountThreshold entries), spilling and falling back to sort" + + s" based aggregation. You may change the threshold by adjust option " + + SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key + ) + + // Falls back to sort-based aggregation + sortBased = true + + } + } + + if (sortBased) { + val sortIteratorFromHashMap = hashMap + .dumpToExternalSorter(groupingAttributes, aggregateFunctions) + .sortedIterator() + sortBasedAggregationStore = new SortBasedAggregator( + sortIteratorFromHashMap, + StructType.fromAttributes(originalInputAttributes), + StructType.fromAttributes(groupingAttributes), + processRow, + mergeAggregationBuffers, + createNewAggregationBuffer()) + + while (inputRows.hasNext) { + // NOTE: The input row is always UnsafeRow + val unsafeInputRow = inputRows.next().asInstanceOf[UnsafeRow] + val groupingKey = groupingProjection.apply(unsafeInputRow) + sortBasedAggregationStore.addInput(groupingKey, unsafeInputRow) + } + } + } + + if (sortBased) { + aggBufferIterator = sortBasedAggregationStore.destructiveIterator() + } else { + aggBufferIterator = hashMap.iterator + } + } +} + +/** + * A class used to handle sort-based aggregation, used together with [[ObjectHashAggregateExec]]. + * + * @param initialAggBufferIterator iterator that points to sorted input aggregation buffers + * @param inputSchema The schema of input row + * @param groupingSchema The schema of grouping key + * @param processRow Function to update the aggregation buffer with input rows + * @param mergeAggregationBuffers Function used to merge the input aggregation buffers into existing + * aggregation buffers + * @param makeEmptyAggregationBuffer Creates an empty aggregation buffer + * + * @todo Try to eliminate this class by refactor and reuse code paths in [[SortAggregateExec]]. + */ +class SortBasedAggregator( + initialAggBufferIterator: KVIterator[UnsafeRow, UnsafeRow], + inputSchema: StructType, + groupingSchema: StructType, + processRow: (InternalRow, InternalRow) => Unit, + mergeAggregationBuffers: (InternalRow, InternalRow) => Unit, + makeEmptyAggregationBuffer: => InternalRow) { + + // external sorter to sort the input (grouping key + input row) with grouping key. + private val inputSorter = createExternalSorterForInput() + private val groupingKeyOrdering: BaseOrdering = GenerateOrdering.create(groupingSchema) + + def addInput(groupingKey: UnsafeRow, inputRow: UnsafeRow): Unit = { + inputSorter.insertKV(groupingKey, inputRow) + } + + /** + * Returns a destructive iterator of AggregationBufferEntry. + * Notice: it is illegal to call any method after `destructiveIterator()` has been called. + */ + def destructiveIterator(): Iterator[AggregationBufferEntry] = { + new Iterator[AggregationBufferEntry] { + val inputIterator = inputSorter.sortedIterator() + var hasNextInput: Boolean = inputIterator.next() + var hasNextAggBuffer: Boolean = initialAggBufferIterator.next() + private var result: AggregationBufferEntry = _ + private var groupingKey: UnsafeRow = _ + + override def hasNext(): Boolean = { + result != null || findNextSortedGroup() + } + + override def next(): AggregationBufferEntry = { + val returnResult = result + result = null + returnResult + } + + // Two-way merges initialAggBufferIterator and inputIterator + private def findNextSortedGroup(): Boolean = { + if (hasNextInput || hasNextAggBuffer) { + // Find smaller key of the initialAggBufferIterator and initialAggBufferIterator + groupingKey = findGroupingKey() + result = new AggregationBufferEntry(groupingKey, makeEmptyAggregationBuffer) + + // Firstly, update the aggregation buffer with input rows. + while (hasNextInput && + groupingKeyOrdering.compare(inputIterator.getKey, groupingKey) == 0) { + processRow(result.aggregationBuffer, inputIterator.getValue) + hasNextInput = inputIterator.next() + } + + // Secondly, merge the aggregation buffer with existing aggregation buffers. + // NOTE: the ordering of these two while-block matter, mergeAggregationBuffer() should + // be called after calling processRow. + while (hasNextAggBuffer && + groupingKeyOrdering.compare(initialAggBufferIterator.getKey, groupingKey) == 0) { + mergeAggregationBuffers(result.aggregationBuffer, initialAggBufferIterator.getValue) + hasNextAggBuffer = initialAggBufferIterator.next() + } + + true + } else { + false + } + } + + private def findGroupingKey(): UnsafeRow = { + var newGroupingKey: UnsafeRow = null + if (!hasNextInput) { + newGroupingKey = initialAggBufferIterator.getKey + } else if (!hasNextAggBuffer) { + newGroupingKey = inputIterator.getKey + } else { + val compareResult = + groupingKeyOrdering.compare(inputIterator.getKey, initialAggBufferIterator.getKey) + if (compareResult <= 0) { + newGroupingKey = inputIterator.getKey + } else { + newGroupingKey = initialAggBufferIterator.getKey + } + } + + if (groupingKey == null) { + groupingKey = newGroupingKey.copy() + } else { + groupingKey.copyFrom(newGroupingKey) + } + groupingKey + } + } + } + + private def createExternalSorterForInput(): UnsafeKVExternalSorter = { + new UnsafeKVExternalSorter( + groupingSchema, + inputSchema, + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes, + SparkEnv.get.conf.getLong( + "spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + null + ) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala new file mode 100644 index 0000000..f2d4f6c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import java.{util => ju} + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate} +import org.apache.spark.sql.execution.UnsafeKVExternalSorter +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +/** + * An aggregation map that supports using safe `SpecificInternalRow`s aggregation buffers, so that + * we can support storing arbitrary Java objects as aggregate function states in the aggregation + * buffers. This class is only used together with [[ObjectHashAggregateExec]]. + */ +class ObjectAggregationMap() { + private[this] val hashMap = new ju.LinkedHashMap[UnsafeRow, InternalRow] + + def getAggregationBuffer(groupingKey: UnsafeRow): InternalRow = { + hashMap.get(groupingKey) + } + + def putAggregationBuffer(groupingKey: UnsafeRow, aggBuffer: InternalRow): Unit = { + hashMap.put(groupingKey, aggBuffer) + } + + def size: Int = hashMap.size() + + def iterator: Iterator[AggregationBufferEntry] = { + val iter = hashMap.entrySet().iterator() + new Iterator[AggregationBufferEntry] { + + override def hasNext: Boolean = { + iter.hasNext + } + override def next(): AggregationBufferEntry = { + val entry = iter.next() + new AggregationBufferEntry(entry.getKey, entry.getValue) + } + } + } + + /** + * Dumps all entries into a newly created external sorter, clears the hash map, and returns the + * external sorter. + */ + def dumpToExternalSorter( + groupingAttributes: Seq[Attribute], + aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = { + val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes) + val sorter = new UnsafeKVExternalSorter( + StructType.fromAttributes(groupingAttributes), + StructType.fromAttributes(aggBufferAttributes), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get().taskMemoryManager().pageSizeBytes, + SparkEnv.get.conf.getLong( + "spark.shuffle.spill.numElementsForceSpillThreshold", + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), + null + ) + + val mapIterator = iterator + val unsafeAggBufferProjection = + UnsafeProjection.create(aggBufferAttributes.map(_.dataType).toArray) + + while (mapIterator.hasNext) { + val entry = mapIterator.next() + aggregateFunctions.foreach { + case agg: TypedImperativeAggregate[_] => + agg.serializeAggregateBufferInPlace(entry.aggregationBuffer) + case _ => + } + + sorter.insertKV( + entry.groupingKey, + unsafeAggBufferProjection(entry.aggregationBuffer) + ) + } + + hashMap.clear() + sorter + } + + def clear(): Unit = { + hashMap.clear() + } +} + +// Stores the grouping key and aggregation buffer +class AggregationBufferEntry(var groupingKey: UnsafeRow, var aggregationBuffer: InternalRow) http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala new file mode 100644 index 0000000..3fcb7ec --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.util.Utils + +/** + * A hash-based aggregate operator that supports [[TypedImperativeAggregate]] functions that may + * use arbitrary JVM objects as aggregation states. + * + * Similar to [[HashAggregateExec]], this operator also falls back to sort-based aggregation when + * the size of the internal hash map exceeds the threshold. The differences are: + * + * - It uses safe rows as aggregation buffer since it must support JVM objects as aggregation + * states. + * + * - It tracks entry count of the hash map instead of byte size to decide when we should fall back. + * This is because it's hard to estimate the accurate size of arbitrary JVM objects in a + * lightweight way. + * + * - Whenever fallen back to sort-based aggregation, this operator feeds all of the rest input rows + * into external sorters instead of building more hash map(s) as what [[HashAggregateExec]] does. + * This is because having too many JVM object aggregation states floating there can be dangerous + * for GC. + * + * - CodeGen is not supported yet. + * + * This operator may be turned off by setting the following SQL configuration to `false`: + * {{{ + * spark.sql.execution.useObjectHashAggregateExec + * }}} + * The fallback threshold can be configured by tuning: + * {{{ + * spark.sql.objectHashAggregate.sortBased.fallbackThreshold + * }}} + */ +case class ObjectHashAggregateExec( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryExecNode { + + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override lazy val allAttributes: AttributeSeq = + child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows") + ) + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + val numOutputRows = longMetric("numOutputRows") + val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold + + child.execute().mapPartitionsInternal { iter => + val hasInput = iter.hasNext + if (!hasInput && groupingExpressions.nonEmpty) { + // This is a grouped aggregate and the input kvIterator is empty, + // so return an empty kvIterator. + Iterator.empty + } else { + val aggregationIterator = + new ObjectAggregationIterator( + child.output, + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + initialInputBufferOffset, + resultExpressions, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), + child.output, + iter, + fallbackCountThreshold) + if (!hasInput && groupingExpressions.isEmpty) { + numOutputRows += 1 + Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput()) + } else { + aggregationIterator + } + } + } + } + + override def verboseString: String = toString(verbose = true) + + override def simpleString: String = toString(verbose = false) + + private def toString(verbose: Boolean): String = { + val allAggregateExpressions = aggregateExpressions + val keyString = Utils.truncatedString(groupingExpressions, "[", ", ", "]") + val functionString = Utils.truncatedString(allAggregateExpressions, "[", ", ", "]") + val outputString = Utils.truncatedString(output, "[", ", ", "]") + if (verbose) { + s"ObjectHashAggregate(keys=$keyString, functions=$functionString, output=$outputString)" + } else { + s"ObjectHashAggregate(keys=$keyString, functions=$functionString)" + } + } +} + +object ObjectHashAggregateExec { + def supportsAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = { + aggregateExpressions.map(_.aggregateFunction).exists { + case _: TypedImperativeAggregate[_] => true + case _ => false + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7b8ed65..71f3a67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -526,6 +526,24 @@ object SQLConf { .stringConf .createWithDefault(classOf[ManifestFileCommitProtocol].getName) + val OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD = + SQLConfigBuilder("spark.sql.objectHashAggregate.sortBased.fallbackThreshold") + .internal() + .doc("In the case of ObjectHashAggregateExec, when the size of the in-memory hash map " + + "grows too large, we will fall back to sort-based aggregation. This option sets a row " + + "count threshold for the size of the hash map.") + .intConf + // We are trying to be conservative and use a relatively small default count threshold here + // since the state object of some TypedImperativeAggregate function can be quite large (e.g. + // percentile_approx). + .createWithDefault(128) + + val USE_OBJECT_HASH_AGG = SQLConfigBuilder("spark.sql.execution.useObjectHashAggregateExec") + .internal() + .doc("Decides if we use ObjectHashAggregateExec") + .booleanConf + .createWithDefault(true) + val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() .doc("Whether to delete the expired log files in file stream sink.") @@ -769,6 +787,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP) + def useObjectHashAggregation: Boolean = getConf(USE_OBJECT_HASH_AGG) + + def objectAggSortBasedFallbackThreshold: Int = getConf(OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD) + def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED) def variableSubstituteDepth: Int = getConf(VARIABLE_SUBSTITUTE_DEPTH) http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index ffa26f1..0759915 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate -import org.apache.spark.sql.execution.aggregate.SortAggregateExec +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -87,11 +87,11 @@ class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext { test("dataframe aggregate with object aggregate buffer, should not use HashAggregate") { val df = data.toDF("a", "b") - val max = new TypedMax($"a".expr) + val max = TypedMax($"a".expr) // Always uses SortAggregateExec val sparkPlan = df.select(Column(max.toAggregateExpression())).queryExecution.sparkPlan - assert(sparkPlan.isInstanceOf[SortAggregateExec]) + assert(!sparkPlan.isInstanceOf[HashAggregateExec]) } test("dataframe aggregate with object aggregate buffer, no group by") { http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala new file mode 100644 index 0000000..bc9cb6e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import java.util.Properties + +import scala.collection.mutable + +import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.unsafe.KVIterator + +class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkContext { + + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + sc = new SparkContext("local[2, 4]", "test", conf) + val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) + TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + } + + override def afterAll(): Unit = TaskContext.unset() + + private val rand = new java.util.Random() + + // In this test, the aggregator is XOR checksum. + test("merge input kv iterator and aggregation buffer iterator") { + + val inputSchema = StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType))) + val groupingSchema = StructType(Seq(StructField("b", IntegerType))) + + // Schema: a: Int, b: Int + val inputRow: UnsafeRow = createUnsafeRow(2) + + // Schema: group: Int + val group: UnsafeRow = createUnsafeRow(1) + + val expected = new mutable.HashMap[Int, Int]() + val hashMap = new ObjectAggregationMap + (0 to 5000).foreach { _ => + randomKV(inputRow, group) + + // XOR aggregate on first column of input row + expected.put(group.getInt(0), expected.getOrElse(group.getInt(0), 0) ^ inputRow.getInt(0)) + if (hashMap.getAggregationBuffer(group) == null) { + hashMap.putAggregationBuffer(group.copy, createNewAggregationBuffer()) + } + updateInputRow(hashMap.getAggregationBuffer(group), inputRow) + } + + val store = new SortBasedAggregator( + createSortedAggBufferIterator(hashMap), + inputSchema, + groupingSchema, + updateInputRow, + mergeAggBuffer, + createNewAggregationBuffer) + + (5000 to 100000).foreach { _ => + randomKV(inputRow, group) + // XOR aggregate on first column of input row + expected.put(group.getInt(0), expected.getOrElse(group.getInt(0), 0) ^ inputRow.getInt(0)) + store.addInput(group, inputRow) + } + + val iter = store.destructiveIterator() + while(iter.hasNext) { + val agg = iter.next() + assert(agg.aggregationBuffer.getInt(0) == expected(agg.groupingKey.getInt(0))) + } + } + + private def createNewAggregationBuffer(): InternalRow = { + val buffer = createUnsafeRow(1) + buffer.setInt(0, 0) + buffer + } + + private def updateInputRow: (InternalRow, InternalRow) => Unit = { + (buffer: InternalRow, input: InternalRow) => { + buffer.setInt(0, buffer.getInt(0) ^ input.getInt(0)) + } + } + + private def mergeAggBuffer: (InternalRow, InternalRow) => Unit = updateInputRow + + private def createUnsafeRow(numOfField: Int): UnsafeRow = { + val buffer: Array[Byte] = new Array(1024) + val row: UnsafeRow = new UnsafeRow(numOfField) + row.pointTo(buffer, 1024) + row + } + + private def randomKV(inputRow: UnsafeRow, group: UnsafeRow): Unit = { + inputRow.setInt(0, rand.nextInt(100000)) + inputRow.setInt(1, rand.nextInt(10000)) + group.setInt(0, inputRow.getInt(1) % 100) + } + + def createSortedAggBufferIterator( + hashMap: ObjectAggregationMap): KVIterator[UnsafeRow, UnsafeRow] = { + + val sortedIterator = hashMap.iterator.toList.sortBy(_.groupingKey.getInt(0)).iterator + new KVIterator[UnsafeRow, UnsafeRow] { + var key: UnsafeRow = null + var value: UnsafeRow = null + override def next: Boolean = { + if (sortedIterator.hasNext) { + val kv = sortedIterator.next() + key = kv.groupingKey + value = kv.aggregationBuffer.asInstanceOf[UnsafeRow] + true + } else { + false + } + } + override def getKey(): UnsafeRow = key + override def getValue(): UnsafeRow = value + override def close(): Unit = Unit + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala new file mode 100644 index 0000000..197110f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/execution/benchmark/ObjectHashAggregateExecBenchmark.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFPercentileApprox + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.expressions.{ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile +import org.apache.spark.sql.hive.HiveSessionCatalog +import org.apache.spark.sql.hive.execution.TestingTypedCount +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.LongType +import org.apache.spark.util.Benchmark + +class ObjectHashAggregateExecBenchmark extends BenchmarkBase with TestHiveSingleton { + ignore("Hive UDAF vs Spark AF") { + val N = 2 << 15 + + val benchmark = new Benchmark( + name = "hive udaf vs spark af", + valuesPerIteration = N, + minNumIters = 5, + warmupTime = 5.seconds, + minTime = 10.seconds, + outputPerIteration = true + ) + + registerHiveFunction("hive_percentile_approx", classOf[GenericUDAFPercentileApprox]) + + sparkSession.range(N).createOrReplaceTempView("t") + + benchmark.addCase("hive udaf w/o group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") + sparkSession.sql("SELECT hive_percentile_approx(id, 0.5) FROM t").collect() + } + + benchmark.addCase("spark af w/o group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + sparkSession.sql("SELECT percentile_approx(id, 0.5) FROM t").collect() + } + + benchmark.addCase("hive udaf w/ group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") + sparkSession.sql( + s"SELECT hive_percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" + ).collect() + } + + benchmark.addCase("spark af w/ group by w/o fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + sparkSession.sql( + s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" + ).collect() + } + + benchmark.addCase("spark af w/ group by w/ fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2") + sparkSession.sql( + s"SELECT percentile_approx(id, 0.5) FROM t GROUP BY CAST(id / ${N / 4} AS BIGINT)" + ).collect() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + hive udaf vs spark af: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + hive udaf w/o group by 5326 / 5408 0.0 81264.2 1.0X + spark af w/o group by 93 / 111 0.7 1415.6 57.4X + hive udaf w/ group by 3804 / 3946 0.0 58050.1 1.4X + spark af w/ group by w/o fallback 71 / 90 0.9 1085.7 74.8X + spark af w/ group by w/ fallback 98 / 111 0.7 1501.6 54.1X + */ + } + + ignore("ObjectHashAggregateExec vs SortAggregateExec - typed_count") { + val N: Long = 1024 * 1024 * 100 + + val benchmark = new Benchmark( + name = "object agg v.s. sort agg", + valuesPerIteration = N, + minNumIters = 1, + warmupTime = 10.seconds, + minTime = 45.seconds, + outputPerIteration = true + ) + + import sparkSession.implicits._ + + def typed_count(column: Column): Column = + Column(TestingTypedCount(column.expr).toAggregateExpression()) + + val df = sparkSession.range(N) + + benchmark.addCase("sort agg w/ group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") + df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + } + + benchmark.addCase("object agg w/ group by w/o fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + } + + benchmark.addCase("object agg w/ group by w/ fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2") + df.groupBy($"id" < (N / 2)).agg(typed_count($"id")).collect() + } + + benchmark.addCase("sort agg w/o group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") + df.select(typed_count($"id")).collect() + } + + benchmark.addCase("object agg w/o group by w/o fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + df.select(typed_count($"id")).collect() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + sort agg w/ group by 31251 / 31908 3.4 298.0 1.0X + object agg w/ group by w/o fallback 6903 / 7141 15.2 65.8 4.5X + object agg w/ group by w/ fallback 20945 / 21613 5.0 199.7 1.5X + sort agg w/o group by 4734 / 5463 22.1 45.2 6.6X + object agg w/o group by w/o fallback 4310 / 4529 24.3 41.1 7.3X + */ + } + + ignore("ObjectHashAggregateExec vs SortAggregateExec - percentile_approx") { + val N = 2 << 20 + + val benchmark = new Benchmark( + name = "object agg v.s. sort agg", + valuesPerIteration = N, + minNumIters = 5, + warmupTime = 15.seconds, + minTime = 45.seconds, + outputPerIteration = true + ) + + import sparkSession.implicits._ + + val df = sparkSession.range(N).coalesce(1) + + benchmark.addCase("sort agg w/ group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") + df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + } + + benchmark.addCase("object agg w/ group by w/o fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + } + + benchmark.addCase("object agg w/ group by w/ fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + sparkSession.conf.set(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key, "2") + df.groupBy($"id" / (N / 4) cast LongType).agg(percentile_approx($"id", 0.5)).collect() + } + + benchmark.addCase("sort agg w/o group by") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "false") + df.select(percentile_approx($"id", 0.5)).collect() + } + + benchmark.addCase("object agg w/o group by w/o fallback") { _ => + sparkSession.conf.set(SQLConf.USE_OBJECT_HASH_AGG.key, "true") + df.select(percentile_approx($"id", 0.5)).collect() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_92-b14 on Mac OS X 10.10.5 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + + object agg v.s. sort agg: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + sort agg w/ group by 3418 / 3530 0.6 1630.0 1.0X + object agg w/ group by w/o fallback 3210 / 3314 0.7 1530.7 1.1X + object agg w/ group by w/ fallback 3419 / 3511 0.6 1630.1 1.0X + sort agg w/o group by 4336 / 4499 0.5 2067.3 0.8X + object agg w/o group by w/o fallback 4271 / 4372 0.5 2036.7 0.8X + */ + } + + private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { + val sessionCatalog = sparkSession.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) + val info = new ExpressionInfo(clazz.getName, functionName) + sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + } + + private def percentile_approx( + column: Column, percentage: Double, isDistinct: Boolean = false): Column = { + val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) + Column(approxPercentile.toAggregateExpression(isDistinct)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala new file mode 100644 index 0000000..527626b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import scala.util.Random + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMax +import org.scalatest.Matchers._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.hive.HiveSessionCatalog +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +class ObjectHashAggregateSuite + extends QueryTest + with SQLTestUtils + with TestHiveSingleton + with ExpressionEvalHelper { + + import testImplicits._ + + test("typed_count without grouping keys") { + val df = Seq((1: Integer, 2), (null, 2), (3: Integer, 4)).toDF("a", "b") + + checkAnswer( + df.coalesce(1).select(typed_count($"a")), + Seq(Row(2)) + ) + } + + test("typed_count without grouping keys and empty input") { + val df = Seq.empty[(Integer, Int)].toDF("a", "b") + + checkAnswer( + df.coalesce(1).select(typed_count($"a")), + Seq(Row(0)) + ) + } + + test("typed_count with grouping keys") { + val df = Seq((1: Integer, 1), (null, 1), (2: Integer, 2)).toDF("a", "b") + + checkAnswer( + df.coalesce(1).groupBy($"b").agg(typed_count($"a")), + Seq( + Row(1, 1), + Row(2, 1)) + ) + } + + test("typed_count fallback to sort-based aggregation") { + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "2") { + val df = Seq( + (null, 1), + (null, 1), + (1: Integer, 1), + (2: Integer, 2), + (2: Integer, 2), + (2: Integer, 2) + ).toDF("a", "b") + + checkAnswer( + df.coalesce(1).groupBy($"b").agg(typed_count($"a")), + Seq(Row(1, 1), Row(2, 3)) + ) + } + } + + test("random input data types") { + val dataTypes = Seq( + // Integral types + ByteType, ShortType, IntegerType, LongType, + + // Fractional types + FloatType, DoubleType, + + // Decimal types + DecimalType(25, 5), DecimalType(6, 5), + + // Datetime types + DateType, TimestampType, + + // Complex types + ArrayType(IntegerType), + MapType(DoubleType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType), nullable = true), + + // UDT + new UDT.MyDenseVectorUDT(), + + // Others + StringType, + BinaryType, NullType, BooleanType + ) + + dataTypes.sliding(2, 1).map(_.toSeq).foreach { dataTypes => + // Schema used to generate random input data. + val schemaForGenerator = StructType(dataTypes.zipWithIndex.map { + case (fieldType, index) => + StructField(s"col_$index", fieldType, nullable = true) + }) + + // Schema of the DataFrame to be tested. + val schema = StructType( + StructField("id", IntegerType, nullable = false) +: schemaForGenerator.fields + ) + + logInfo(s"Testing schema:\n${schema.treeString}") + + // Creates a DataFrame for the schema with random data. + val data = generateRandomRows(schemaForGenerator) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + val aggFunctions = schema.fieldNames.map(f => typed_count(col(f))) + + checkAnswer( + df.agg(aggFunctions.head, aggFunctions.tail: _*), + Row.fromSeq(data.map(_.toSeq).transpose.map(_.count(_ != null): Long)) + ) + + checkAnswer( + df.groupBy($"id" % 4 as 'mod).agg(aggFunctions.head, aggFunctions.tail: _*), + data.groupBy(_.getInt(0) % 4).map { case (key, value) => + key -> Row.fromSeq(value.map(_.toSeq).transpose.map(_.count(_ != null): Long)) + }.toSeq.map { + case (key, value) => Row.fromSeq(key +: value.toSeq) + } + ) + + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "5") { + checkAnswer( + df.agg(aggFunctions.head, aggFunctions.tail: _*), + Row.fromSeq(data.map(_.toSeq).transpose.map(_.count(_ != null): Long)) + ) + } + } + } + + private def percentile_approx( + column: Column, percentage: Double, isDistinct: Boolean = false): Column = { + val approxPercentile = new ApproximatePercentile(column.expr, Literal(percentage)) + Column(approxPercentile.toAggregateExpression(isDistinct)) + } + + private def typed_count(column: Column): Column = + Column(TestingTypedCount(column.expr).toAggregateExpression()) + + // Generates 50 random rows for a given schema. + private def generateRandomRows(schemaForGenerator: StructType): Seq[Row] = { + val dataGenerator = RandomDataGenerator.forType( + dataType = schemaForGenerator, + nullable = true, + new Random(System.nanoTime()) + ).getOrElse { + fail(s"Failed to create data generator for schema $schemaForGenerator") + } + + (1 to 50).map { i => + dataGenerator() match { + case row: Row => Row.fromSeq(i +: row.toSeq) + case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null)) + case other => fail( + s"Row or null is expected to be generated, " + + s"but a ${other.getClass.getCanonicalName} is generated." + ) + } + } + } + + makeRandomizedTests() + + private def makeRandomizedTests(): Unit = { + // A TypedImperativeAggregate function + val typed = percentile_approx($"c0", 0.5) + + // A Hive UDAF without partial aggregation support + val withoutPartial = { + registerHiveFunction("hive_max", classOf[GenericUDAFMax]) + function("hive_max", $"c1") + } + + // A Spark SQL native aggregate function with partial aggregation support that can be executed + // by the Tungsten `HashAggregateExec` + val withPartialUnsafe = max($"c2") + + // A Spark SQL native aggregate function with partial aggregation support that can only be + // executed by the Tungsten `HashAggregateExec` + val withPartialSafe = max($"c3") + + // A Spark SQL native distinct aggregate function + val withDistinct = countDistinct($"c4") + + val allAggs = Seq( + "typed" -> typed, + "without partial" -> withoutPartial, + "with partial + unsafe" -> withPartialUnsafe, + "with partial + safe" -> withPartialSafe, + "with distinct" -> withDistinct + ) + + val builtinNumericTypes = Seq( + // Integral types + ByteType, ShortType, IntegerType, LongType, + + // Fractional types + FloatType, DoubleType + ) + + val numericTypes = builtinNumericTypes ++ Seq( + // Decimal types + DecimalType(25, 5), DecimalType(6, 5) + ) + + val dateTimeTypes = Seq(DateType, TimestampType) + + val arrayType = ArrayType(IntegerType) + + val structType = new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", ArrayType(BooleanType), nullable = true) + + val mapType = MapType(DoubleType, LongType) + + val complexTypes = Seq(arrayType, mapType, structType) + + val orderedComplexType = Seq(arrayType, structType) + + val orderedTypes = numericTypes ++ dateTimeTypes ++ orderedComplexType ++ Seq( + StringType, BinaryType, NullType, BooleanType + ) + + val udt = new UDT.MyDenseVectorUDT() + + val fixedLengthTypes = builtinNumericTypes ++ Seq(BooleanType, NullType) + + val varLenTypes = complexTypes ++ Seq(StringType, BinaryType, udt) + + val varLenOrderedTypes = varLenTypes.intersect(orderedTypes) + + val allTypes = orderedTypes :+ udt + + val seed = System.nanoTime() + val random = new Random(seed) + + logInfo(s"Using random seed $seed") + + // Generates a random schema for the randomized data generator + val schema = new StructType() + .add("c0", numericTypes(random.nextInt(numericTypes.length)), nullable = true) + .add("c1", orderedTypes(random.nextInt(orderedTypes.length)), nullable = true) + .add("c2", fixedLengthTypes(random.nextInt(fixedLengthTypes.length)), nullable = true) + .add("c3", varLenOrderedTypes(random.nextInt(varLenOrderedTypes.length)), nullable = true) + .add("c4", allTypes(random.nextInt(allTypes.length)), nullable = true) + + logInfo( + s"""Using the following random schema to generate all the randomized aggregation tests: + | + |${schema.treeString} + """.stripMargin + ) + + // Builds a randomly generated DataFrame + val schemaWithId = StructType(StructField("id", IntegerType, nullable = false) +: schema.fields) + val data = generateRandomRows(schema) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schemaWithId) + + // Tests all combinations of length 1 to 5 types of aggregate functions + (1 to allAggs.length) foreach { i => + allAggs.combinations(i) foreach { targetAggs => + val (names, aggs) = targetAggs.unzip + + // Tests aggregation of w/ and w/o grouping keys + Seq(true, false).foreach { withGroupingKeys => + + // Tests aggregation with empty and non-empty input rows + Seq(true, false).foreach { emptyInput => + + // Builds the aggregation to be tested according to different configurations + def doAggregation(df: DataFrame): DataFrame = { + val baseDf = if (emptyInput) { + val emptyRows = spark.sparkContext.parallelize(Seq.empty[Row], 1) + spark.createDataFrame(emptyRows, schemaWithId) + } else { + df + } + + if (withGroupingKeys) { + baseDf + .groupBy($"id" % 10 as "group") + .agg(aggs.head, aggs.tail: _*) + .orderBy("group") + } else { + baseDf.agg(aggs.head, aggs.tail: _*) + } + } + + // Currently Spark SQL doesn't support evaluating distinct aggregate function together + // with aggregate functions without partial aggregation support. + if (!(aggs.contains(withoutPartial) && aggs.contains(withDistinct))) { + test( + s"randomized aggregation test - " + + s"${names.mkString("[", ", ", "]")} - " + + s"${if (withGroupingKeys) "with" else "without"} grouping keys - " + + s"with ${if (emptyInput) "empty" else "non-empty"} input" + ) { + var expected: Seq[Row] = null + var actual1: Seq[Row] = null + var actual2: Seq[Row] = null + + // Disables `ObjectHashAggregateExec` to obtain a standard answer + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + val aggDf = doAggregation(df) + + if (aggs.intersect(Seq(withoutPartial, withPartialSafe, typed)).nonEmpty) { + assert(containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else { + assert(!containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(containsHashAggregateExec(aggDf)) + } + + expected = aggDf.collect().toSeq + } + + // Enables `ObjectHashAggregateExec` + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val aggDf = doAggregation(df) + + if (aggs.contains(typed) && !aggs.contains(withoutPartial)) { + assert(!containsSortAggregateExec(aggDf)) + assert(containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else if (aggs.intersect(Seq(withoutPartial, withPartialSafe)).nonEmpty) { + assert(containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(!containsHashAggregateExec(aggDf)) + } else { + assert(!containsSortAggregateExec(aggDf)) + assert(!containsObjectHashAggregateExec(aggDf)) + assert(containsHashAggregateExec(aggDf)) + } + + // Disables sort-based aggregation fallback (we only generate 50 rows, so 100 is + // big enough) to obtain a result to be checked. + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "100") { + actual1 = aggDf.collect().toSeq + } + + // Enables sort-based aggregation fallback to obtain another result to be checked. + withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "3") { + // Here we are not reusing `aggDf` because the physical plan in `aggDf` is + // cached and won't be re-planned using the new fallback threshold. + actual2 = doAggregation(df).collect().toSeq + } + } + + doubleSafeCheckRows(actual1, expected, 1e-4) + doubleSafeCheckRows(actual2, expected, 1e-4) + } + } + } + } + } + } + } + + private def containsSortAggregateExec(df: DataFrame): Boolean = { + df.queryExecution.executedPlan.collectFirst { + case _: SortAggregateExec => () + }.nonEmpty + } + + private def containsObjectHashAggregateExec(df: DataFrame): Boolean = { + df.queryExecution.executedPlan.collectFirst { + case _: ObjectHashAggregateExec => () + }.nonEmpty + } + + private def containsHashAggregateExec(df: DataFrame): Boolean = { + df.queryExecution.executedPlan.collectFirst { + case _: HashAggregateExec => () + }.nonEmpty + } + + private def doubleSafeCheckRows(actual: Seq[Row], expected: Seq[Row], tolerance: Double): Unit = { + assert(actual.length == expected.length) + actual.zip(expected).foreach { case (lhs: Row, rhs: Row) => + assert(lhs.length == rhs.length) + lhs.toSeq.zip(rhs.toSeq).foreach { + case (a: Double, b: Double) => checkResult(a, b +- tolerance) + case (a, b) => checkResult(a, b) + } + } + } + + private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { + val sessionCatalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] + val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) + val info = new ExpressionInfo(clazz.getName, functionName) + sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) + } + + private def function(name: String, args: Column*): Column = { + Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/27daf6bc/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala new file mode 100644 index 0000000..a3d48d9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.hive.execution.TestingTypedCount.State +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - A testing aggregate function resembles COUNT " + + "but implements ObjectAggregateFunction.") +case class TestingTypedCount( + child: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[TestingTypedCount.State] { + + def this(child: Expression) = this(child, 0, 0) + + override def children: Seq[Expression] = child :: Nil + + override def dataType: DataType = LongType + + override def nullable: Boolean = false + + override val supportsPartial: Boolean = true + + override def createAggregationBuffer(): State = TestingTypedCount.State(0L) + + override def update(buffer: State, input: InternalRow): Unit = { + if (child.eval(input) != null) { + buffer.count += 1 + } + } + + override def merge(buffer: State, input: State): Unit = { + buffer.count += input.count + } + + override def eval(buffer: State): Any = buffer.count + + override def serialize(buffer: State): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val dataStream = new DataOutputStream(byteStream) + dataStream.writeLong(buffer.count) + byteStream.toByteArray + } + + override def deserialize(storageFormat: Array[Byte]): State = { + val byteStream = new ByteArrayInputStream(storageFormat) + val dataStream = new DataInputStream(byteStream) + TestingTypedCount.State(dataStream.readLong()) + } + + override def inputTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override val prettyName: String = "typed_count" +} + +object TestingTypedCount { + case class State(var count: Long) +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org