This is an automated email from the ASF dual-hosted git repository.
wangzhen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 3a9a967ac0 [GLUTEN-9611][CORE] Improve performance of PayloadCloser
Iterator (#9612)
3a9a967ac0 is described below
commit 3a9a967ac0781a3197c75bd027c0155802ad6b66
Author: Zhen Wang <[email protected]>
AuthorDate: Thu May 15 11:31:28 2025 +0800
[GLUTEN-9611][CORE] Improve performance of PayloadCloser Iterator (#9612)
* [GLUTEN-9611][VL] Improve performance of PayloadCloser Iterator
* add none flag to allow null value callback closer
---
.../org/apache/gluten/iterator/IteratorsV1.scala | 19 +++-----
.../apache/spark/iterator/IteratorBenchmark.scala | 57 ++++++++++++++++++++++
2 files changed, 64 insertions(+), 12 deletions(-)
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
b/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
index 120d4cb2b0..3d50b95def 100644
--- a/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
+++ b/gluten-core/src/main/scala/org/apache/gluten/iterator/IteratorsV1.scala
@@ -26,7 +26,8 @@ import java.util.concurrent.atomic.AtomicBoolean
object IteratorsV1 {
private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit)
extends Iterator[A] {
- private var closer: Option[() => Unit] = None
+ private val _none = new Object
+ private var _prev: Any = _none
TaskResources.addRecycler("Iterators#PayloadCloser", 100) {
tryClose()
@@ -39,22 +40,16 @@ object IteratorsV1 {
override def next(): A = {
val a: A = in.next()
- closer.synchronized {
- closer = Some(
- () => {
- closeCallback.apply(a)
- })
+ this.synchronized {
+ _prev = a
}
a
}
private def tryClose(): Unit = {
- closer.synchronized {
- closer match {
- case Some(c) => c.apply()
- case None =>
- }
- closer = None // make sure the payload is closed once
+ this.synchronized {
+ if (_prev != _none) closeCallback.apply(_prev.asInstanceOf[A])
+ _prev = _none // make sure the payload is closed once
}
}
}
diff --git
a/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
b/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
index 047deebfac..063dfd0d92 100644
---
a/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
+++
b/gluten-core/src/test/scala/org/apache/spark/iterator/IteratorBenchmark.scala
@@ -21,6 +21,9 @@ import org.apache.gluten.iterator.Iterators.V1
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.task.TaskResources
+import org.apache.spark.util.ThreadUtils
+
+import java.util.concurrent.TimeUnit
object IteratorBenchmark extends BenchmarkBase {
@@ -125,5 +128,59 @@ object IteratorBenchmark extends BenchmarkBase {
}
}
}
+
+ runBenchmark("Iterator Multi Threads") {
+ val nPayloads: Int = 50000000 // 50 millions
+
+ def makeScalaIterator: Iterator[Any] = {
+ (0 until nPayloads).view.map { _: Int => new Object }.iterator
+ }
+
+ def compareMultiThreadsIterator(name: String, threads: Int = 3)(
+ makeGlutenIterator: Iterators.Version => Iterator[Any]): Unit = {
+ val benchmark = new Benchmark(name, nPayloads, output = output)
+ benchmark.addCase("Scala Iterator") {
+ _ =>
+ val pool = ThreadUtils.newDaemonFixedThreadPool(threads,
"ScalaIterator")
+ for (_ <- 0 until threads) {
+ pool.execute(
+ () => {
+ TaskResources.runUnsafe {
+ val count = makeScalaIterator.count(_ => true)
+ assert(count == nPayloads)
+ }
+ })
+ }
+ pool.shutdown()
+ pool.awaitTermination(10, TimeUnit.SECONDS)
+ }
+ benchmark.addCase("Gluten Iterator V1") {
+ _ =>
+ val pool = ThreadUtils.newDaemonFixedThreadPool(threads,
"GlutenIteratorV1")
+ for (_ <- 0 until threads) {
+ pool.execute(
+ () => {
+ TaskResources.runUnsafe {
+ val count = makeGlutenIterator(V1).count(_ => true)
+ assert(count == nPayloads)
+ }
+ })
+ }
+ pool.shutdown()
+ pool.awaitTermination(10, TimeUnit.SECONDS)
+ }
+ benchmark.run()
+ }
+
+ compareMultiThreadsIterator("Multi Threads - recycle") {
+ version =>
+ var count = 0
+ Iterators
+ .wrap(version, makeScalaIterator)
+ .recyclePayload(_ => count += 1)
+ .recycleIterator(assert(count == nPayloads))
+ .create()
+ }
+ }
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]