This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 9641f9425 [CELEBORN-369] [FLINK] Add ut for 
RemoteShuffleResultPartition (#1297)
9641f9425 is described below

commit 9641f942558496876df3fb4ab5fc4d9907868667
Author: zhongqiangchen <[email protected]>
AuthorDate: Fri Mar 3 15:46:43 2023 +0800

    [CELEBORN-369] [FLINK] Add ut for RemoteShuffleResultPartition (#1297)
---
 .../flink/RemoteShuffleResultPartitionSuiteJ.java  | 538 ++++++++++++++++++++-
 1 file changed, 536 insertions(+), 2 deletions(-)

diff --git 
a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
 
b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
index 384adf9dd..62b71d712 100644
--- 
a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
+++ 
b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
@@ -17,7 +17,11 @@
 
 package org.apache.celeborn.plugin.flink;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyString;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -25,28 +29,89 @@ import static org.mockito.Mockito.when;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.time.Duration;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Random;
+import java.util.Set;
+import java.util.stream.IntStream;
 
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.util.function.SupplierWithException;
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.Mockito;
 
+import org.apache.celeborn.client.LifecycleManager;
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.client.ShuffleClientImpl;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
 import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
 
 public class RemoteShuffleResultPartitionSuiteJ {
-  private BufferCompressor bufferCompressor = new BufferCompressor(32 * 1024, 
"lz4");
+  private final int networkBufferSize = 32 * 1024;
+  private BufferCompressor bufferCompressor = new 
BufferCompressor(networkBufferSize, "lz4");
   private RemoteShuffleOutputGate remoteShuffleOutputGate = 
mock(RemoteShuffleOutputGate.class);
+  private final String compressCodec = "LZ4";
+  private final CelebornConf conf = new CelebornConf();
+  BufferDecompressor bufferDecompressor = new 
BufferDecompressor(networkBufferSize, "LZ4");
+
+  private static final int totalBuffers = 1000;
+
+  private static final int bufferSize = 1024;
+
+  private NetworkBufferPool globalBufferPool;
+
+  private BufferPool sortBufferPool;
+
+  private BufferPool nettyBufferPool;
+
+  private RemoteShuffleResultPartition partitionWriter;
+
+  private FakedRemoteShuffleOutputGate outputGate;
 
   @Before
-  public void setup() {}
+  public void setup() {
+    globalBufferPool = new NetworkBufferPool(totalBuffers, bufferSize);
+  }
+
+  @After
+  public void tearDown() throws Exception {
+    if (outputGate != null) {
+      outputGate.release();
+    }
+
+    if (sortBufferPool != null) {
+      sortBufferPool.lazyDestroy();
+    }
+    if (nettyBufferPool != null) {
+      nettyBufferPool.lazyDestroy();
+    }
+    assertEquals(totalBuffers, 
globalBufferPool.getNumberOfAvailableMemorySegments());
+    globalBufferPool.destroy();
+  }
 
   @Test
   public void tesSimpleFlush() throws IOException, InterruptedException {
@@ -88,4 +153,473 @@ public class RemoteShuffleResultPartitionSuiteJ {
     factories.add(() -> networkBufferPool.createBufferPool(numForOutputGate, 
numForOutputGate));
     return factories;
   }
+
+  @Test
+  public void testWriteNormalRecordWithCompressionEnabled() throws Exception {
+    testWriteNormalRecord(true);
+  }
+
+  @Test
+  public void testWriteNormalRecordWithCompressionDisabled() throws Exception {
+    testWriteNormalRecord(false);
+  }
+
+  @Test
+  public void testWriteLargeRecord() throws Exception {
+    int numSubpartitions = 2;
+    int numBuffers = 100;
+    initResultPartitionWriter(numSubpartitions, 10, 200, false, conf, 10);
+
+    partitionWriter.setup();
+
+    byte[] dataWritten = new byte[bufferSize * numBuffers];
+    Random random = new Random();
+    random.nextBytes(dataWritten);
+    ByteBuffer recordWritten = ByteBuffer.wrap(dataWritten);
+    partitionWriter.emitRecord(recordWritten, 0);
+    assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.finish();
+    partitionWriter.close();
+
+    List<Buffer> receivedBuffers = outputGate.getReceivedBuffers()[0];
+
+    ByteBuffer recordRead = ByteBuffer.allocate(bufferSize * numBuffers);
+    for (Buffer buffer : receivedBuffers) {
+      if (buffer.isBuffer()) {
+        recordRead.put(
+            buffer.getNioBuffer(
+                BufferUtils.HEADER_LENGTH, buffer.readableBytes() - 
BufferUtils.HEADER_LENGTH));
+      }
+    }
+    recordWritten.rewind();
+    recordRead.flip();
+    assertEquals(recordWritten, recordRead);
+  }
+
+  @Test
+  public void testBroadcastLargeRecord() throws Exception {
+    int numSubpartitions = 2;
+    int numBuffers = 100;
+    initResultPartitionWriter(numSubpartitions, 10, 200, false, conf, 10);
+
+    partitionWriter.setup();
+
+    byte[] dataWritten = new byte[bufferSize * numBuffers];
+    Random random = new Random();
+    random.nextBytes(dataWritten);
+    ByteBuffer recordWritten = ByteBuffer.wrap(dataWritten);
+    partitionWriter.broadcastRecord(recordWritten);
+    assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.finish();
+    partitionWriter.close();
+
+    ByteBuffer recordRead0 = ByteBuffer.allocate(bufferSize * numBuffers);
+    for (Buffer buffer : outputGate.getReceivedBuffers()[0]) {
+      if (buffer.isBuffer()) {
+        recordRead0.put(
+            buffer.getNioBuffer(
+                BufferUtils.HEADER_LENGTH, buffer.readableBytes() - 
BufferUtils.HEADER_LENGTH));
+      }
+    }
+    recordWritten.rewind();
+    recordRead0.flip();
+    assertEquals(recordWritten, recordRead0);
+
+    ByteBuffer recordRead1 = ByteBuffer.allocate(bufferSize * numBuffers);
+    for (Buffer buffer : outputGate.getReceivedBuffers()[1]) {
+      if (buffer.isBuffer()) {
+        recordRead1.put(
+            buffer.getNioBuffer(
+                BufferUtils.HEADER_LENGTH, buffer.readableBytes() - 
BufferUtils.HEADER_LENGTH));
+      }
+    }
+    recordWritten.rewind();
+    recordRead1.flip();
+    assertEquals(recordWritten, recordRead0);
+  }
+
+  @Test
+  public void testFlush() throws Exception {
+    int numSubpartitions = 10;
+
+    initResultPartitionWriter(numSubpartitions, 10, 20, false, conf, 10);
+    partitionWriter.setup();
+
+    partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 0);
+    partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 1);
+    assertEquals(3, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.broadcastRecord(ByteBuffer.allocate(bufferSize));
+    assertEquals(2, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.flush(0);
+    assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 2);
+    partitionWriter.emitRecord(ByteBuffer.allocate(bufferSize), 3);
+    assertEquals(3, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.flushAll();
+    assertEquals(0, sortBufferPool.bestEffortGetNumOfUsedBuffers());
+
+    partitionWriter.finish();
+    partitionWriter.close();
+  }
+
+  private void testWriteNormalRecord(boolean compressionEnabled) throws 
Exception {
+    int numSubpartitions = 4;
+    int numRecords = 100;
+    Random random = new Random();
+
+    initResultPartitionWriter(numSubpartitions, 100, 500, compressionEnabled, 
conf, 10);
+    partitionWriter.setup();
+    assertTrue(outputGate.isSetup());
+
+    Queue<DataAndType>[] dataWritten = new Queue[numSubpartitions];
+    IntStream.range(0, numSubpartitions).forEach(i -> dataWritten[i] = new 
ArrayDeque<>());
+    int[] numBytesWritten = new int[numSubpartitions];
+    Arrays.fill(numBytesWritten, 0);
+
+    for (int i = 0; i < numRecords; i++) {
+      byte[] data = new byte[random.nextInt(2 * bufferSize) + 1];
+      if (compressionEnabled) {
+        byte randomByte = (byte) random.nextInt();
+        Arrays.fill(data, randomByte);
+      } else {
+        random.nextBytes(data);
+      }
+      ByteBuffer record = ByteBuffer.wrap(data);
+      boolean isBroadCast = random.nextBoolean();
+
+      if (isBroadCast) {
+        partitionWriter.broadcastRecord(record);
+        IntStream.range(0, numSubpartitions)
+            .forEach(
+                subpartition ->
+                    recordDataWritten(
+                        record,
+                        Buffer.DataType.DATA_BUFFER,
+                        subpartition,
+                        dataWritten,
+                        numBytesWritten));
+      } else {
+        int subpartition = random.nextInt(numSubpartitions);
+        partitionWriter.emitRecord(record, subpartition);
+        recordDataWritten(
+            record, Buffer.DataType.DATA_BUFFER, subpartition, dataWritten, 
numBytesWritten);
+      }
+    }
+
+    partitionWriter.finish();
+    assertTrue(outputGate.isFinished());
+    partitionWriter.close();
+    assertTrue(outputGate.isClosed());
+
+    for (int subpartition = 0; subpartition < numSubpartitions; 
++subpartition) {
+      ByteBuffer record = 
EventSerializer.toSerializedEvent(EndOfPartitionEvent.INSTANCE);
+      recordDataWritten(
+          record, Buffer.DataType.EVENT_BUFFER, subpartition, dataWritten, 
numBytesWritten);
+    }
+
+    outputGate
+        .getFinishedRegions()
+        .forEach(
+            regionIndex -> 
assertTrue(outputGate.getNumBuffersByRegion().containsKey(regionIndex)));
+
+    int[] numBytesRead = new int[numSubpartitions];
+    List<Buffer>[] receivedBuffers = outputGate.getReceivedBuffers();
+    List<Buffer>[] validateTarget = new List[numSubpartitions];
+    Arrays.fill(numBytesRead, 0);
+    for (int i = 0; i < numSubpartitions; i++) {
+      validateTarget[i] = new ArrayList<>();
+      for (Buffer buffer : receivedBuffers[i]) {
+        for (Buffer unpackedBuffer : BufferPacker.unpack(buffer.asByteBuf())) {
+          if (compressionEnabled && unpackedBuffer.isCompressed()) {
+            Buffer decompressedBuffer =
+                
bufferDecompressor.decompressToIntermediateBuffer(unpackedBuffer);
+            ByteBuffer decompressed = 
decompressedBuffer.getNioBufferReadable();
+            int numBytes = decompressed.remaining();
+            MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(numBytes);
+            segment.put(0, decompressed, numBytes);
+            decompressedBuffer.recycleBuffer();
+            validateTarget[i].add(
+                new NetworkBuffer(segment, buf -> {}, 
unpackedBuffer.getDataType(), numBytes));
+            numBytesRead[i] += numBytes;
+          } else {
+            numBytesRead[i] += buffer.readableBytes();
+            validateTarget[i].add(buffer);
+          }
+        }
+      }
+    }
+    IntStream.range(0, numSubpartitions).forEach(subpartitions -> {});
+    checkWriteReadResult(
+        numSubpartitions, numBytesWritten, numBytesWritten, dataWritten, 
validateTarget);
+  }
+
+  private void initResultPartitionWriter(
+      int numSubpartitions,
+      int sortBufferPoolSize,
+      int nettyBufferPoolSize,
+      boolean compressionEnabled,
+      CelebornConf conf,
+      int numMappers)
+      throws Exception {
+
+    sortBufferPool = globalBufferPool.createBufferPool(sortBufferPoolSize, 
sortBufferPoolSize);
+    nettyBufferPool = globalBufferPool.createBufferPool(nettyBufferPoolSize, 
nettyBufferPoolSize);
+
+    outputGate =
+        new FakedRemoteShuffleOutputGate(
+            getShuffleDescriptor(), numSubpartitions, () -> nettyBufferPool, 
conf, numMappers);
+    outputGate.setup();
+
+    if (compressionEnabled) {
+      partitionWriter =
+          new RemoteShuffleResultPartition(
+              "RemoteShuffleResultPartitionWriterTest",
+              0,
+              new ResultPartitionID(),
+              ResultPartitionType.BLOCKING,
+              numSubpartitions,
+              numSubpartitions,
+              bufferSize,
+              new ResultPartitionManager(),
+              bufferCompressor,
+              () -> sortBufferPool,
+              outputGate);
+    } else {
+      partitionWriter =
+          new RemoteShuffleResultPartition(
+              "RemoteShuffleResultPartitionWriterTest",
+              0,
+              new ResultPartitionID(),
+              ResultPartitionType.BLOCKING,
+              numSubpartitions,
+              numSubpartitions,
+              bufferSize,
+              new ResultPartitionManager(),
+              null,
+              () -> sortBufferPool,
+              outputGate);
+    }
+  }
+
+  private void recordDataWritten(
+      ByteBuffer record,
+      Buffer.DataType dataType,
+      int subpartition,
+      Queue<DataAndType>[] dataWritten,
+      int[] numBytesWritten) {
+
+    record.rewind();
+    dataWritten[subpartition].add(new DataAndType(record, dataType));
+    numBytesWritten[subpartition] += record.remaining();
+  }
+
+  private static class FakedRemoteShuffleOutputGate extends 
RemoteShuffleOutputGate {
+
+    private boolean isSetup;
+    private boolean isFinished;
+    private boolean isClosed;
+    private final List<Buffer>[] receivedBuffers;
+    private final Map<Integer, Integer> numBuffersByRegion;
+    private final Set<Integer> finishedRegions;
+    private int currentRegionIndex;
+    private boolean currentIsBroadcast;
+
+    FakedRemoteShuffleOutputGate(
+        RemoteShuffleDescriptor shuffleDescriptor,
+        int numSubpartitions,
+        SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+        CelebornConf celebornConf,
+        int numMappers) {
+
+      super(
+          shuffleDescriptor,
+          numSubpartitions,
+          bufferSize,
+          bufferPoolFactory,
+          celebornConf,
+          numMappers);
+      isSetup = false;
+      isFinished = false;
+      isClosed = false;
+      numBuffersByRegion = new HashMap<>();
+      finishedRegions = new HashSet<>();
+      currentRegionIndex = -1;
+      receivedBuffers = new ArrayList[numSubpartitions];
+      IntStream.range(0, numSubpartitions).forEach(i -> receivedBuffers[i] = 
new ArrayList<>());
+      currentIsBroadcast = false;
+    }
+
+    @Override
+    ShuffleClient createWriteClient() {
+      ShuffleClient client = mock(ShuffleClientImpl.class);
+
+      doNothing().when(client).cleanup(anyString(), anyInt(), anyInt(), 
anyInt());
+      return client;
+    }
+
+    @Override
+    public void setup() throws IOException, InterruptedException {
+      bufferPool = bufferPoolFactory.get();
+      isSetup = true;
+    }
+
+    @Override
+    public void write(Buffer buffer, int subIdx) {
+      if (currentIsBroadcast) {
+        assertEquals(0, subIdx);
+        ByteBuffer byteBuffer = buffer.getNioBufferReadable();
+        for (int i = 0; i < numSubs; i++) {
+          int numBytes = buffer.readableBytes();
+          MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(numBytes);
+          byteBuffer.rewind();
+          segment.put(0, byteBuffer, numBytes);
+          receivedBuffers[i].add(
+              new NetworkBuffer(
+                  segment, buf -> {}, buffer.getDataType(), 
buffer.isCompressed(), numBytes));
+        }
+        buffer.recycleBuffer();
+      } else {
+        receivedBuffers[subIdx].add(buffer);
+      }
+      if (numBuffersByRegion.containsKey(currentRegionIndex)) {
+        int prev = numBuffersByRegion.get(currentRegionIndex);
+        numBuffersByRegion.put(currentRegionIndex, prev + 1);
+      } else {
+        numBuffersByRegion.put(currentRegionIndex, 1);
+      }
+    }
+
+    @Override
+    public void regionStart(boolean isBroadcast) {
+      currentIsBroadcast = isBroadcast;
+      currentRegionIndex++;
+    }
+
+    @Override
+    public void regionFinish() {
+      if (finishedRegions.contains(currentRegionIndex)) {
+        throw new IllegalStateException("Unexpected region: " + 
currentRegionIndex);
+      }
+      finishedRegions.add(currentRegionIndex);
+    }
+
+    @Override
+    public void finish() throws InterruptedException {
+      isFinished = true;
+    }
+
+    @Override
+    public void close() {
+      isClosed = true;
+    }
+
+    public List<Buffer>[] getReceivedBuffers() {
+      return receivedBuffers;
+    }
+
+    public Map<Integer, Integer> getNumBuffersByRegion() {
+      return numBuffersByRegion;
+    }
+
+    public Set<Integer> getFinishedRegions() {
+      return finishedRegions;
+    }
+
+    public boolean isSetup() {
+      return isSetup;
+    }
+
+    public boolean isFinished() {
+      return isFinished;
+    }
+
+    public boolean isClosed() {
+      return isClosed;
+    }
+
+    public void release() throws Exception {
+      IntStream.range(0, numSubs)
+          .forEach(
+              subpartitionIndex -> {
+                
receivedBuffers[subpartitionIndex].forEach(Buffer::recycleBuffer);
+                receivedBuffers[subpartitionIndex].clear();
+              });
+      numBuffersByRegion.clear();
+      finishedRegions.clear();
+      super.close();
+    }
+  }
+
+  private RemoteShuffleDescriptor getShuffleDescriptor() throws Exception {
+    Random random = new Random();
+    byte[] bytes = new byte[16];
+    random.nextBytes(bytes);
+    LifecycleManager.ShuffleTask shuffleTask = 
Mockito.mock(LifecycleManager.ShuffleTask.class);
+    Mockito.when(shuffleTask.attemptId()).thenReturn(1);
+    Mockito.when(shuffleTask.mapId()).thenReturn(1);
+    Mockito.when(shuffleTask.shuffleId()).thenReturn(1);
+    return new RemoteShuffleDescriptor(
+        new JobID(bytes).toString(),
+        new JobID(bytes).toString(),
+        new ResultPartitionID(),
+        new RemoteShuffleResource("1", 2, new 
ShuffleResourceDescriptor(shuffleTask)));
+  }
+
+  /** Data written and its {@link Buffer.DataType}. */
+  public static class DataAndType {
+    private final ByteBuffer data;
+    private final Buffer.DataType dataType;
+
+    DataAndType(ByteBuffer data, Buffer.DataType dataType) {
+      this.data = data;
+      this.dataType = dataType;
+    }
+  }
+
+  public static void checkWriteReadResult(
+      int numSubpartitions,
+      int[] numBytesWritten,
+      int[] numBytesRead,
+      Queue<DataAndType>[] dataWritten,
+      Collection<Buffer>[] buffersRead) {
+    for (int subpartitionIndex = 0; subpartitionIndex < numSubpartitions; 
++subpartitionIndex) {
+      assertEquals(numBytesWritten[subpartitionIndex], 
numBytesRead[subpartitionIndex]);
+
+      List<DataAndType> eventsWritten = new ArrayList<>();
+      List<Buffer> eventsRead = new ArrayList<>();
+
+      ByteBuffer subpartitionDataWritten = 
ByteBuffer.allocate(numBytesWritten[subpartitionIndex]);
+      for (DataAndType dataAndType : dataWritten[subpartitionIndex]) {
+        subpartitionDataWritten.put(dataAndType.data);
+        dataAndType.data.rewind();
+        if (dataAndType.dataType.isEvent()) {
+          eventsWritten.add(dataAndType);
+        }
+      }
+
+      ByteBuffer subpartitionDataRead = 
ByteBuffer.allocate(numBytesRead[subpartitionIndex]);
+      for (Buffer buffer : buffersRead[subpartitionIndex]) {
+        subpartitionDataRead.put(buffer.getNioBufferReadable());
+        if (!buffer.isBuffer()) {
+          eventsRead.add(buffer);
+        }
+      }
+
+      subpartitionDataWritten.flip();
+      subpartitionDataRead.flip();
+      assertEquals(subpartitionDataWritten, subpartitionDataRead);
+
+      assertEquals(eventsWritten.size(), eventsRead.size());
+      for (int i = 0; i < eventsWritten.size(); ++i) {
+        assertEquals(eventsWritten.get(i).dataType, 
eventsRead.get(i).getDataType());
+        assertEquals(eventsWritten.get(i).data, 
eventsRead.get(i).getNioBufferReadable());
+      }
+    }
+  }
 }

Reply via email to