Repository: spark
Updated Branches:
  refs/heads/master 6fe32869c -> e0559f238


[SPARK-21743][SQL][FOLLOWUP] free aggregate map when task ends

## What changes were proposed in this pull request?

This is the first follow-up of https://github.com/apache/spark/pull/21573 , 
which was only merged to 2.3.

This PR fixes the memory leak in another way: free the `UnsafeExternalMap` when 
the task ends. All the data buffers in Spark SQL are using `UnsafeExternalMap` 
and `UnsafeExternalSorter` under the hood, e.g. sort, aggregate, window, SMJ, 
etc. `UnsafeExternalSorter` registers a task completion listener to free the 
resource, we should apply the same thing to `UnsafeExternalMap`.

TODO in the next PR:
do not consume all the inputs when having limit in whole stage codegen.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #21738 from cloud-fan/limit.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e0559f23
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e0559f23
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e0559f23

Branch: refs/heads/master
Commit: e0559f238009e02c40f65678fec691c07904e8c0
Parents: 6fe3286
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Jul 10 23:07:10 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Jul 10 23:07:10 2018 +0800

----------------------------------------------------------------------
 .../UnsafeFixedWidthAggregationMap.java         | 17 +++++++++++-----
 .../spark/sql/execution/SparkStrategies.scala   |  7 +------
 .../execution/aggregate/HashAggregateExec.scala |  2 +-
 .../aggregate/TungstenAggregationIterator.scala |  2 +-
 .../UnsafeFixedWidthAggregationMapSuite.scala   | 21 ++++++++++++--------
 5 files changed, 28 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e0559f23/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index c7c4c7b..c8cf44b 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution;
 import java.io.IOException;
 
 import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
 import org.apache.spark.internal.config.package$;
-import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -82,7 +82,7 @@ public final class UnsafeFixedWidthAggregationMap {
    * @param emptyAggregationBuffer the default value for new keys (a "zero" of 
the agg. function)
    * @param aggregationBufferSchema the schema of the aggregation buffer, used 
for row conversion.
    * @param groupingKeySchema the schema of the grouping key, used for row 
conversion.
-   * @param taskMemoryManager the memory manager used to allocate our Unsafe 
memory structures.
+   * @param taskContext the current task context.
    * @param initialCapacity the initial capacity of the map (a sizing hint to 
avoid re-hashing).
    * @param pageSizeBytes the data page size, in bytes; limits the maximum 
record size.
    */
@@ -90,19 +90,26 @@ public final class UnsafeFixedWidthAggregationMap {
       InternalRow emptyAggregationBuffer,
       StructType aggregationBufferSchema,
       StructType groupingKeySchema,
-      TaskMemoryManager taskMemoryManager,
+      TaskContext taskContext,
       int initialCapacity,
       long pageSizeBytes) {
     this.aggregationBufferSchema = aggregationBufferSchema;
     this.currentAggregationBuffer = new 
UnsafeRow(aggregationBufferSchema.length());
     this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
     this.groupingKeySchema = groupingKeySchema;
-    this.map =
-      new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, 
true);
+    this.map = new BytesToBytesMap(
+      taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true);
 
     // Initialize the buffer for aggregation value
     final UnsafeProjection valueProjection = 
UnsafeProjection.create(aggregationBufferSchema);
     this.emptyAggregationBuffer = 
valueProjection.apply(emptyAggregationBuffer).getBytes();
+
+    // Register a cleanup task with TaskContext to ensure that memory is 
guaranteed to be freed at
+    // the end of the task. This is necessary to avoid memory leaks in when 
the downstream operator
+    // does not fully consume the aggregation map's output (e.g. aggregate 
followed by limit).
+    taskContext.addTaskCompletionListener(context -> {
+      free();
+    });
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e0559f23/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 07a6fca..cfbcb9a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -73,12 +73,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
             if limit < conf.topKSortFallbackThreshold =>
           TakeOrderedAndProjectExec(limit, order, projectList, 
planLater(child)) :: Nil
         case Limit(IntegerLiteral(limit), child) =>
-          // With whole stage codegen, Spark releases resources only when all 
the output data of the
-          // query plan are consumed. It's possible that `CollectLimitExec` 
only consumes a little
-          // data from child plan and finishes the query without releasing 
resources. Here we wrap
-          // the child plan with `LocalLimitExec`, to stop the processing of 
whole stage codegen and
-          // trigger the resource releasing work, after we consume `limit` 
rows.
-          CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: 
Nil
+          CollectLimitExec(limit, planLater(child)) :: Nil
         case other => planLater(other) :: Nil
       }
       case Limit(IntegerLiteral(limit), Sort(order, true, child))

http://git-wip-us.apache.org/repos/asf/spark/blob/e0559f23/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 8c7b2c1..2cac0cf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -328,7 +328,7 @@ case class HashAggregateExec(
       initialBuffer,
       bufferSchema,
       groupingKeySchema,
-      TaskContext.get().taskMemoryManager(),
+      TaskContext.get(),
       1024 * 16, // initial capacity
       TaskContext.get().taskMemoryManager().pageSizeBytes
     )

http://git-wip-us.apache.org/repos/asf/spark/blob/e0559f23/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 9dc334c..c191123 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -166,7 +166,7 @@ class TungstenAggregationIterator(
     initialAggregationBuffer,
     
StructType.fromAttributes(aggregateFunctions.flatMap(_.aggBufferAttributes)),
     StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
-    TaskContext.get().taskMemoryManager(),
+    TaskContext.get(),
     1024 * 16, // initial capacity
     TaskContext.get().taskMemoryManager().pageSizeBytes
   )

http://git-wip-us.apache.org/repos/asf/spark/blob/e0559f23/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 3e31d22..5c15ecd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable
 import scala.util.{Random, Try}
 import scala.util.control.NonFatal
 
+import org.mockito.Mockito._
 import org.scalatest.Matchers
 
 import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, 
TaskContextImpl}
@@ -54,6 +55,8 @@ class UnsafeFixedWidthAggregationMapSuite
   private var memoryManager: TestMemoryManager = null
   private var taskMemoryManager: TaskMemoryManager = null
 
+  private var taskContext: TaskContext = null
+
   def testWithMemoryLeakDetection(name: String)(f: => Unit) {
     def cleanup(): Unit = {
       if (taskMemoryManager != null) {
@@ -67,6 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite
       val conf = new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false")
       memoryManager = new TestMemoryManager(conf)
       taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
+      taskContext = mock(classOf[TaskContext])
+      when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
 
       TaskContext.setTaskContext(new TaskContextImpl(
         stageId = 0,
@@ -111,7 +116,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       1024, // initial capacity,
       PAGE_SIZE_BYTES
     )
@@ -124,7 +129,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       1024, // initial capacity
       PAGE_SIZE_BYTES
     )
@@ -151,7 +156,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES
     )
@@ -176,7 +181,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES
     )
@@ -223,7 +228,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES
     )
@@ -263,7 +268,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       StructType(Nil),
       StructType(Nil),
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       PAGE_SIZE_BYTES
     )
@@ -307,7 +312,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       pageSize
     )
@@ -344,7 +349,7 @@ class UnsafeFixedWidthAggregationMapSuite
       emptyAggregationBuffer,
       aggBufferSchema,
       groupKeySchema,
-      taskMemoryManager,
+      taskContext,
       128, // initial capacity
       pageSize
     )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to