cloud-fan commented on a change in pull request #32242:
URL: https://github.com/apache/spark/pull/32242#discussion_r618140415



##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
##########
@@ -687,40 +725,59 @@ case class HashAggregateExec(
 
     val thisPlan = ctx.addReferenceObj("plan", this)
 
-    // Create a name for the iterator from the fast hash map, and the code to 
create fast hash map.
-    val (iterTermForFastHashMap, createFastHashMap) = if 
(isFastHashMapEnabled) {
-      // Generates the fast hash map class and creates the fast hash map term.
-      val fastHashMapClassName = ctx.freshName("FastHashMap")
-      if (isVectorizedHashMapEnabled) {
-        val generatedMap = new VectorizedHashMapGenerator(ctx, 
aggregateExpressions,
-          fastHashMapClassName, groupingKeySchema, bufferSchema, 
bitMaxCapacity).generate()
-        ctx.addInnerClass(generatedMap)
-
-        // Inline mutable state since not many aggregation operations in a task
-        fastHashMapTerm = ctx.addMutableState(
-          fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
-        val iter = ctx.addMutableState(
-          "java.util.Iterator<InternalRow>",
-          "vectorizedFastHashMapIter",
-          forceInline = true)
-        val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
-        (iter, create)
-      } else {
-        val generatedMap = new RowBasedHashMapGenerator(ctx, 
aggregateExpressions,
-          fastHashMapClassName, groupingKeySchema, bufferSchema, 
bitMaxCapacity).generate()
-        ctx.addInnerClass(generatedMap)
-
-        // Inline mutable state since not many aggregation operations in a task
-        fastHashMapTerm = ctx.addMutableState(
-          fastHashMapClassName, "fastHashMap", forceInline = true)
-        val iter = ctx.addMutableState(
-          "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
-          "fastHashMapIter", forceInline = true)
-        val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
-          s"$thisPlan.getTaskMemoryManager(), 
$thisPlan.getEmptyAggregationBuffer());"
-        (iter, create)
-      }
-    } else ("", "")
+    // Create a name for the iterator from the fast hash map, the code to 
create
+    // and add hook to close fast hash map.
+    val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) 
=
+      if (isFastHashMapEnabled) {
+        // Generates the fast hash map class and creates the fast hash map 
term.
+        val fastHashMapClassName = ctx.freshName("FastHashMap")
+        val (iter, create) = if (isVectorizedHashMapEnabled) {
+          val generatedMap = new VectorizedHashMapGenerator(ctx, 
aggregateExpressions,
+            fastHashMapClassName, groupingKeySchema, bufferSchema, 
bitMaxCapacity).generate()
+          ctx.addInnerClass(generatedMap)
+
+          // Inline mutable state since not many aggregation operations in a 
task
+          fastHashMapTerm = ctx.addMutableState(
+            fastHashMapClassName, "vectorizedFastHashMap", forceInline = true)
+          val iter = ctx.addMutableState(
+            "java.util.Iterator<InternalRow>",
+            "vectorizedFastHashMapIter",
+            forceInline = true)
+          val create = s"$fastHashMapTerm = new $fastHashMapClassName();"
+          (iter, create)
+        } else {
+          val generatedMap = new RowBasedHashMapGenerator(ctx, 
aggregateExpressions,
+            fastHashMapClassName, groupingKeySchema, bufferSchema, 
bitMaxCapacity).generate()
+          ctx.addInnerClass(generatedMap)
+
+          // Inline mutable state since not many aggregation operations in a 
task
+          fastHashMapTerm = ctx.addMutableState(
+            fastHashMapClassName, "fastHashMap", forceInline = true)
+          val iter = ctx.addMutableState(
+            "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
+            "fastHashMapIter", forceInline = true)
+          val create = s"$fastHashMapTerm = new $fastHashMapClassName(" +
+            s"$thisPlan.getTaskContext().taskMemoryManager(), " +
+            s"$thisPlan.getEmptyAggregationBuffer());"
+          (iter, create)
+        }
+
+        // Generates the code to register a cleanup task with TaskContext to 
ensure that memory
+        // is guaranteed to be freed at the end of the task. This is necessary 
to avoid memory
+        // leaks in when the downstream operator does not fully consume the 
aggregation map's
+        // output (e.g. aggregate followed by limit).
+        val hookToCloseFastHashMap =
+          s"""
+             |$thisPlan.getTaskContext().addTaskCompletionListener(
+             |  new org.apache.spark.util.TaskCompletionListener() {
+             |    @Override
+             |    public void onTaskCompletion(org.apache.spark.TaskContext 
context) {
+             |      $fastHashMapTerm.close();
+             |    }
+             |});
+           """.stripMargin
+        (iter, create, hookToCloseFastHashMap)

Review comment:
       Can we change the code in a less diff way?
   ```
   ...
   val hookToCloseFastHashMap = if (isFastHashMapEnabled) {
    ...
   } else ""
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to