cloud-fan commented on a change in pull request #32242: URL: https://github.com/apache/spark/pull/32242#discussion_r618140415
########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ########## @@ -687,40 +725,59 @@ case class HashAggregateExec( val thisPlan = ctx.addReferenceObj("plan", this) - // Create a name for the iterator from the fast hash map, and the code to create fast hash map. - val (iterTermForFastHashMap, createFastHashMap) = if (isFastHashMapEnabled) { - // Generates the fast hash map class and creates the fast hash map term. - val fastHashMapClassName = ctx.freshName("FastHashMap") - if (isVectorizedHashMapEnabled) { - val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() - ctx.addInnerClass(generatedMap) - - // Inline mutable state since not many aggregation operations in a task - fastHashMapTerm = ctx.addMutableState( - fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) - val iter = ctx.addMutableState( - "java.util.Iterator<InternalRow>", - "vectorizedFastHashMapIter", - forceInline = true) - val create = s"$fastHashMapTerm = new $fastHashMapClassName();" - (iter, create) - } else { - val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() - ctx.addInnerClass(generatedMap) - - // Inline mutable state since not many aggregation operations in a task - fastHashMapTerm = ctx.addMutableState( - fastHashMapClassName, "fastHashMap", forceInline = true) - val iter = ctx.addMutableState( - "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>", - "fastHashMapIter", forceInline = true) - val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + - s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());" - (iter, create) - } - } else ("", "") + // Create a name for the iterator from the fast hash map, the code to create + // and add hook to close fast hash map. + val (iterTermForFastHashMap, createFastHashMap, addHookToCloseFastHashMap) = + if (isFastHashMapEnabled) { + // Generates the fast hash map class and creates the fast hash map term. + val fastHashMapClassName = ctx.freshName("FastHashMap") + val (iter, create) = if (isVectorizedHashMapEnabled) { + val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() + ctx.addInnerClass(generatedMap) + + // Inline mutable state since not many aggregation operations in a task + fastHashMapTerm = ctx.addMutableState( + fastHashMapClassName, "vectorizedFastHashMap", forceInline = true) + val iter = ctx.addMutableState( + "java.util.Iterator<InternalRow>", + "vectorizedFastHashMapIter", + forceInline = true) + val create = s"$fastHashMapTerm = new $fastHashMapClassName();" + (iter, create) + } else { + val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema, bitMaxCapacity).generate() + ctx.addInnerClass(generatedMap) + + // Inline mutable state since not many aggregation operations in a task + fastHashMapTerm = ctx.addMutableState( + fastHashMapClassName, "fastHashMap", forceInline = true) + val iter = ctx.addMutableState( + "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>", + "fastHashMapIter", forceInline = true) + val create = s"$fastHashMapTerm = new $fastHashMapClassName(" + + s"$thisPlan.getTaskContext().taskMemoryManager(), " + + s"$thisPlan.getEmptyAggregationBuffer());" + (iter, create) + } + + // Generates the code to register a cleanup task with TaskContext to ensure that memory + // is guaranteed to be freed at the end of the task. This is necessary to avoid memory + // leaks in when the downstream operator does not fully consume the aggregation map's + // output (e.g. aggregate followed by limit). + val hookToCloseFastHashMap = + s""" + |$thisPlan.getTaskContext().addTaskCompletionListener( + | new org.apache.spark.util.TaskCompletionListener() { + | @Override + | public void onTaskCompletion(org.apache.spark.TaskContext context) { + | $fastHashMapTerm.close(); + | } + |}); + """.stripMargin + (iter, create, hookToCloseFastHashMap) Review comment: Can we change the code in a less diff way? ``` ... val hookToCloseFastHashMap = if (isFastHashMapEnabled) { ... } else "" ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org