This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new b11d11153 [CELEBORN-1490][CIP-6] Introduce tier producer in celeborn 
flink client
b11d11153 is described below

commit b11d111536497282b9e6c43c5993b845ad255a3a
Author: Weijie Guo <[email protected]>
AuthorDate: Fri Sep 20 10:50:26 2024 +0800

    [CELEBORN-1490][CIP-6] Introduce tier producer in celeborn flink client
    
    ### What changes were proposed in this pull request?
    Introduce tier producer in celeborn flink client
    
    Note: Only the last commit need review.
    
    ### Why are the changes needed?
    Tier producer is the mediator used by flink hybrid shuffle to send data to 
celeborn.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    Yes
    
    Closes #2733 from reswqa/cip6-5-pr.
    
    Authored-by: Weijie Guo <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../plugin/flink/RemoteShuffleOutputGate.java      |   5 +-
 .../celeborn/plugin/flink/buffer/BufferHeader.java |   9 +
 .../celeborn/plugin/flink/buffer/BufferPacker.java |  45 +-
 .../flink/buffer/ReceivedNoHeaderBufferPacker.java | 112 +++++
 .../celeborn/plugin/flink/utils/BufferUtils.java   |  20 +
 .../celeborn/plugin/flink/BufferPackSuiteJ.java    | 192 +++++++-
 .../plugin/flink/tiered/CelebornTierFactory.java   |  12 +-
 .../flink/tiered/CelebornTierProducerAgent.java    | 487 +++++++++++++++++++++
 .../tiered/CelebornTierMasterAgentSuiteJ.java      | 200 +++++++++
 9 files changed, 1043 insertions(+), 39 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
index d17a182a1..f695af14d 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
@@ -33,6 +33,7 @@ import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.exception.DriverChangedException;
 import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
 import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
 import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
 import org.apache.celeborn.plugin.flink.utils.BufferUtils;
@@ -207,13 +208,13 @@ public class RemoteShuffleOutputGate {
   }
 
   /** Writes a piece of data to a subpartition. */
-  public void write(ByteBuf byteBuf, int subIdx) {
+  public void write(ByteBuf byteBuf, BufferHeader bufferHeader) {
     try {
       flinkShuffleClient.pushDataToLocation(
           shuffleId,
           mapId,
           attemptId,
-          subIdx,
+          bufferHeader.getSubPartitionId(),
           io.netty.buffer.Unpooled.wrappedBuffer(byteBuf.nioBuffer()),
           partitionLocation,
           () -> byteBuf.release());
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
index 6dc6350ce..59e4d5010 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
@@ -37,6 +37,11 @@ public class BufferHeader {
     this(0, 0, 0, size + 2, dataType, isCompressed, size);
   }
 
+  public BufferHeader(
+      int subPartitionId, Buffer.DataType dataType, boolean isCompressed, int 
size) {
+    this(subPartitionId, 0, 0, size + 2, dataType, isCompressed, size);
+  }
+
   public BufferHeader(
       int subPartitionId,
       int attemptId,
@@ -54,6 +59,10 @@ public class BufferHeader {
     this.size = size;
   }
 
+  public int getSubPartitionId() {
+    return subPartitionId;
+  }
+
   public Buffer.DataType getDataType() {
     return dataType;
   }
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
index d0a757f19..76a6c2ef7 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
@@ -33,7 +33,15 @@ import org.apache.celeborn.plugin.flink.utils.BufferUtils;
 import org.apache.celeborn.plugin.flink.utils.Utils;
 import org.apache.celeborn.reflect.DynMethods;
 
-/** Harness used to pack multiple partial buffers together as a full one. */
+/**
+ * Harness used to pack multiple partial buffers together as a full one. There 
are two Flink
+ * integration strategies: Remote Shuffle Service and Hybrid Shuffle. In 
Remote Shuffle Service
+ * integration strategy, the {@link BufferPacker} will receive buffers 
containing both shuffle data
+ * and the Celeborn header. In Hybrid Shuffle integration strategy employs the 
subclass {@link
+ * ReceivedNoHeaderBufferPacker}, which receives buffers containing only 
shuffle data. In these two
+ * integration strategies, the BufferPacker must utilize different methods to 
pack buffers, and the
+ * result of the packed buffer should be same.
+ */
 public class BufferPacker {
   private static Logger logger = LoggerFactory.getLogger(BufferPacker.class);
 
@@ -41,14 +49,15 @@ public class BufferPacker {
     void accept(T var1, U var2) throws E;
   }
 
-  private final BiConsumerWithException<ByteBuf, Integer, 
InterruptedException> ripeBufferHandler;
+  protected final BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+      ripeBufferHandler;
 
-  private Buffer cachedBuffer;
+  protected Buffer cachedBuffer;
 
-  private int currentSubIdx = -1;
+  protected int currentSubIdx = -1;
 
   public BufferPacker(
-      BiConsumerWithException<ByteBuf, Integer, InterruptedException> 
ripeBufferHandler) {
+      BiConsumerWithException<ByteBuf, BufferHeader, InterruptedException> 
ripeBufferHandler) {
     this.ripeBufferHandler = ripeBufferHandler;
   }
 
@@ -71,7 +80,8 @@ public class BufferPacker {
       int targetSubIdx = currentSubIdx;
       currentSubIdx = subIdx;
       logBufferPack(false, dumpedBuffer.getDataType(), 
dumpedBuffer.readableBytes());
-      handleRipeBuffer(dumpedBuffer, targetSubIdx);
+      handleRipeBuffer(
+          dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(), 
dumpedBuffer.isCompressed());
     } else {
       /**
        * this is an optimization. if cachedBuffer can contain other buffer, 
then other buffer can
@@ -95,12 +105,13 @@ public class BufferPacker {
         cachedBuffer = buffer;
         logBufferPack(false, dumpedBuffer.getDataType(), 
dumpedBuffer.readableBytes());
 
-        handleRipeBuffer(dumpedBuffer, currentSubIdx);
+        handleRipeBuffer(
+            dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(), 
dumpedBuffer.isCompressed());
       }
     }
   }
 
-  private void logBufferPack(boolean isDrain, Buffer.DataType dataType, int 
length) {
+  protected void logBufferPack(boolean isDrain, Buffer.DataType dataType, int 
length) {
     logger.debug(
         "isDrain:{}, cachedBuffer pack partition:{} type:{}, length:{}",
         isDrain,
@@ -112,15 +123,27 @@ public class BufferPacker {
   public void drain() throws InterruptedException {
     if (cachedBuffer != null) {
       logBufferPack(true, cachedBuffer.getDataType(), 
cachedBuffer.readableBytes());
-      handleRipeBuffer(cachedBuffer, currentSubIdx);
+      handleRipeBuffer(
+          cachedBuffer, currentSubIdx, cachedBuffer.getDataType(), 
cachedBuffer.isCompressed());
     }
     cachedBuffer = null;
     currentSubIdx = -1;
   }
 
-  private void handleRipeBuffer(Buffer buffer, int subIdx) throws 
InterruptedException {
+  protected void handleRipeBuffer(
+      Buffer buffer, int subIdx, Buffer.DataType dataType, boolean 
isCompressed)
+      throws InterruptedException {
+    // Always set the compress flag to false, because the result buffer 
generated by {@link
+    // BufferPacker} needs to be split into multiple buffers in unpack process,
+    // If the compress flag is set to true for this result buffer, it will 
throw an exception during
+    // the unpack process, as compressed buffer cannot be sliced.
     buffer.setCompressed(false);
-    ripeBufferHandler.accept(buffer.asByteBuf(), subIdx);
+    ripeBufferHandler.accept(
+        buffer.asByteBuf(), new BufferHeader(subIdx, dataType, isCompressed, 
buffer.getSize()));
+  }
+
+  public boolean isEmpty() {
+    return cachedBuffer == null;
   }
 
   public void close() {
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java
new file mode 100644
index 000000000..09337ec4f
--- /dev/null
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.buffer;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+
+/**
+ * Harness used to pack multiple partial buffers together as a full one. It 
used in Flink hybrid
+ * shuffle integration strategy now.
+ */
+public class ReceivedNoHeaderBufferPacker extends BufferPacker {
+
+  /** The flink buffer header of cached first buffer. */
+  private BufferHeader firstBufferHeader;
+
+  public ReceivedNoHeaderBufferPacker(
+      BiConsumerWithException<ByteBuf, BufferHeader, InterruptedException> 
ripeBufferHandler) {
+    super(ripeBufferHandler);
+  }
+
+  @Override
+  public void process(Buffer buffer, int subIdx) throws InterruptedException {
+    if (buffer == null) {
+      return;
+    }
+
+    if (buffer.readableBytes() == 0) {
+      buffer.recycleBuffer();
+      return;
+    }
+
+    if (cachedBuffer == null) {
+      // cache the first buffer and record flink buffer header of first buffer
+      cachedBuffer = buffer;
+      currentSubIdx = subIdx;
+      firstBufferHeader =
+          new BufferHeader(subIdx, buffer.getDataType(), 
buffer.isCompressed(), buffer.getSize());
+    } else if (currentSubIdx != subIdx) {
+      // drain the previous cached buffer and cache current buffer
+      Buffer dumpedBuffer = cachedBuffer;
+      cachedBuffer = buffer;
+      int targetSubIdx = currentSubIdx;
+      currentSubIdx = subIdx;
+      logBufferPack(false, dumpedBuffer.getDataType(), 
dumpedBuffer.readableBytes());
+      handleRipeBuffer(
+          dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(), 
dumpedBuffer.isCompressed());
+      firstBufferHeader =
+          new BufferHeader(subIdx, buffer.getDataType(), 
buffer.isCompressed(), buffer.getSize());
+    } else {
+      int bufferHeaderLength = BufferUtils.HEADER_LENGTH - 
BufferUtils.HEADER_LENGTH_PREFIX;
+      if (cachedBuffer.readableBytes() + buffer.readableBytes() + 
bufferHeaderLength
+          <= cachedBuffer.getMaxCapacity() - BufferUtils.HEADER_LENGTH) {
+        // if the cache buffer can contain the current buffer, then pack the 
current buffer into
+        // cache buffer
+        ByteBuf byteBuf = cachedBuffer.asByteBuf();
+        byteBuf.writeByte(buffer.getDataType().ordinal());
+        byteBuf.writeBoolean(buffer.isCompressed());
+        byteBuf.writeInt(buffer.getSize());
+        byteBuf.writeBytes(buffer.asByteBuf(), 0, buffer.readableBytes());
+        logBufferPack(false, buffer.getDataType(), buffer.readableBytes() + 
bufferHeaderLength);
+
+        buffer.recycleBuffer();
+      } else {
+        // if the cache buffer cannot contain the current buffer, drain the 
cached buffer, and cache
+        // the current buffer
+        Buffer dumpedBuffer = cachedBuffer;
+        cachedBuffer = buffer;
+        logBufferPack(false, dumpedBuffer.getDataType(), 
dumpedBuffer.readableBytes());
+
+        handleRipeBuffer(
+            dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(), 
dumpedBuffer.isCompressed());
+        firstBufferHeader =
+            new BufferHeader(subIdx, buffer.getDataType(), 
buffer.isCompressed(), buffer.getSize());
+      }
+    }
+  }
+
+  @Override
+  protected void handleRipeBuffer(
+      Buffer buffer, int subIdx, Buffer.DataType dataType, boolean 
isCompressed)
+      throws InterruptedException {
+    if (buffer == null || buffer.readableBytes() == 0) {
+      return;
+    }
+    // Always set the compress flag to false, because this buffer contains 
Celeborn header and
+    // multiple flink data buffers.
+    // It is crucial to keep this flag set to false because we need to slice 
this buffer to extract
+    // flink data buffers
+    // during the unpacking process, the flink {@link NetworkBuffer} cannot 
correctly slice
+    // compressed buffer.
+    buffer.setCompressed(false);
+    ripeBufferHandler.accept(buffer.asByteBuf(), firstBufferHeader);
+  }
+}
diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
index 14599e477..999d1eb10 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
@@ -59,6 +59,26 @@ public class BufferUtils {
     buffer.setSize(dataLength + HEADER_LENGTH);
   }
 
+  /**
+   * It is utilized in Hybrid Shuffle integration strategy, in this case the 
buffer containing data
+   * only. Copies the data of the compressed buffer to the origin buffer.
+   */
+  public static void setCompressedDataWithoutHeader(Buffer buffer, Buffer 
compressedBuffer) {
+    checkArgument(buffer != null, "Must be not null.");
+    checkArgument(buffer.getReaderIndex() == 0, "Illegal reader index.");
+
+    boolean isCompressed = compressedBuffer != null && 
compressedBuffer.isCompressed();
+    int dataLength = isCompressed ? compressedBuffer.readableBytes() : 
buffer.readableBytes();
+    ByteBuf byteBuf = buffer.asByteBuf();
+    if (isCompressed) {
+      byteBuf.writerIndex(0);
+      byteBuf.writeBytes(compressedBuffer.asByteBuf());
+      // set the compression flag here, as we need it when writing the 
sub-header of this buffer
+      buffer.setCompressed(true);
+    }
+    buffer.setSize(dataLength);
+  }
+
   public static void setBufferHeader(
       ByteBuf byteBuf, Buffer.DataType dataType, boolean isCompressed, int 
dataLength) {
     byteBuf.writerIndex(0);
diff --git 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
index 2d5d5e78f..8f3c0ce6e 100644
--- 
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
+++ 
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
@@ -23,20 +23,32 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
 import org.apache.commons.lang3.tuple.Pair;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
+import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
 import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker;
 import org.apache.celeborn.plugin.flink.utils.BufferUtils;
 
+@RunWith(Parameterized.class)
 public class BufferPackSuiteJ {
   private static final int BUFFER_SIZE = 20 + 16;
 
@@ -44,6 +56,18 @@ public class BufferPackSuiteJ {
 
   private BufferPool bufferPool;
 
+  private boolean bufferPackerReceivedBufferHasHeader;
+
+  public BufferPackSuiteJ(boolean bufferPackerReceivedBufferHasHeader) {
+    this.bufferPackerReceivedBufferHasHeader = 
bufferPackerReceivedBufferHasHeader;
+  }
+
+  @Parameterized.Parameters
+  public static Collection prepareData() {
+    Object[][] object = {{true}, {false}};
+    return Arrays.asList(object);
+  }
+
   @Before
   public void setup() throws Exception {
     networkBufferPool = new NetworkBufferPool(10, BUFFER_SIZE);
@@ -66,13 +90,14 @@ public class BufferPackSuiteJ {
     Integer subIdx = 2;
 
     List<ByteBuf> output = new ArrayList<>();
-    BufferPacker.BiConsumerWithException<ByteBuf, Integer, 
InterruptedException> ripeBufferHandler =
-        (ripe, sub) -> {
-          assertEquals(subIdx, sub);
-          output.add(ripe);
-        };
-
-    BufferPacker packer = new BufferPacker(ripeBufferHandler);
+    BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+        ripeBufferHandler =
+            (ripe, header) -> {
+              assertEquals(subIdx, 
Integer.valueOf(header.getSubPartitionId()));
+              output.add(ripe);
+            };
+
+    BufferPacker packer = createBufferPakcer(ripeBufferHandler);
     packer.process(buffers.get(0), subIdx);
     packer.process(buffers.get(1), subIdx);
     packer.process(buffers.get(2), subIdx);
@@ -89,9 +114,12 @@ public class BufferPackSuiteJ {
     setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
 
     List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
-    BufferPacker.BiConsumerWithException<ByteBuf, Integer, 
InterruptedException> ripeBufferHandler =
-        (ripe, sub) -> output.add(Pair.of(ripe, sub));
-    BufferPacker packer = new BufferPacker(ripeBufferHandler);
+    BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+        ripeBufferHandler =
+            (ripe, header) ->
+                output.add(
+                    Pair.of(addBufferHeaderPossible(ripe, header), 
header.getSubPartitionId()));
+    BufferPacker packer = createBufferPakcer(ripeBufferHandler);
     fillBuffers(buffers, 0, 1, 2);
 
     packer.process(buffers.get(0), 2);
@@ -123,9 +151,12 @@ public class BufferPackSuiteJ {
     setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
 
     List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
-    BufferPacker.BiConsumerWithException<ByteBuf, Integer, 
InterruptedException> ripeBufferHandler =
-        (ripe, sub) -> output.add(Pair.of(ripe, sub));
-    BufferPacker packer = new BufferPacker(ripeBufferHandler);
+    BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+        ripeBufferHandler =
+            (ripe, header) ->
+                output.add(
+                    Pair.of(addBufferHeaderPossible(ripe, header), 
header.getSubPartitionId()));
+    BufferPacker packer = createBufferPakcer(ripeBufferHandler);
     fillBuffers(buffers, 0, 1, 2);
 
     packer.process(buffers.get(0), 0);
@@ -158,9 +189,12 @@ public class BufferPackSuiteJ {
     setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
 
     List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
-    BufferPacker.BiConsumerWithException<ByteBuf, Integer, 
InterruptedException> ripeBufferHandler =
-        (ripe, sub) -> output.add(Pair.of(ripe, sub));
-    BufferPacker packer = new BufferPacker(ripeBufferHandler);
+    BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+        ripeBufferHandler =
+            (ripe, header) ->
+                output.add(
+                    Pair.of(addBufferHeaderPossible(ripe, header), 
header.getSubPartitionId()));
+    BufferPacker packer = createBufferPakcer(ripeBufferHandler);
     fillBuffers(buffers, 0, 1, 2);
 
     packer.process(buffers.get(0), 0);
@@ -186,6 +220,59 @@ public class BufferPackSuiteJ {
     unpacked.forEach(Buffer::recycleBuffer);
   }
 
+  @Test
+  public void testPackMultipleBuffers() throws Exception {
+    int numBuffers = 7;
+    List<Buffer> buffers = new ArrayList<>();
+    buffers.add(buildSomeBuffer(100));
+    buffers.addAll(requestBuffers(numBuffers - 1));
+    setCompressed(buffers, true, true, true, false, false, false, true);
+    setDataType(
+        buffers,
+        EVENT_BUFFER,
+        DATA_BUFFER,
+        DATA_BUFFER,
+        EVENT_BUFFER,
+        DATA_BUFFER,
+        DATA_BUFFER,
+        EVENT_BUFFER);
+
+    List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
+    BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+        ripeBufferHandler =
+            (ripe, header) ->
+                output.add(
+                    Pair.of(addBufferHeaderPossible(ripe, header), 
header.getSubPartitionId()));
+    BufferPacker packer = createBufferPakcer(ripeBufferHandler);
+    fillBuffers(buffers, 0, 1, 2, 3, 4, 5, 6, 7);
+
+    for (int i = 0; i < buffers.size(); i++) {
+      packer.process(buffers.get(i), 0);
+    }
+    packer.drain();
+
+    List<Buffer> unpacked = new ArrayList<>();
+    for (int i = 0; i < output.size(); i++) {
+      Pair<ByteBuf, Integer> pair = output.get(i);
+      assertEquals(Integer.valueOf(0), pair.getRight());
+      unpacked.addAll(BufferPacker.unpack(pair.getLeft()));
+    }
+    assertEquals(7, unpacked.size());
+
+    checkIfCompressed(unpacked, true, true, true, false, false, false, true);
+    checkDataType(
+        unpacked,
+        EVENT_BUFFER,
+        DATA_BUFFER,
+        DATA_BUFFER,
+        EVENT_BUFFER,
+        DATA_BUFFER,
+        DATA_BUFFER,
+        EVENT_BUFFER);
+    verifyBuffers(unpacked, 0, 1, 2, 3, 4, 5, 6, 7);
+    unpacked.forEach(Buffer::recycleBuffer);
+  }
+
   @Test
   public void testFailedToHandleRipeBufferAndClose() throws Exception {
     List<Buffer> buffers = requestBuffers(1);
@@ -193,12 +280,13 @@ public class BufferPackSuiteJ {
     setDataType(buffers, DATA_BUFFER);
     fillBuffers(buffers, 0);
 
-    BufferPacker.BiConsumerWithException<ByteBuf, Integer, 
InterruptedException> ripeBufferHandler =
-        (ripe, sub) -> {
-          // ripe.release();
-          throw new RuntimeException("Test");
-        };
-    BufferPacker packer = new BufferPacker(ripeBufferHandler);
+    BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+        ripeBufferHandler =
+            (ripe, header) -> {
+              // ripe.release();
+              throw new RuntimeException("Test");
+            };
+    BufferPacker packer = createBufferPakcer(ripeBufferHandler);
     System.out.println(buffers.get(0).refCnt());
     packer.process(buffers.get(0), 0);
     try {
@@ -248,8 +336,17 @@ public class BufferPackSuiteJ {
     for (int i = 0; i < buffers.size(); i++) {
       Buffer buffer = buffers.get(i);
       ByteBuf target = buffer.asByteBuf();
-      BufferUtils.setBufferHeader(target, buffer.getDataType(), 
buffer.isCompressed(), 4);
-      target.writerIndex(BufferUtils.HEADER_LENGTH);
+
+      if (bufferPackerReceivedBufferHasHeader) {
+        // If the buffer includes a header, we need to leave space for the 
header, so we should
+        // update the writer index to BufferUtils.HEADER_LENGTH.
+        BufferUtils.setBufferHeader(target, buffer.getDataType(), 
buffer.isCompressed(), 4);
+        target.writerIndex(BufferUtils.HEADER_LENGTH);
+      } else {
+        // if the buffer does not have a header, we can directly write data 
starting from the
+        // beginning of the buffer.
+        target.writerIndex(0);
+      }
       target.writeInt(ints[i]);
     }
   }
@@ -260,4 +357,51 @@ public class BufferPackSuiteJ {
       assertEquals(expects[i], actual.getInt(0));
     }
   }
+
+  public static Buffer buildSomeBuffer(int size) {
+    final MemorySegment seg = 
MemorySegmentFactory.allocateUnpooledSegment(size);
+    return new NetworkBuffer(seg, MemorySegment::free, 
Buffer.DataType.DATA_BUFFER, size);
+  }
+
+  public ByteBuf addBufferHeaderPossible(ByteBuf byteBuf, BufferHeader 
bufferHeader) {
+    // Try to add buffer header if bufferPackerReceivedBufferHasHeader set to 
false in BufferPacker
+    // drain process
+    if (bufferPackerReceivedBufferHasHeader) {
+      return byteBuf;
+    }
+
+    CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+    // create a small buffer headerBuf to write the buffer header
+    ByteBuf headerBuf = Unpooled.buffer(BufferUtils.HEADER_LENGTH);
+
+    // write celeborn buffer header (subpartitionid(4) + attemptId(4) + 
nextBatchId(4) +
+    // compressedsize)
+    headerBuf.writeInt(bufferHeader.getSubPartitionId());
+    headerBuf.writeInt(0);
+    headerBuf.writeInt(0);
+    headerBuf.writeInt(
+        byteBuf.readableBytes() + (BufferUtils.HEADER_LENGTH - 
BufferUtils.HEADER_LENGTH_PREFIX));
+
+    // write flink buffer header (dataType(1) + isCompress(1) + size(4))
+    headerBuf.writeByte(bufferHeader.getDataType().ordinal());
+    headerBuf.writeBoolean(bufferHeader.isCompressed());
+    headerBuf.writeInt(bufferHeader.getSize());
+
+    // composite the headerBuf and data buffer together
+    compositeByteBuf.addComponents(true, headerBuf, byteBuf);
+    ByteBuf packedByteBuf = 
Unpooled.wrappedBuffer(compositeByteBuf.nioBuffer());
+    byteBuf.writerIndex(0);
+    byteBuf.writeBytes(packedByteBuf, 0, packedByteBuf.readableBytes());
+    return byteBuf;
+  }
+
+  public BufferPacker createBufferPakcer(
+      BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader, 
InterruptedException>
+          ripeBufferHandler) {
+    if (bufferPackerReceivedBufferHasHeader) {
+      return new BufferPacker(ripeBufferHandler);
+    } else {
+      return new ReceivedNoHeaderBufferPacker(ripeBufferHandler);
+    }
+  }
 }
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
index 326a11985..02306a5ad 100644
--- 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
@@ -101,8 +101,16 @@ public class CelebornTierFactory implements TierFactory {
       ScheduledExecutorService ioExecutor,
       List<TierShuffleDescriptor> shuffleDescriptors,
       int maxRequestedBuffers) {
-    // TODO impl this in the follow-up PR.
-    return null;
+    return new CelebornTierProducerAgent(
+        conf,
+        partitionId,
+        numPartitions,
+        numSubpartitions,
+        NUM_BYTES_PER_SEGMENT,
+        bufferSizeBytes,
+        storageMemoryManager,
+        resourceRegistry,
+        shuffleDescriptors);
   }
 
   @Override
diff --git 
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
new file mode 100644
index 000000000..aab2b3ae5
--- /dev/null
+++ 
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static 
org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.EndOfSegmentEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
+import org.apache.flink.util.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.DriverChangedException;
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
+import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+public class CelebornTierProducerAgent implements TierProducerAgent {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(CelebornTierProducerAgent.class);
+
+  private final int numBuffersPerSegment;
+
+  private final int bufferSizeBytes;
+
+  private final int numPartitions;
+
+  private final int numSubPartitions;
+
+  private final CelebornConf celebornConf;
+
+  private final TieredStorageMemoryManager memoryManager;
+
+  private final String applicationId;
+
+  private final int shuffleId;
+
+  private final int mapId;
+
+  private final int attemptId;
+
+  private final int partitionId;
+
+  private final String lifecycleManagerHost;
+
+  private final int lifecycleManagerPort;
+
+  private final long lifecycleManagerTimestamp;
+
+  private FlinkShuffleClientImpl flinkShuffleClient;
+
+  private BufferPacker bufferPacker;
+
+  private final int[] subPartitionSegmentIds;
+
+  private final int[] subPartitionSegmentBuffers;
+
+  private final int maxReviveTimes;
+
+  private PartitionLocation partitionLocation;
+
+  private boolean hasRegisteredShuffle;
+
+  private int currentRegionIndex = 0;
+
+  private int currentSubpartition = 0;
+
+  private boolean hasSentHandshake = false;
+
+  private boolean hasSentRegionStart = false;
+
+  private volatile boolean isReleased;
+
+  CelebornTierProducerAgent(
+      CelebornConf conf,
+      TieredStoragePartitionId partitionId,
+      int numPartitions,
+      int numSubPartitions,
+      int numBytesPerSegment,
+      int bufferSizeBytes,
+      TieredStorageMemoryManager memoryManager,
+      TieredStorageResourceRegistry resourceRegistry,
+      List<TierShuffleDescriptor> shuffleDescriptors) {
+    checkArgument(
+        numBytesPerSegment >= bufferSizeBytes, "One segment should contain at 
least one buffer.");
+    checkArgument(shuffleDescriptors.size() == 1, "There should be only one 
shuffle descriptor.");
+    TierShuffleDescriptor descriptor = shuffleDescriptors.get(0);
+    checkArgument(
+        descriptor instanceof TierShuffleDescriptorImpl,
+        "Wrong shuffle descriptor type " + descriptor.getClass());
+    TierShuffleDescriptorImpl shuffleDesc = (TierShuffleDescriptorImpl) 
descriptor;
+
+    this.numBuffersPerSegment = numBytesPerSegment / bufferSizeBytes;
+    this.bufferSizeBytes = bufferSizeBytes;
+    this.memoryManager = memoryManager;
+    this.numPartitions = numPartitions;
+    this.numSubPartitions = numSubPartitions;
+    this.celebornConf = conf;
+    this.subPartitionSegmentIds = new int[numSubPartitions];
+    this.subPartitionSegmentBuffers = new int[numSubPartitions];
+    this.maxReviveTimes = conf.clientPushMaxReviveTimes();
+
+    this.applicationId = shuffleDesc.getCelebornAppId();
+    this.shuffleId =
+        
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getShuffleId();
+    this.mapId = 
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getMapId();
+    this.attemptId =
+        
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getAttemptId();
+    this.partitionId =
+        
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getPartitionId();
+    this.lifecycleManagerHost = 
shuffleDesc.getShuffleResource().getLifecycleManagerHost();
+    this.lifecycleManagerPort = 
shuffleDesc.getShuffleResource().getLifecycleManagerPort();
+    this.lifecycleManagerTimestamp =
+        shuffleDesc.getShuffleResource().getLifecycleManagerTimestamp();
+    this.flinkShuffleClient = getShuffleClient();
+
+    Arrays.fill(subPartitionSegmentIds, -1);
+    Arrays.fill(subPartitionSegmentBuffers, 0);
+
+    this.bufferPacker = new ReceivedNoHeaderBufferPacker(this::write);
+    resourceRegistry.registerResource(partitionId, this::releaseResources);
+    registerShuffle();
+    try {
+      handshake();
+    } catch (IOException e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  @Override
+  public boolean tryStartNewSegment(
+      TieredStorageSubpartitionId tieredStorageSubpartitionId, int segmentId, 
int minNumBuffers) {
+    int subPartitionId = tieredStorageSubpartitionId.getSubpartitionId();
+    checkState(
+        segmentId >= subPartitionSegmentIds[subPartitionId], "Wrong segment id 
" + segmentId);
+    subPartitionSegmentIds[subPartitionId] = segmentId;
+    // If the start segment rpc is sent, the worker side will expect that
+    // there must be at least one buffer will be written in the next moment.
+    try {
+      flinkShuffleClient.segmentStart(
+          shuffleId, mapId, attemptId, subPartitionId, segmentId, 
partitionLocation);
+    } catch (IOException e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+    return true;
+  }
+
+  @Override
+  public boolean tryWrite(
+      TieredStorageSubpartitionId tieredStorageSubpartitionId,
+      Buffer buffer,
+      Object bufferOwner,
+      int numRemainingConsecutiveBuffers) {
+    // It should be noted that, unlike RemoteShuffleOutputGate#write, the 
received buffer contains
+    // only
+    // and does not have any remaining space for writing the celeborn header.
+
+    int subPartitionId = tieredStorageSubpartitionId.getSubpartitionId();
+
+    if (subPartitionSegmentBuffers[subPartitionId] + 1 + 
numRemainingConsecutiveBuffers
+        >= numBuffersPerSegment) {
+      // End the current segment if the segment buffer count reaches the 
threshold
+      subPartitionSegmentBuffers[subPartitionId] = 0;
+      try {
+        bufferPacker.drain();
+      } catch (InterruptedException e) {
+        buffer.recycleBuffer();
+        ExceptionUtils.rethrow(e, "Failed to process buffer.");
+      }
+      appendEndOfSegmentBuffer(subPartitionId);
+      return false;
+    }
+
+    if (buffer.isBuffer()) {
+      memoryManager.transferBufferOwnership(
+          bufferOwner, CelebornTierFactory.getCelebornTierName(), buffer);
+    }
+
+    // write buffer to BufferPacker and record buffer count per subPartition 
per segment
+    processBuffer(buffer, subPartitionId);
+    subPartitionSegmentBuffers[subPartitionId]++;
+    return true;
+  }
+
+  @Override
+  public void close() {
+    if (hasSentRegionStart) {
+      regionFinish();
+    }
+    try {
+      if (hasRegisteredShuffle && partitionLocation != null) {
+        flinkShuffleClient.mapPartitionMapperEnd(
+            shuffleId, mapId, attemptId, numPartitions, 
partitionLocation.getId());
+      }
+    } catch (Exception e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+    bufferPacker.close();
+    bufferPacker = null;
+    flinkShuffleClient.cleanup(shuffleId, mapId, attemptId);
+    flinkShuffleClient = null;
+  }
+
+  private void regionStartOrFinish(int subPartitionId) {
+    // check whether the region should be started or finished
+    regionStart();
+    if (subPartitionId < currentSubpartition) {
+      // if the consumed subPartitionId is out of order, it means that should 
the previous region
+      // should be finished, and starting a new region.
+      regionFinish();
+      LOG.debug(
+          "Check region finish sub partition id {} and start next region {}",
+          subPartitionId,
+          currentRegionIndex);
+      regionStart();
+    }
+  }
+
+  private void regionStart() {
+    if (hasSentRegionStart) {
+      return;
+    }
+    regionStartWithRevive();
+  }
+
+  private void regionStartWithRevive() {
+    try {
+      int remainingReviveTimes = maxReviveTimes;
+      while (remainingReviveTimes-- > 0 && !hasSentRegionStart) {
+        Optional<PartitionLocation> revivePartition =
+            flinkShuffleClient.regionStart(
+                shuffleId, mapId, attemptId, partitionLocation, 
currentRegionIndex, false);
+        if (revivePartition.isPresent()) {
+          LOG.info(
+              "Revive at regionStart, currentTimes:{}, totalTimes:{} for 
shuffleId:{}, mapId:{}, "
+                  + "attempId:{}, currentRegionIndex:{}, isBroadcast:{}, 
newPartition:{}, oldPartition:{}",
+              remainingReviveTimes,
+              maxReviveTimes,
+              shuffleId,
+              mapId,
+              attemptId,
+              currentRegionIndex,
+              false,
+              revivePartition,
+              partitionLocation);
+          partitionLocation = revivePartition.get();
+          // For every revive partition, handshake should be sent firstly
+          hasSentHandshake = false;
+          handshake();
+          if (numSubPartitions > 0) {
+            for (int i = 0; i < numSubPartitions; i++) {
+              flinkShuffleClient.segmentStart(
+                  shuffleId, mapId, attemptId, i, subPartitionSegmentIds[i], 
partitionLocation);
+            }
+          }
+        } else {
+          hasSentRegionStart = true;
+          currentSubpartition = 0;
+        }
+      }
+      if (remainingReviveTimes == 0 && !hasSentRegionStart) {
+        throw new RuntimeException(
+            "After retry " + maxReviveTimes + " times, still failed to send 
regionStart");
+      }
+    } catch (IOException e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  void regionFinish() {
+    try {
+      bufferPacker.drain();
+      flinkShuffleClient.regionFinish(shuffleId, mapId, attemptId, 
partitionLocation);
+      hasSentRegionStart = false;
+      currentRegionIndex++;
+    } catch (Exception e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  private void handshake() throws IOException {
+    try {
+      int remainingReviveTimes = maxReviveTimes;
+      while (remainingReviveTimes-- > 0 && !hasSentHandshake) {
+        Optional<PartitionLocation> revivePartition =
+            flinkShuffleClient.pushDataHandShake(
+                shuffleId, mapId, attemptId, numSubPartitions, 
bufferSizeBytes, partitionLocation);
+        // if remainingReviveTimes == 0 and revivePartition.isPresent(), there 
is no need to send
+        // handshake again
+        if (revivePartition.isPresent() && remainingReviveTimes > 0) {
+          LOG.info(
+              "Revive at handshake, currentTimes:{}, totalTimes:{} for 
shuffleId:{}, mapId:{}, "
+                  + "attempId:{}, currentRegionIndex:{}, newPartition:{}, 
oldPartition:{}",
+              remainingReviveTimes,
+              maxReviveTimes,
+              shuffleId,
+              mapId,
+              attemptId,
+              currentRegionIndex,
+              revivePartition,
+              partitionLocation);
+          partitionLocation = revivePartition.get();
+          hasSentHandshake = false;
+        } else {
+          hasSentHandshake = true;
+        }
+      }
+      if (remainingReviveTimes == 0 && !hasSentHandshake) {
+        throw new RuntimeException(
+            "After retry " + maxReviveTimes + " times, still failed to send 
handshake");
+      }
+    } catch (IOException e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  private void releaseResources() {
+    if (!isReleased) {
+      isReleased = true;
+    }
+  }
+
+  private void registerShuffle() {
+    try {
+      if (!hasRegisteredShuffle) {
+        partitionLocation =
+            flinkShuffleClient.registerMapPartitionTask(
+                shuffleId, numPartitions, mapId, attemptId, partitionId, true);
+        Utils.checkNotNull(partitionLocation);
+        hasRegisteredShuffle = true;
+      }
+    } catch (IOException e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  private void write(ByteBuf byteBuf, BufferHeader bufferHeader) {
+    try {
+      // create a composite buffer and write a header into it. This composite 
buffer will serve as
+      // the result packed buffer.
+      CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+      ByteBuf headerBuf = Unpooled.buffer(BufferUtils.HEADER_LENGTH);
+
+      // write celeborn buffer header (subpartitionid(4) + attemptId(4) + 
nextBatchId(4) +
+      // compressedsize)
+      headerBuf.writeInt(bufferHeader.getSubPartitionId());
+      headerBuf.writeInt(attemptId);
+      headerBuf.writeInt(0);
+      headerBuf.writeInt(
+          byteBuf.readableBytes() + (BufferUtils.HEADER_LENGTH - 
BufferUtils.HEADER_LENGTH_PREFIX));
+
+      // write flink buffer header (dataType(1) + isCompress(1) + size(4))
+      headerBuf.writeByte(bufferHeader.getDataType().ordinal());
+      headerBuf.writeBoolean(bufferHeader.isCompressed());
+      headerBuf.writeInt(bufferHeader.getSize());
+
+      // composite the headerBuf and data buffer together
+      compositeByteBuf.addComponents(true, headerBuf, byteBuf);
+      io.netty.buffer.ByteBuf wrappedBuffer =
+          io.netty.buffer.Unpooled.wrappedBuffer(compositeByteBuf.nioBuffer());
+
+      int numWritten =
+          flinkShuffleClient.pushDataToLocation(
+              shuffleId,
+              mapId,
+              attemptId,
+              bufferHeader.getSubPartitionId(),
+              wrappedBuffer,
+              partitionLocation,
+              compositeByteBuf::release);
+      checkState(
+          numWritten == byteBuf.readableBytes() + BufferUtils.HEADER_LENGTH, 
"Wrong written size.");
+    } catch (IOException e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  private void appendEndOfSegmentBuffer(int subPartitionId) {
+    try {
+      checkState(bufferPacker.isEmpty(), "BufferPacker is not empty");
+      MemorySegment endSegmentMemorySegment =
+          MemorySegmentFactory.wrap(
+              
EventSerializer.toSerializedEvent(EndOfSegmentEvent.INSTANCE).array());
+      Buffer endOfSegmentBuffer =
+          new NetworkBuffer(
+              endSegmentMemorySegment,
+              FreeingBufferRecycler.INSTANCE,
+              END_OF_SEGMENT,
+              endSegmentMemorySegment.size());
+      processBuffer(endOfSegmentBuffer, subPartitionId);
+    } catch (Exception e) {
+      ExceptionUtils.rethrow(e, "Failed to append end of segment event.");
+    }
+  }
+
+  private void processBuffer(Buffer originBuffer, int subPartitionId) {
+    try {
+      regionStartOrFinish(subPartitionId);
+      currentSubpartition = subPartitionId;
+
+      Buffer buffer = originBuffer;
+      if (originBuffer.isCompressed()) {
+        // In flink 1.20.0, it will receive a compressed buffer. However, 
since we need to write
+        // data to this buffer and the compressed buffer is read-only,
+        // we must create a new Buffer object to the wrap origin buffer.
+        NetworkBuffer networkBuffer =
+            new NetworkBuffer(
+                originBuffer.getMemorySegment(),
+                originBuffer.getRecycler(),
+                originBuffer.getDataType(),
+                originBuffer.getSize());
+        networkBuffer.writerIndex(originBuffer.asByteBuf().writerIndex());
+        buffer = networkBuffer;
+      }
+
+      // TODO: To enhance performance, the flink should pass an no-compressed 
buffer to producer
+      // agent and we compress the buffer here
+
+      // set the buffer meta
+      BufferUtils.setCompressedDataWithoutHeader(buffer, originBuffer);
+
+      bufferPacker.process(buffer, subPartitionId);
+    } catch (InterruptedException e) {
+      originBuffer.recycleBuffer();
+      ExceptionUtils.rethrow(e, "Failed to process buffer.");
+    }
+  }
+
+  @VisibleForTesting
+  FlinkShuffleClientImpl getShuffleClient() {
+    try {
+      return FlinkShuffleClientImpl.get(
+          applicationId,
+          lifecycleManagerHost,
+          lifecycleManagerPort,
+          lifecycleManagerTimestamp,
+          celebornConf,
+          null);
+    } catch (DriverChangedException e) {
+      // would generate a new attempt to retry output gate
+      throw new RuntimeException(e.getMessage());
+    }
+  }
+}
diff --git 
a/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierMasterAgentSuiteJ.java
 
b/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierMasterAgentSuiteJ.java
new file mode 100644
index 000000000..f53d010cd
--- /dev/null
+++ 
b/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierMasterAgentSuiteJ.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import java.net.UnknownHostException;
+import java.util.Collection;
+import java.util.concurrent.CompletableFuture;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleHandler;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.Utils$;
+import org.apache.celeborn.plugin.flink.ShuffleResource;
+import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
+import org.apache.celeborn.plugin.flink.utils.FlinkUtils;
+
+public class CelebornTierMasterAgentSuiteJ {
+  private static final Logger LOG = 
LoggerFactory.getLogger(CelebornTierMasterAgentSuiteJ.class);
+  private CelebornTierMasterAgent masterAgent;
+
+  @Before
+  public void setUp() {
+    Configuration configuration = new Configuration();
+    int startPort = Utils$.MODULE$.selectRandomPort(1024, 65535);
+    configuration.setInteger("celeborn.master.port", startPort);
+    configuration.setString("celeborn.master.endpoints", "localhost:" + 
startPort);
+    masterAgent = createMasterAgent(configuration);
+  }
+
+  @Test
+  public void testRegisterJob() {
+    TierShuffleHandler tierShuffleHandler = createTierShuffleHandler();
+    JobID jobID = JobID.generate();
+    masterAgent.registerJob(jobID, tierShuffleHandler);
+
+    // reRunRegister job
+    try {
+      masterAgent.registerJob(jobID, tierShuffleHandler);
+      Assert.fail("should throw exception if double register job");
+    } catch (Exception e) {
+      Assert.assertTrue(true);
+    }
+
+    // unRegister job
+    masterAgent.unregisterJob(jobID);
+    masterAgent.registerJob(jobID, tierShuffleHandler);
+  }
+
+  private static TierShuffleHandler createTierShuffleHandler() {
+    return new TierShuffleHandler() {
+
+      @Override
+      public CompletableFuture<?> onReleasePartitions(
+          Collection<TieredStoragePartitionId> collection) {
+        return CompletableFuture.completedFuture(null);
+      }
+
+      @Override
+      public void onFatalError(Throwable throwable) {
+        System.exit(-1);
+      }
+    };
+  }
+
+  @Test
+  public void testRegisterPartitionWithProducer() {
+    JobID jobID = JobID.generate();
+    TierShuffleHandler tierShuffleHandler = createTierShuffleHandler();
+    masterAgent.registerJob(jobID, tierShuffleHandler);
+
+    ExecutionAttemptID executionAttemptID =
+        new ExecutionAttemptID(
+            new ExecutionGraphID(), new ExecutionVertexID(new JobVertexID(0L, 
0L), 0), 0);
+    ResultPartitionID resultPartitionID =
+        new ResultPartitionID(
+            new IntermediateResultPartitionID(new IntermediateDataSetID(), 0), 
executionAttemptID);
+    TierShuffleDescriptor tierShuffleDescriptor =
+        masterAgent.addPartitionAndGetShuffleDescriptor(jobID, 
resultPartitionID);
+    Assert.assertTrue(tierShuffleDescriptor instanceof 
TierShuffleDescriptorImpl);
+    ShuffleResource shuffleResource =
+        ((TierShuffleDescriptorImpl) 
tierShuffleDescriptor).getShuffleResource();
+    ShuffleResourceDescriptor mapPartitionShuffleDescriptor =
+        shuffleResource.getMapPartitionShuffleDescriptor();
+
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getPartitionId());
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getAttemptId());
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getMapId());
+
+    // use same partition id
+    tierShuffleDescriptor =
+        masterAgent.addPartitionAndGetShuffleDescriptor(jobID, 
resultPartitionID);
+    mapPartitionShuffleDescriptor =
+        ((TierShuffleDescriptorImpl) tierShuffleDescriptor)
+            .getShuffleResource()
+            .getMapPartitionShuffleDescriptor();
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getMapId());
+    Assert.assertEquals(1, mapPartitionShuffleDescriptor.getPartitionId());
+    Assert.assertEquals(1, mapPartitionShuffleDescriptor.getAttemptId());
+
+    // use another partition number
+    tierShuffleDescriptor =
+        masterAgent.addPartitionAndGetShuffleDescriptor(
+            jobID,
+            new ResultPartitionID(
+                new IntermediateResultPartitionID(
+                    
resultPartitionID.getPartitionId().getIntermediateDataSetID(), 1),
+                executionAttemptID));
+    mapPartitionShuffleDescriptor =
+        ((TierShuffleDescriptorImpl) tierShuffleDescriptor)
+            .getShuffleResource()
+            .getMapPartitionShuffleDescriptor();
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+    Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
+    Assert.assertEquals(2, mapPartitionShuffleDescriptor.getPartitionId());
+    Assert.assertEquals(0, mapPartitionShuffleDescriptor.getAttemptId());
+  }
+
+  @Test
+  public void testRegisterMultipleJobs() throws UnknownHostException {
+    JobID jobID1 = JobID.generate();
+    TierShuffleHandler tierShuffleHandler1 = createTierShuffleHandler();
+    masterAgent.registerJob(jobID1, tierShuffleHandler1);
+
+    JobID jobID2 = JobID.generate();
+    TierShuffleHandler tierShuffleHandler2 = createTierShuffleHandler();
+    masterAgent.registerJob(jobID2, tierShuffleHandler2);
+
+    IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
+    ResultPartitionID resultPartitionID = new ResultPartitionID();
+    TierShuffleDescriptor tierShuffleDescriptor1 =
+        masterAgent.addPartitionAndGetShuffleDescriptor(jobID1, 
resultPartitionID);
+
+    // use same partition id but different jobId
+    TierShuffleDescriptor tierShuffleDescriptor2 =
+        masterAgent.addPartitionAndGetShuffleDescriptor(jobID2, 
resultPartitionID);
+
+    Assert.assertEquals(
+        ((TierShuffleDescriptorImpl) tierShuffleDescriptor1)
+            .getShuffleResource()
+            .getMapPartitionShuffleDescriptor()
+            .getShuffleId(),
+        0);
+    Assert.assertEquals(
+        ((TierShuffleDescriptorImpl) tierShuffleDescriptor2)
+            .getShuffleResource()
+            .getMapPartitionShuffleDescriptor()
+            .getShuffleId(),
+        1);
+  }
+
+  @After
+  public void tearDown() {
+    if (masterAgent != null) {
+      try {
+        masterAgent.close();
+      } catch (Exception e) {
+        LOG.warn(e.getMessage(), e);
+      }
+    }
+  }
+
+  public CelebornTierMasterAgent createMasterAgent(Configuration 
configuration) {
+    CelebornConf conf = FlinkUtils.toCelebornConf(configuration);
+    return new CelebornTierMasterAgent(conf);
+  }
+}

Reply via email to