Repository: spark
Updated Branches:
  refs/heads/branch-1.5 2d86faddd -> 4c6b1296d


[SPARK-9747] [SQL] Avoid starving an unsafe operator in aggregation

This is the sister patch to #8011, but for aggregation.

In a nutshell: create the `TungstenAggregationIterator` before computing the 
parent partition. Internally this creates a `BytesToBytesMap` which acquires a 
page in the constructor as of this patch. This ensures that the aggregation 
operator is not starved since we reserve at least 1 page in advance.

rxin yhuai

Author: Andrew Or <and...@databricks.com>

Closes #8038 from andrewor14/unsafe-starve-memory-agg.

(cherry picked from commit e0110792ef71ebfd3727b970346a2e13695990a4)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4c6b1296
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4c6b1296
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4c6b1296

Branch: refs/heads/branch-1.5
Commit: 4c6b1296d20f594f71e63b0772b5290ef21ddd21
Parents: 2d86fad
Author: Andrew Or <and...@databricks.com>
Authored: Wed Aug 12 10:08:35 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Aug 12 10:08:47 2015 -0700

----------------------------------------------------------------------
 .../spark/unsafe/map/BytesToBytesMap.java       | 34 ++++++--
 .../unsafe/sort/UnsafeExternalSorter.java       |  9 +-
 .../map/AbstractBytesToBytesMapSuite.java       | 11 ++-
 .../UnsafeFixedWidthAggregationMap.java         |  7 ++
 .../execution/aggregate/TungstenAggregate.scala | 72 ++++++++++------
 .../aggregate/TungstenAggregationIterator.scala | 88 ++++++++++++--------
 .../TungstenAggregationIteratorSuite.scala      | 56 +++++++++++++
 7 files changed, 201 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java 
b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 85b46ec..87ed47e 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -193,6 +193,11 @@ public final class BytesToBytesMap {
         TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
     }
     allocate(initialCapacity);
+
+    // Acquire a new page as soon as we construct the map to ensure that we 
have at least
+    // one page to work with. Otherwise, other operators in the same task may 
starve this
+    // map (SPARK-9747).
+    acquireNewPage();
   }
 
   public BytesToBytesMap(
@@ -574,16 +579,9 @@ public final class BytesToBytesMap {
           final long lengthOffsetInPage = currentDataPage.getBaseOffset() + 
pageCursor;
           Platform.putInt(pageBaseObject, lengthOffsetInPage, 
END_OF_PAGE_MARKER);
         }
-        final long memoryGranted = 
shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-        if (memoryGranted != pageSizeBytes) {
-          shuffleMemoryManager.release(memoryGranted);
-          logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+        if (!acquireNewPage()) {
           return false;
         }
-        MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        dataPages.add(newPage);
-        pageCursor = 0;
-        currentDataPage = newPage;
         dataPage = currentDataPage;
         dataPageBaseObject = currentDataPage.getBaseObject();
         dataPageInsertOffset = currentDataPage.getBaseOffset();
@@ -643,6 +641,24 @@ public final class BytesToBytesMap {
   }
 
   /**
+   * Acquire a new page from the {@link ShuffleMemoryManager}.
+   * @return whether there is enough space to allocate the new page.
+   */
+  private boolean acquireNewPage() {
+    final long memoryGranted = 
shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+    if (memoryGranted != pageSizeBytes) {
+      shuffleMemoryManager.release(memoryGranted);
+      logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+      return false;
+    }
+    MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+    dataPages.add(newPage);
+    pageCursor = 0;
+    currentDataPage = newPage;
+    return true;
+  }
+
+  /**
    * Allocate new data structures for this map. When calling this outside of 
the constructor,
    * make sure to keep references to the old data structures so that you can 
free them.
    *
@@ -748,7 +764,7 @@ public final class BytesToBytesMap {
   }
 
   @VisibleForTesting
-  int getNumDataPages() {
+  public int getNumDataPages() {
     return dataPages.size();
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 9601aaf..fc364e0 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -132,16 +132,15 @@ public final class UnsafeExternalSorter {
 
     if (existingInMemorySorter == null) {
       initializeForWriting();
+      // Acquire a new page as soon as we construct the sorter to ensure that 
we have at
+      // least one page to work with. Otherwise, other operators in the same 
task may starve
+      // this sorter (SPARK-9709). We don't need to do this if we already have 
an existing sorter.
+      acquireNewPage();
     } else {
       this.isInMemSorterExternal = true;
       this.inMemSorter = existingInMemorySorter;
     }
 
-    // Acquire a new page as soon as we construct the sorter to ensure that we 
have at
-    // least one page to work with. Otherwise, other operators in the same 
task may starve
-    // this sorter (SPARK-9709).
-    acquireNewPage();
-
     // Register a cleanup task with TaskContext to ensure that memory is 
guaranteed to be freed at
     // the end of the task. This is necessary to avoid memory leaks in when 
the downstream operator
     // does not fully consume the sorter's output (e.g. sort followed by 
limit).

http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 1a79c20..ab480b6 100644
--- 
a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ 
b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -543,7 +543,7 @@ public abstract class AbstractBytesToBytesMapSuite {
           Platform.LONG_ARRAY_OFFSET,
           8);
         newPeakMemory = map.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0) {
+        if (i % numRecordsPerPage == 0 && i > 0) {
           // We allocated a new page for this record, so peak memory should 
change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
         } else {
@@ -561,4 +561,13 @@ public abstract class AbstractBytesToBytesMapSuite {
       map.free();
     }
   }
+
+  @Test
+  public void testAcquirePageInConstructor() {
+    final BytesToBytesMap map = new BytesToBytesMap(
+      taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+    assertEquals(1, map.getNumDataPages());
+    map.free();
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 5cce41d..09511ff 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution;
 
 import java.io.IOException;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import org.apache.spark.SparkEnv;
 import org.apache.spark.shuffle.ShuffleMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
@@ -220,6 +222,11 @@ public final class UnsafeFixedWidthAggregationMap {
     return map.getPeakMemoryUsedBytes();
   }
 
+  @VisibleForTesting
+  public int getNumDataPages() {
+    return map.getNumDataPages();
+  }
+
   /**
    * Free the memory associated with this map. This is idempotent and can be 
called multiple times.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 6b5935a..c40ca97 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,12 +17,13 @@
 
 package org.apache.spark.sql.execution.aggregate
 
-import org.apache.spark.rdd.RDD
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, 
ClusteredDistribution, AllTuples, Distribution}
+import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{UnaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
@@ -68,35 +69,56 @@ case class TungstenAggregate(
   protected override def doExecute(): RDD[InternalRow] = attachTree(this, 
"execute") {
     val numInputRows = longMetric("numInputRows")
     val numOutputRows = longMetric("numOutputRows")
-    child.execute().mapPartitions { iter =>
-      val hasInput = iter.hasNext
-      if (!hasInput && groupingExpressions.nonEmpty) {
-        // This is a grouped aggregate and the input iterator is empty,
-        // so return an empty iterator.
-        Iterator.empty.asInstanceOf[Iterator[UnsafeRow]]
-      } else {
-        val aggregationIterator =
-          new TungstenAggregationIterator(
-            groupingExpressions,
-            nonCompleteAggregateExpressions,
-            completeAggregateExpressions,
-            initialInputBufferOffset,
-            resultExpressions,
-            newMutableProjection,
-            child.output,
-            iter,
-            testFallbackStartsAt,
-            numInputRows,
-            numOutputRows)
-
-        if (!hasInput && groupingExpressions.isEmpty) {
+
+    /**
+     * Set up the underlying unsafe data structures used before computing the 
parent partition.
+     * This makes sure our iterator is not starved by other operators in the 
same task.
+     */
+    def preparePartition(): TungstenAggregationIterator = {
+      new TungstenAggregationIterator(
+        groupingExpressions,
+        nonCompleteAggregateExpressions,
+        completeAggregateExpressions,
+        initialInputBufferOffset,
+        resultExpressions,
+        newMutableProjection,
+        child.output,
+        testFallbackStartsAt,
+        numInputRows,
+        numOutputRows)
+    }
+
+    /** Compute a partition using the iterator already set up previously. */
+    def executePartition(
+        context: TaskContext,
+        partitionIndex: Int,
+        aggregationIterator: TungstenAggregationIterator,
+        parentIterator: Iterator[InternalRow]): Iterator[UnsafeRow] = {
+      val hasInput = parentIterator.hasNext
+      if (!hasInput) {
+        // We're not using the underlying map, so we just can free it here
+        aggregationIterator.free()
+        if (groupingExpressions.isEmpty) {
           numOutputRows += 1
           
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
         } else {
-          aggregationIterator
+          // This is a grouped aggregate and the input iterator is empty,
+          // so return an empty iterator.
+          Iterator[UnsafeRow]()
         }
+      } else {
+        aggregationIterator.start(parentIterator)
+        aggregationIterator
       }
     }
+
+    // Note: we need to set up the iterator in each partition before computing 
the
+    // parent partition, so we cannot simply use `mapPartitions` here 
(SPARK-9747).
+    val resultRdd = {
+      new MapPartitionsWithPreparationRDD[UnsafeRow, InternalRow, 
TungstenAggregationIterator](
+        child.execute(), preparePartition, executePartition, 
preservesPartitioning = true)
+    }
+    resultRdd.asInstanceOf[RDD[InternalRow]]
   }
 
   override def simpleString: String = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 1f383dd..af7e0fc 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -72,8 +72,6 @@ import org.apache.spark.sql.types.StructType
  *   the function used to create mutable projections.
  * @param originalInputAttributes
  *   attributes of representing input rows from `inputIter`.
- * @param inputIter
- *   the iterator containing input [[UnsafeRow]]s.
  */
 class TungstenAggregationIterator(
     groupingExpressions: Seq[NamedExpression],
@@ -83,12 +81,14 @@ class TungstenAggregationIterator(
     resultExpressions: Seq[NamedExpression],
     newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => 
MutableProjection),
     originalInputAttributes: Seq[Attribute],
-    inputIter: Iterator[InternalRow],
     testFallbackStartsAt: Option[Int],
     numInputRows: LongSQLMetric,
     numOutputRows: LongSQLMetric)
   extends Iterator[UnsafeRow] with Logging {
 
+  // The parent partition iterator, to be initialized later in `start`
+  private[this] var inputIter: Iterator[InternalRow] = null
+
   ///////////////////////////////////////////////////////////////////////////
   // Part 1: Initializing aggregate functions.
   ///////////////////////////////////////////////////////////////////////////
@@ -348,11 +348,15 @@ class TungstenAggregationIterator(
     false // disable tracking of performance metrics
   )
 
+  // Exposed for testing
+  private[aggregate] def getHashMap: UnsafeFixedWidthAggregationMap = hashMap
+
   // The function used to read and process input rows. When processing input 
rows,
   // it first uses hash-based aggregation by putting groups and their buffers 
in
   // hashMap. If we could not allocate more memory for the map, we switch to
   // sort-based aggregation (by calling switchToSortBasedAggregation).
   private def processInputs(): Unit = {
+    assert(inputIter != null, "attempted to process input when iterator was 
null")
     while (!sortBased && inputIter.hasNext) {
       val newInput = inputIter.next()
       numInputRows += 1
@@ -372,6 +376,7 @@ class TungstenAggregationIterator(
   // that it switch to sort-based aggregation after `fallbackStartsAt` input 
rows have
   // been processed.
   private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit 
= {
+    assert(inputIter != null, "attempted to process input when iterator was 
null")
     var i = 0
     while (!sortBased && inputIter.hasNext) {
       val newInput = inputIter.next()
@@ -412,6 +417,7 @@ class TungstenAggregationIterator(
    * Switch to sort-based aggregation when the hash-based approach is unable 
to acquire memory.
    */
   private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: 
InternalRow): Unit = {
+    assert(inputIter != null, "attempted to process input when iterator was 
null")
     logInfo("falling back to sort based aggregation.")
     // Step 1: Get the ExternalSorter containing sorted entries of the map.
     externalSorter = hashMap.destructAndCreateExternalSorter()
@@ -431,6 +437,11 @@ class TungstenAggregationIterator(
       case _ => false
     }
 
+    // Note: Since we spill the sorter's contents immediately after creating 
it, we must insert
+    // something into the sorter here to ensure that we acquire at least a 
page of memory.
+    // This is done through `externalSorter.insertKV`, which will trigger the 
page allocation.
+    // Otherwise, children operators may steal the window of opportunity and 
starve our sorter.
+
     if (needsProcess) {
       // First, we create a buffer.
       val buffer = createNewAggregationBuffer()
@@ -588,27 +599,33 @@ class TungstenAggregationIterator(
   //         have not switched to sort-based aggregation.
   ///////////////////////////////////////////////////////////////////////////
 
-  // Starts to process input rows.
-  testFallbackStartsAt match {
-    case None =>
-      processInputs()
-    case Some(fallbackStartsAt) =>
-      // This is the testing path. processInputsWithControlledFallback is same 
as processInputs
-      // except that it switches to sort-based aggregation after 
`fallbackStartsAt` input rows
-      // have been processed.
-      processInputsWithControlledFallback(fallbackStartsAt)
-  }
+  /**
+   * Start processing input rows.
+   * Only after this method is called will this iterator be non-empty.
+   */
+  def start(parentIter: Iterator[InternalRow]): Unit = {
+    inputIter = parentIter
+    testFallbackStartsAt match {
+      case None =>
+        processInputs()
+      case Some(fallbackStartsAt) =>
+        // This is the testing path. processInputsWithControlledFallback is 
same as processInputs
+        // except that it switches to sort-based aggregation after 
`fallbackStartsAt` input rows
+        // have been processed.
+        processInputsWithControlledFallback(fallbackStartsAt)
+    }
 
-  // If we did not switch to sort-based aggregation in processInputs,
-  // we pre-load the first key-value pair from the map (to make hasNext 
idempotent).
-  if (!sortBased) {
-    // First, set aggregationBufferMapIterator.
-    aggregationBufferMapIterator = hashMap.iterator()
-    // Pre-load the first key-value pair from the aggregationBufferMapIterator.
-    mapIteratorHasNext = aggregationBufferMapIterator.next()
-    // If the map is empty, we just free it.
-    if (!mapIteratorHasNext) {
-      hashMap.free()
+    // If we did not switch to sort-based aggregation in processInputs,
+    // we pre-load the first key-value pair from the map (to make hasNext 
idempotent).
+    if (!sortBased) {
+      // First, set aggregationBufferMapIterator.
+      aggregationBufferMapIterator = hashMap.iterator()
+      // Pre-load the first key-value pair from the 
aggregationBufferMapIterator.
+      mapIteratorHasNext = aggregationBufferMapIterator.next()
+      // If the map is empty, we just free it.
+      if (!mapIteratorHasNext) {
+        hashMap.free()
+      }
     }
   }
 
@@ -673,21 +690,20 @@ class TungstenAggregationIterator(
   }
 
   ///////////////////////////////////////////////////////////////////////////
-  // Part 8: A utility function used to generate a output row when there is no
-  // input and there is no grouping expression.
+  // Part 8: Utility functions
   ///////////////////////////////////////////////////////////////////////////
 
+  /**
+   * Generate a output row when there is no input and there is no grouping 
expression.
+   */
   def outputForEmptyGroupingKeyWithoutInput(): UnsafeRow = {
-    if (groupingExpressions.isEmpty) {
-      sortBasedAggregationBuffer.copyFrom(initialAggregationBuffer)
-      // We create a output row and copy it. So, we can free the map.
-      val resultCopy =
-        generateOutput(UnsafeRow.createFromByteArray(0, 0), 
sortBasedAggregationBuffer).copy()
-      hashMap.free()
-      resultCopy
-    } else {
-      throw new IllegalStateException(
-        "This method should not be called when groupingExpressions is not 
empty.")
-    }
+    assert(groupingExpressions.isEmpty)
+    assert(inputIter == null)
+    generateOutput(UnsafeRow.createFromByteArray(0, 0), 
initialAggregationBuffer)
+  }
+
+  /** Free memory used in the underlying map. */
+  def free(): Unit = {
+    hashMap.free()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4c6b1296/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
new file mode 100644
index 0000000..ac22c2f
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.unsafe.memory.TaskMemoryManager
+
+class TungstenAggregationIteratorSuite extends SparkFunSuite {
+
+  test("memory acquired on construction") {
+    // set up environment
+    val ctx = TestSQLContext
+
+    val taskMemoryManager = new 
TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+    val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, 
Seq.empty)
+    TaskContext.setTaskContext(taskContext)
+
+    // Assert that a page is allocated before processing starts
+    var iter: TungstenAggregationIterator = null
+    try {
+      val newMutableProjection = (expr: Seq[Expression], schema: 
Seq[Attribute]) => {
+        () => new InterpretedMutableProjection(expr, schema)
+      }
+      val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy")
+      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 
0,
+        Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, 
dummyAccum)
+      val numPages = iter.getHashMap.getNumDataPages
+      assert(numPages === 1)
+    } finally {
+      // Clean up
+      if (iter != null) {
+        iter.free()
+      }
+      TaskContext.unset()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to