This is an automated email from the ASF dual-hosted git repository.
zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new f65d09e42 [#2389] fix:(remote merge): Fixed the issue of losing data
when calling hasNext multiple times. (#2390)
f65d09e42 is described below
commit f65d09e42c96d4fb706e1585a86665a498ae6112
Author: zhengchenyu <[email protected]>
AuthorDate: Wed Mar 12 16:09:04 2025 +0800
[#2389] fix:(remote merge): Fixed the issue of losing data when calling
hasNext multiple times. (#2390)
### What changes were proposed in this pull request?
KeyValueReader::next combines the semantics of hasNext and next, which is
unreasonable and should be separated.
### Why are the changes needed?
Fix: #2389
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
unit test and integration test
---
.../hadoop/mapreduce/task/reduce/RMRssShuffle.java | 12 ++++++---
.../client/record/reader/KeyValueReader.java | 8 +++---
.../client/record/reader/RMRecordsReader.java | 31 +++++++++++++---------
.../client/record/reader/RMRecordsReaderTest.java | 30 ++++++++++++---------
.../test/RemoteMergeShuffleWithRssClientTest.java | 30 ++++++++++++---------
...ShuffleWithRssClientTestWhenShuffleFlushed.java | 30 ++++++++++++---------
6 files changed, 83 insertions(+), 58 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
index 13b3f215c..f4a1beea9 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
@@ -46,6 +46,7 @@ import org.apache.hadoop.util.Progress;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.Record;
import org.apache.uniffle.client.record.reader.KeyValueReader;
import org.apache.uniffle.client.record.reader.RMRecordsReader;
import org.apache.uniffle.client.record.writer.Combiner;
@@ -184,6 +185,7 @@ public class RMRssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, Exceptio
RMRecordsReader reader;
KeyValueReader<ComparativeOutputBuffer, ComparativeOutputBuffer>
keyValueReader;
private Progress mergeProgress = new Progress();
+ Record<ComparativeOutputBuffer, ComparativeOutputBuffer> current;
public RecordsRelayer(RMRecordsReader reader, SerializerInstance
keySerializer) {
this.reader = reader;
@@ -192,7 +194,7 @@ public class RMRssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, Exceptio
@Override
public DataInputBuffer getKey() throws IOException {
- ComparativeOutputBuffer buffer = keyValueReader.getCurrentKey();
+ ComparativeOutputBuffer buffer = current.getKey();
DataInputBuffer inputBuffer = new DataInputBuffer();
inputBuffer.reset(buffer.getData(), 0, buffer.getLength());
return inputBuffer;
@@ -200,7 +202,7 @@ public class RMRssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, Exceptio
@Override
public DataInputBuffer getValue() throws IOException {
- ComparativeOutputBuffer buffer = keyValueReader.getCurrentValue();
+ ComparativeOutputBuffer buffer = current.getValue();
DataInputBuffer inputBuffer = new DataInputBuffer();
inputBuffer.reset(buffer.getData(), 0, buffer.getLength());
return inputBuffer;
@@ -208,7 +210,11 @@ public class RMRssShuffle<K, V> implements
ShuffleConsumerPlugin<K, V>, Exceptio
@Override
public boolean next() throws IOException {
- return keyValueReader.next();
+ boolean hasNext = keyValueReader.hasNext();
+ if (hasNext) {
+ current = keyValueReader.next();
+ }
+ return hasNext;
}
@Override
diff --git
a/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
index 606e97ac4..295e02135 100644
---
a/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
+++
b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
@@ -19,11 +19,11 @@ package org.apache.uniffle.client.record.reader;
import java.io.IOException;
-public abstract class KeyValueReader<K, V> {
+import org.apache.uniffle.client.record.Record;
- public abstract boolean next() throws IOException;
+public abstract class KeyValueReader<K, V> {
- public abstract K getCurrentKey() throws IOException;
+ public abstract boolean hasNext() throws IOException;
- public abstract V getCurrentValue() throws IOException;
+ public abstract Record<K, V> next() throws IOException;
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
index 43be684cb..46d727b05 100644
---
a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
+++
b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
@@ -257,8 +257,11 @@ public class RMRecordsReader<K, V, C> {
private Record<ComparativeOutputBuffer, ComparativeOutputBuffer> curr =
null;
@Override
- public boolean next() throws IOException {
+ public boolean hasNext() throws IOException {
try {
+ if (curr != null) {
+ return true;
+ }
curr = results.take();
return curr != null;
} catch (InterruptedException e) {
@@ -266,14 +269,11 @@ public class RMRecordsReader<K, V, C> {
}
}
- @Override
- public ComparativeOutputBuffer getCurrentKey() throws IOException {
- return curr.getKey();
- }
-
- @Override
- public ComparativeOutputBuffer getCurrentValue() throws IOException {
- return curr.getValue();
+ public Record<ComparativeOutputBuffer, ComparativeOutputBuffer> next()
throws IOException {
+ Record<ComparativeOutputBuffer, ComparativeOutputBuffer> next =
+ Record.create(curr.getKey(), curr.getValue());
+ curr = null;
+ return next;
}
};
}
@@ -284,8 +284,11 @@ public class RMRecordsReader<K, V, C> {
private Record<K, C> curr = null;
@Override
- public boolean next() throws IOException {
+ public boolean hasNext() throws IOException {
try {
+ if (curr != null) {
+ return true;
+ }
curr = results.take();
return curr != null;
} catch (InterruptedException e) {
@@ -293,7 +296,12 @@ public class RMRecordsReader<K, V, C> {
}
}
- @Override
+ public Record<K, C> next() throws IOException {
+ Record record = Record.create(getCurrentKey(), getCurrentValue());
+ curr = null;
+ return record;
+ }
+
public K getCurrentKey() throws IOException {
if (raw) {
ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer)
curr.getKey();
@@ -305,7 +313,6 @@ public class RMRecordsReader<K, V, C> {
}
}
- @Override
public C getCurrentValue() throws IOException {
if (raw) {
ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer)
curr.getValue();
diff --git
a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
index 2c543ad25..b1310260b 100644
---
a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
+++
b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
@@ -31,6 +31,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.record.Record;
import org.apache.uniffle.client.record.writer.Combiner;
import org.apache.uniffle.client.record.writer.SumByKeyCombiner;
import org.apache.uniffle.common.ShuffleServerInfo;
@@ -102,9 +103,10 @@ public class RMRecordsReaderTest {
readerSpy.start();
int index = 0;
KeyValueReader keyValueReader = readerSpy.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(SerializerUtils.genData(valueClass, index),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
record.getValue());
index++;
}
assertEquals(RECORDS_NUM, index);
@@ -172,8 +174,9 @@ public class RMRecordsReaderTest {
readerSpy.start();
int index = 0;
KeyValueReader keyValueReader = readerSpy.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
Object value = SerializerUtils.genData(valueClass, index);
Object newValue = value;
if (index % 2 == 0) {
@@ -183,7 +186,7 @@ public class RMRecordsReaderTest {
newValue = (int) value * 2;
}
}
- assertEquals(newValue, keyValueReader.getCurrentValue());
+ assertEquals(newValue, record.getValue());
index++;
}
assertEquals(RECORDS_NUM * 2, index);
@@ -250,9 +253,10 @@ public class RMRecordsReaderTest {
readerSpy.start();
int index = 0;
KeyValueReader keyValueReader = readerSpy.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(SerializerUtils.genData(valueClass, index),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
record.getValue());
index++;
}
assertEquals(RECORDS_NUM * 6, index);
@@ -322,10 +326,10 @@ public class RMRecordsReaderTest {
readerSpy.start();
int index = 0;
KeyValueReader keyValueReader = readerSpy.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(
- SerializerUtils.genData(valueClass, index * 2),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index * 2),
record.getValue());
index++;
}
assertEquals(RECORDS_NUM * 6, index);
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
index 29336daec..760b27b8a 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
@@ -46,6 +46,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.record.Record;
import org.apache.uniffle.client.record.reader.KeyValueReader;
import org.apache.uniffle.client.record.reader.RMRecordsReader;
import org.apache.uniffle.client.record.writer.Combiner;
@@ -309,9 +310,10 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(SerializerUtils.genData(valueClass, index),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
record.getValue());
index++;
}
assertEquals(5 * RECORD_NUMBER, index);
@@ -479,8 +481,9 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
Object value = SerializerUtils.genData(valueClass, index);
Object newValue = value;
if (index % 3 != 1) {
@@ -490,7 +493,7 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
newValue = (int) value * 2;
}
}
- assertEquals(newValue, keyValueReader.getCurrentValue());
+ assertEquals(newValue, record.getValue());
index++;
}
assertEquals(3 * RECORD_NUMBER, index);
@@ -697,9 +700,10 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(SerializerUtils.genData(valueClass, index),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
record.getValue());
index++;
}
assertEquals(6 * RECORD_NUMBER, index);
@@ -910,10 +914,10 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(
- SerializerUtils.genData(valueClass, index * 2),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index * 2),
record.getValue());
index++;
}
assertEquals(6 * RECORD_NUMBER, index);
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
index 402d17ead..e63aee262 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
@@ -47,6 +47,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.record.Record;
import org.apache.uniffle.client.record.reader.KeyValueReader;
import org.apache.uniffle.client.record.reader.RMRecordsReader;
import org.apache.uniffle.client.record.writer.Combiner;
@@ -322,9 +323,10 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(SerializerUtils.genData(valueClass, index),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
record.getValue());
index++;
}
assertEquals(5 * RECORD_NUMBER, index);
@@ -493,8 +495,9 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
Object value = SerializerUtils.genData(valueClass, index);
Object newValue = value;
if (index % 3 != 1) {
@@ -504,7 +507,7 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
newValue = (int) value * 2;
}
}
- assertEquals(newValue, keyValueReader.getCurrentValue());
+ assertEquals(newValue, record.getValue());
index++;
}
assertEquals(3 * RECORD_NUMBER, index);
@@ -711,9 +714,10 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(SerializerUtils.genData(valueClass, index),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index),
record.getValue());
index++;
}
assertEquals(6 * RECORD_NUMBER, index);
@@ -925,10 +929,10 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
reader.start();
int index = 0;
KeyValueReader keyValueReader = reader.keyValueReader();
- while (keyValueReader.next()) {
- assertEquals(SerializerUtils.genData(keyClass, index),
keyValueReader.getCurrentKey());
- assertEquals(
- SerializerUtils.genData(valueClass, index * 2),
keyValueReader.getCurrentValue());
+ while (keyValueReader.hasNext()) {
+ Record record = keyValueReader.next();
+ assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+ assertEquals(SerializerUtils.genData(valueClass, index * 2),
record.getValue());
index++;
}
assertEquals(6 * RECORD_NUMBER, index);