This is an automated email from the ASF dual-hosted git repository.

wangzhen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a9a967ac0 [GLUTEN-9611][CORE] Improve performance of PayloadCloser 
Iterator (#9612)
3a9a967ac0 is described below

commit 3a9a967ac0781a3197c75bd027c0155802ad6b66
Author: Zhen Wang <[email protected]>
AuthorDate: Thu May 15 11:31:28 2025 +0800

    [GLUTEN-9611][CORE] Improve performance of PayloadCloser Iterator (#9612)
    
    * [GLUTEN-9611][VL] Improve performance of PayloadCloser Iterator
    
    * add none flag to allow null value callback closer
---
 .../org/apache/gluten/iterator/IteratorsV1.scala   | 19 +++-----
 .../apache/spark/iterator/IteratorBenchmark.scala  | 57 ++++++++++++++++++++++
 2 files changed, 64 insertions(+), 12 deletions(-)

diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala 
b/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
index 120d4cb2b0..3d50b95def 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
@@ -26,7 +26,8 @@ import java.util.concurrent.atomic.AtomicBoolean
 
 object IteratorsV1 {
   private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit) 
extends Iterator[A] {
-    private var closer: Option[() => Unit] = None
+    private val _none = new Object
+    private var _prev: Any = _none
 
     TaskResources.addRecycler("Iterators#PayloadCloser", 100) {
       tryClose()
@@ -39,22 +40,16 @@ object IteratorsV1 {
 
     override def next(): A = {
       val a: A = in.next()
-      closer.synchronized {
-        closer = Some(
-          () => {
-            closeCallback.apply(a)
-          })
+      this.synchronized {
+        _prev = a
       }
       a
     }
 
     private def tryClose(): Unit = {
-      closer.synchronized {
-        closer match {
-          case Some(c) => c.apply()
-          case None =>
-        }
-        closer = None // make sure the payload is closed once
+      this.synchronized {
+        if (_prev != _none) closeCallback.apply(_prev.asInstanceOf[A])
+        _prev = _none // make sure the payload is closed once
       }
     }
   }
diff --git 
a/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala 
b/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
index 047deebfac..063dfd0d92 100644
--- 
a/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
+++ 
b/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
@@ -21,6 +21,9 @@ import org.apache.gluten.iterator.Iterators.V1
 
 import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
 import org.apache.spark.task.TaskResources
+import org.apache.spark.util.ThreadUtils
+
+import java.util.concurrent.TimeUnit
 
 object IteratorBenchmark extends BenchmarkBase {
 
@@ -125,5 +128,59 @@ object IteratorBenchmark extends BenchmarkBase {
         }
       }
     }
+
+    runBenchmark("Iterator Multi Threads") {
+      val nPayloads: Int = 50000000 // 50 millions
+
+      def makeScalaIterator: Iterator[Any] = {
+        (0 until nPayloads).view.map { _: Int => new Object }.iterator
+      }
+
+      def compareMultiThreadsIterator(name: String, threads: Int = 3)(
+          makeGlutenIterator: Iterators.Version => Iterator[Any]): Unit = {
+        val benchmark = new Benchmark(name, nPayloads, output = output)
+        benchmark.addCase("Scala Iterator") {
+          _ =>
+            val pool = ThreadUtils.newDaemonFixedThreadPool(threads, 
"ScalaIterator")
+            for (_ <- 0 until threads) {
+              pool.execute(
+                () => {
+                  TaskResources.runUnsafe {
+                    val count = makeScalaIterator.count(_ => true)
+                    assert(count == nPayloads)
+                  }
+                })
+            }
+            pool.shutdown()
+            pool.awaitTermination(10, TimeUnit.SECONDS)
+        }
+        benchmark.addCase("Gluten Iterator V1") {
+          _ =>
+            val pool = ThreadUtils.newDaemonFixedThreadPool(threads, 
"GlutenIteratorV1")
+            for (_ <- 0 until threads) {
+              pool.execute(
+                () => {
+                  TaskResources.runUnsafe {
+                    val count = makeGlutenIterator(V1).count(_ => true)
+                    assert(count == nPayloads)
+                  }
+                })
+            }
+            pool.shutdown()
+            pool.awaitTermination(10, TimeUnit.SECONDS)
+        }
+        benchmark.run()
+      }
+
+      compareMultiThreadsIterator("Multi Threads - recycle") {
+        version =>
+          var count = 0
+          Iterators
+            .wrap(version, makeScalaIterator)
+            .recyclePayload(_ => count += 1)
+            .recycleIterator(assert(count == nPayloads))
+            .create()
+      }
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to