This is an automated email from the ASF dual-hosted git repository.

zhengchenyu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new f65d09e42 [#2389] fix:(remote merge): Fixed the issue of losing data 
when calling hasNext multiple times. (#2390)
f65d09e42 is described below

commit f65d09e42c96d4fb706e1585a86665a498ae6112
Author: zhengchenyu <[email protected]>
AuthorDate: Wed Mar 12 16:09:04 2025 +0800

    [#2389] fix:(remote merge): Fixed the issue of losing data when calling 
hasNext multiple times. (#2390)
    
    ### What changes were proposed in this pull request?
    
    KeyValueReader::next combines the semantics of hasNext and next, which is 
unreasonable and should be separated.
    
    ### Why are the changes needed?
    
    Fix: #2389
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    unit test and integration test
---
 .../hadoop/mapreduce/task/reduce/RMRssShuffle.java | 12 ++++++---
 .../client/record/reader/KeyValueReader.java       |  8 +++---
 .../client/record/reader/RMRecordsReader.java      | 31 +++++++++++++---------
 .../client/record/reader/RMRecordsReaderTest.java  | 30 ++++++++++++---------
 .../test/RemoteMergeShuffleWithRssClientTest.java  | 30 ++++++++++++---------
 ...ShuffleWithRssClientTestWhenShuffleFlushed.java | 30 ++++++++++++---------
 6 files changed, 83 insertions(+), 58 deletions(-)

diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
index 13b3f215c..f4a1beea9 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RMRssShuffle.java
@@ -46,6 +46,7 @@ import org.apache.hadoop.util.Progress;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.record.Record;
 import org.apache.uniffle.client.record.reader.KeyValueReader;
 import org.apache.uniffle.client.record.reader.RMRecordsReader;
 import org.apache.uniffle.client.record.writer.Combiner;
@@ -184,6 +185,7 @@ public class RMRssShuffle<K, V> implements 
ShuffleConsumerPlugin<K, V>, Exceptio
     RMRecordsReader reader;
     KeyValueReader<ComparativeOutputBuffer, ComparativeOutputBuffer> 
keyValueReader;
     private Progress mergeProgress = new Progress();
+    Record<ComparativeOutputBuffer, ComparativeOutputBuffer> current;
 
     public RecordsRelayer(RMRecordsReader reader, SerializerInstance 
keySerializer) {
       this.reader = reader;
@@ -192,7 +194,7 @@ public class RMRssShuffle<K, V> implements 
ShuffleConsumerPlugin<K, V>, Exceptio
 
     @Override
     public DataInputBuffer getKey() throws IOException {
-      ComparativeOutputBuffer buffer = keyValueReader.getCurrentKey();
+      ComparativeOutputBuffer buffer = current.getKey();
       DataInputBuffer inputBuffer = new DataInputBuffer();
       inputBuffer.reset(buffer.getData(), 0, buffer.getLength());
       return inputBuffer;
@@ -200,7 +202,7 @@ public class RMRssShuffle<K, V> implements 
ShuffleConsumerPlugin<K, V>, Exceptio
 
     @Override
     public DataInputBuffer getValue() throws IOException {
-      ComparativeOutputBuffer buffer = keyValueReader.getCurrentValue();
+      ComparativeOutputBuffer buffer = current.getValue();
       DataInputBuffer inputBuffer = new DataInputBuffer();
       inputBuffer.reset(buffer.getData(), 0, buffer.getLength());
       return inputBuffer;
@@ -208,7 +210,11 @@ public class RMRssShuffle<K, V> implements 
ShuffleConsumerPlugin<K, V>, Exceptio
 
     @Override
     public boolean next() throws IOException {
-      return keyValueReader.next();
+      boolean hasNext = keyValueReader.hasNext();
+      if (hasNext) {
+        current = keyValueReader.next();
+      }
+      return hasNext;
     }
 
     @Override
diff --git 
a/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
 
b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
index 606e97ac4..295e02135 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/record/reader/KeyValueReader.java
@@ -19,11 +19,11 @@ package org.apache.uniffle.client.record.reader;
 
 import java.io.IOException;
 
-public abstract class KeyValueReader<K, V> {
+import org.apache.uniffle.client.record.Record;
 
-  public abstract boolean next() throws IOException;
+public abstract class KeyValueReader<K, V> {
 
-  public abstract K getCurrentKey() throws IOException;
+  public abstract boolean hasNext() throws IOException;
 
-  public abstract V getCurrentValue() throws IOException;
+  public abstract Record<K, V> next() throws IOException;
 }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
 
b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
index 43be684cb..46d727b05 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/record/reader/RMRecordsReader.java
@@ -257,8 +257,11 @@ public class RMRecordsReader<K, V, C> {
       private Record<ComparativeOutputBuffer, ComparativeOutputBuffer> curr = 
null;
 
       @Override
-      public boolean next() throws IOException {
+      public boolean hasNext() throws IOException {
         try {
+          if (curr != null) {
+            return true;
+          }
           curr = results.take();
           return curr != null;
         } catch (InterruptedException e) {
@@ -266,14 +269,11 @@ public class RMRecordsReader<K, V, C> {
         }
       }
 
-      @Override
-      public ComparativeOutputBuffer getCurrentKey() throws IOException {
-        return curr.getKey();
-      }
-
-      @Override
-      public ComparativeOutputBuffer getCurrentValue() throws IOException {
-        return curr.getValue();
+      public Record<ComparativeOutputBuffer, ComparativeOutputBuffer> next() 
throws IOException {
+        Record<ComparativeOutputBuffer, ComparativeOutputBuffer> next =
+            Record.create(curr.getKey(), curr.getValue());
+        curr = null;
+        return next;
       }
     };
   }
@@ -284,8 +284,11 @@ public class RMRecordsReader<K, V, C> {
       private Record<K, C> curr = null;
 
       @Override
-      public boolean next() throws IOException {
+      public boolean hasNext() throws IOException {
         try {
+          if (curr != null) {
+            return true;
+          }
           curr = results.take();
           return curr != null;
         } catch (InterruptedException e) {
@@ -293,7 +296,12 @@ public class RMRecordsReader<K, V, C> {
         }
       }
 
-      @Override
+      public Record<K, C> next() throws IOException {
+        Record record = Record.create(getCurrentKey(), getCurrentValue());
+        curr = null;
+        return record;
+      }
+
       public K getCurrentKey() throws IOException {
         if (raw) {
           ComparativeOutputBuffer keyBuffer = (ComparativeOutputBuffer) 
curr.getKey();
@@ -305,7 +313,6 @@ public class RMRecordsReader<K, V, C> {
         }
       }
 
-      @Override
       public C getCurrentValue() throws IOException {
         if (raw) {
           ComparativeOutputBuffer valueBuffer = (ComparativeOutputBuffer) 
curr.getValue();
diff --git 
a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
 
b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
index 2c543ad25..b1310260b 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/record/reader/RMRecordsReaderTest.java
@@ -31,6 +31,7 @@ import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
 
 import org.apache.uniffle.client.api.ShuffleServerClient;
+import org.apache.uniffle.client.record.Record;
 import org.apache.uniffle.client.record.writer.Combiner;
 import org.apache.uniffle.client.record.writer.SumByKeyCombiner;
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -102,9 +103,10 @@ public class RMRecordsReaderTest {
     readerSpy.start();
     int index = 0;
     KeyValueReader keyValueReader = readerSpy.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(SerializerUtils.genData(valueClass, index), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index), 
record.getValue());
       index++;
     }
     assertEquals(RECORDS_NUM, index);
@@ -172,8 +174,9 @@ public class RMRecordsReaderTest {
     readerSpy.start();
     int index = 0;
     KeyValueReader keyValueReader = readerSpy.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
       Object value = SerializerUtils.genData(valueClass, index);
       Object newValue = value;
       if (index % 2 == 0) {
@@ -183,7 +186,7 @@ public class RMRecordsReaderTest {
           newValue = (int) value * 2;
         }
       }
-      assertEquals(newValue, keyValueReader.getCurrentValue());
+      assertEquals(newValue, record.getValue());
       index++;
     }
     assertEquals(RECORDS_NUM * 2, index);
@@ -250,9 +253,10 @@ public class RMRecordsReaderTest {
     readerSpy.start();
     int index = 0;
     KeyValueReader keyValueReader = readerSpy.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(SerializerUtils.genData(valueClass, index), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index), 
record.getValue());
       index++;
     }
     assertEquals(RECORDS_NUM * 6, index);
@@ -322,10 +326,10 @@ public class RMRecordsReaderTest {
     readerSpy.start();
     int index = 0;
     KeyValueReader keyValueReader = readerSpy.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(
-          SerializerUtils.genData(valueClass, index * 2), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index * 2), 
record.getValue());
       index++;
     }
     assertEquals(RECORDS_NUM * 6, index);
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
index 29336daec..760b27b8a 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
@@ -46,6 +46,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.record.Record;
 import org.apache.uniffle.client.record.reader.KeyValueReader;
 import org.apache.uniffle.client.record.reader.RMRecordsReader;
 import org.apache.uniffle.client.record.writer.Combiner;
@@ -309,9 +310,10 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(SerializerUtils.genData(valueClass, index), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index), 
record.getValue());
       index++;
     }
     assertEquals(5 * RECORD_NUMBER, index);
@@ -479,8 +481,9 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
       Object value = SerializerUtils.genData(valueClass, index);
       Object newValue = value;
       if (index % 3 != 1) {
@@ -490,7 +493,7 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
           newValue = (int) value * 2;
         }
       }
-      assertEquals(newValue, keyValueReader.getCurrentValue());
+      assertEquals(newValue, record.getValue());
       index++;
     }
     assertEquals(3 * RECORD_NUMBER, index);
@@ -697,9 +700,10 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(SerializerUtils.genData(valueClass, index), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index), 
record.getValue());
       index++;
     }
     assertEquals(6 * RECORD_NUMBER, index);
@@ -910,10 +914,10 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(
-          SerializerUtils.genData(valueClass, index * 2), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index * 2), 
record.getValue());
       index++;
     }
     assertEquals(6 * RECORD_NUMBER, index);
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
index 402d17ead..e63aee262 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
@@ -47,6 +47,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.record.Record;
 import org.apache.uniffle.client.record.reader.KeyValueReader;
 import org.apache.uniffle.client.record.reader.RMRecordsReader;
 import org.apache.uniffle.client.record.writer.Combiner;
@@ -322,9 +323,10 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(SerializerUtils.genData(valueClass, index), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index), 
record.getValue());
       index++;
     }
     assertEquals(5 * RECORD_NUMBER, index);
@@ -493,8 +495,9 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
       Object value = SerializerUtils.genData(valueClass, index);
       Object newValue = value;
       if (index % 3 != 1) {
@@ -504,7 +507,7 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
           newValue = (int) value * 2;
         }
       }
-      assertEquals(newValue, keyValueReader.getCurrentValue());
+      assertEquals(newValue, record.getValue());
       index++;
     }
     assertEquals(3 * RECORD_NUMBER, index);
@@ -711,9 +714,10 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(SerializerUtils.genData(valueClass, index), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index), 
record.getValue());
       index++;
     }
     assertEquals(6 * RECORD_NUMBER, index);
@@ -925,10 +929,10 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
     reader.start();
     int index = 0;
     KeyValueReader keyValueReader = reader.keyValueReader();
-    while (keyValueReader.next()) {
-      assertEquals(SerializerUtils.genData(keyClass, index), 
keyValueReader.getCurrentKey());
-      assertEquals(
-          SerializerUtils.genData(valueClass, index * 2), 
keyValueReader.getCurrentValue());
+    while (keyValueReader.hasNext()) {
+      Record record = keyValueReader.next();
+      assertEquals(SerializerUtils.genData(keyClass, index), record.getKey());
+      assertEquals(SerializerUtils.genData(valueClass, index * 2), 
record.getValue());
       index++;
     }
     assertEquals(6 * RECORD_NUMBER, index);

Reply via email to