This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 277ccba [SPARK-31511][SQL][2.4] Make BytesToBytesMap iterators thread-safe 277ccba is described below commit 277ccba382ac19debeaf15cd648cf6ab69603012 Author: sychen <syc...@ctrip.com> AuthorDate: Tue Sep 8 03:23:59 2020 +0000 [SPARK-31511][SQL][2.4] Make BytesToBytesMap iterators thread-safe ### What changes were proposed in this pull request? This PR increases the thread safety of the `BytesToBytesMap`: - It makes the `iterator()` and `destructiveIterator()` methods used their own `Location` object. This used to be shared, and this was causing issues when the map was being iterated over in two threads by two different iterators. - Removes the `safeIterator()` function. This is not needed anymore. - Improves the documentation of a couple of methods w.r.t. thread-safety. ### Why are the changes needed? It is unexpected an iterator shares the object it is returning with all other iterators. This is a violation of the iterator contract, and it causes issues with iterators over a map that are consumed in different threads. ### Does this PR introduce any user-facing change? No ### How was this patch tested? add ut Closes #29605 from cxzl25/SPARK-31511. Lead-authored-by: sychen <syc...@ctrip.com> Co-authored-by: herman <her...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/unsafe/map/BytesToBytesMap.java | 18 +++++----- .../sql/execution/joins/HashedRelationSuite.scala | 39 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 2b23fbbf..5ab52cc 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -421,11 +421,11 @@ public final class BytesToBytesMap extends MemoryConsumer { * * For efficiency, all calls to `next()` will return the same {@link Location} object. * - * If any other lookups or operations are performed on this map while iterating over it, including - * `lookup()`, the behavior of the returned iterator is undefined. + * The returned iterator is thread-safe. However if the map is modified while iterating over it, + * the behavior of the returned iterator is undefined. */ public MapIterator iterator() { - return new MapIterator(numValues, loc, false); + return new MapIterator(numValues, new Location(), false); } /** @@ -435,19 +435,20 @@ public final class BytesToBytesMap extends MemoryConsumer { * * For efficiency, all calls to `next()` will return the same {@link Location} object. * - * If any other lookups or operations are performed on this map while iterating over it, including - * `lookup()`, the behavior of the returned iterator is undefined. + * The returned iterator is thread-safe. However if the map is modified while iterating over it, + * the behavior of the returned iterator is undefined. */ public MapIterator destructiveIterator() { updatePeakMemoryUsed(); - return new MapIterator(numValues, loc, true); + return new MapIterator(numValues, new Location(), true); } /** * Looks up a key, and return a {@link Location} handle that can be used to test existence * and read/write values. * - * This function always return the same {@link Location} instance to avoid object allocation. + * This function always returns the same {@link Location} instance to avoid object allocation. + * This function is not thread-safe. */ public Location lookup(Object keyBase, long keyOffset, int keyLength) { safeLookup(keyBase, keyOffset, keyLength, loc, @@ -459,7 +460,8 @@ public final class BytesToBytesMap extends MemoryConsumer { * Looks up a key, and return a {@link Location} handle that can be used to test existence * and read/write values. * - * This function always return the same {@link Location} instance to avoid object allocation. + * This function always returns the same {@link Location} instance to avoid object allocation. + * This function is not thread-safe. */ public Location lookup(Object keyBase, long keyOffset, int keyLength, int hash) { safeLookup(keyBase, keyOffset, keyLength, loc, hash); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index d9b34dc..1bdd6fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -341,6 +341,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray)) } + test("SPARK-31511: Make BytesToBytesMap iterators thread-safe") { + val ser = sparkContext.env.serializer.newInstance() + val key = Seq(BoundReference(0, LongType, false)) + + val unsafeProj = UnsafeProjection.create( + Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true))) + val rows = (0 until 10000).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy()) + val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm) + + val os = new ByteArrayOutputStream() + val thread1 = new Thread { + override def run(): Unit = { + val out = new ObjectOutputStream(os) + unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + } + } + + val thread2 = new Thread { + override def run(): Unit = { + val threadOut = new ObjectOutputStream(new ByteArrayOutputStream()) + unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(threadOut) + threadOut.flush() + } + } + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed)) + val os2 = new ByteArrayOutputStream() + val out2 = new ObjectOutputStream(os2) + unsafeHashed2.writeExternal(out2) + out2.flush() + assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray)) + } + // This test require 4G heap to run, should run it manually ignore("build HashedRelation that is larger than 1G") { val unsafeProj = UnsafeProjection.create( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org