This is an automated email from the ASF dual-hosted git repository.
rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new b11d11153 [CELEBORN-1490][CIP-6] Introduce tier producer in celeborn
flink client
b11d11153 is described below
commit b11d111536497282b9e6c43c5993b845ad255a3a
Author: Weijie Guo <[email protected]>
AuthorDate: Fri Sep 20 10:50:26 2024 +0800
[CELEBORN-1490][CIP-6] Introduce tier producer in celeborn flink client
### What changes were proposed in this pull request?
Introduce tier producer in celeborn flink client
Note: Only the last commit need review.
### Why are the changes needed?
Tier producer is the mediator used by flink hybrid shuffle to send data to
celeborn.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Yes
Closes #2733 from reswqa/cip6-5-pr.
Authored-by: Weijie Guo <[email protected]>
Signed-off-by: Shuang <[email protected]>
---
.../plugin/flink/RemoteShuffleOutputGate.java | 5 +-
.../celeborn/plugin/flink/buffer/BufferHeader.java | 9 +
.../celeborn/plugin/flink/buffer/BufferPacker.java | 45 +-
.../flink/buffer/ReceivedNoHeaderBufferPacker.java | 112 +++++
.../celeborn/plugin/flink/utils/BufferUtils.java | 20 +
.../celeborn/plugin/flink/BufferPackSuiteJ.java | 192 +++++++-
.../plugin/flink/tiered/CelebornTierFactory.java | 12 +-
.../flink/tiered/CelebornTierProducerAgent.java | 487 +++++++++++++++++++++
.../tiered/CelebornTierMasterAgentSuiteJ.java | 200 +++++++++
9 files changed, 1043 insertions(+), 39 deletions(-)
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
index d17a182a1..f695af14d 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleOutputGate.java
@@ -33,6 +33,7 @@ import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.DriverChangedException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
@@ -207,13 +208,13 @@ public class RemoteShuffleOutputGate {
}
/** Writes a piece of data to a subpartition. */
- public void write(ByteBuf byteBuf, int subIdx) {
+ public void write(ByteBuf byteBuf, BufferHeader bufferHeader) {
try {
flinkShuffleClient.pushDataToLocation(
shuffleId,
mapId,
attemptId,
- subIdx,
+ bufferHeader.getSubPartitionId(),
io.netty.buffer.Unpooled.wrappedBuffer(byteBuf.nioBuffer()),
partitionLocation,
() -> byteBuf.release());
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
index 6dc6350ce..59e4d5010 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferHeader.java
@@ -37,6 +37,11 @@ public class BufferHeader {
this(0, 0, 0, size + 2, dataType, isCompressed, size);
}
+ public BufferHeader(
+ int subPartitionId, Buffer.DataType dataType, boolean isCompressed, int
size) {
+ this(subPartitionId, 0, 0, size + 2, dataType, isCompressed, size);
+ }
+
public BufferHeader(
int subPartitionId,
int attemptId,
@@ -54,6 +59,10 @@ public class BufferHeader {
this.size = size;
}
+ public int getSubPartitionId() {
+ return subPartitionId;
+ }
+
public Buffer.DataType getDataType() {
return dataType;
}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
index d0a757f19..76a6c2ef7 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/BufferPacker.java
@@ -33,7 +33,15 @@ import org.apache.celeborn.plugin.flink.utils.BufferUtils;
import org.apache.celeborn.plugin.flink.utils.Utils;
import org.apache.celeborn.reflect.DynMethods;
-/** Harness used to pack multiple partial buffers together as a full one. */
+/**
+ * Harness used to pack multiple partial buffers together as a full one. There
are two Flink
+ * integration strategies: Remote Shuffle Service and Hybrid Shuffle. In
Remote Shuffle Service
+ * integration strategy, the {@link BufferPacker} will receive buffers
containing both shuffle data
+ * and the Celeborn header. In Hybrid Shuffle integration strategy employs the
subclass {@link
+ * ReceivedNoHeaderBufferPacker}, which receives buffers containing only
shuffle data. In these two
+ * integration strategies, the BufferPacker must utilize different methods to
pack buffers, and the
+ * result of the packed buffer should be same.
+ */
public class BufferPacker {
private static Logger logger = LoggerFactory.getLogger(BufferPacker.class);
@@ -41,14 +49,15 @@ public class BufferPacker {
void accept(T var1, U var2) throws E;
}
- private final BiConsumerWithException<ByteBuf, Integer,
InterruptedException> ripeBufferHandler;
+ protected final BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler;
- private Buffer cachedBuffer;
+ protected Buffer cachedBuffer;
- private int currentSubIdx = -1;
+ protected int currentSubIdx = -1;
public BufferPacker(
- BiConsumerWithException<ByteBuf, Integer, InterruptedException>
ripeBufferHandler) {
+ BiConsumerWithException<ByteBuf, BufferHeader, InterruptedException>
ripeBufferHandler) {
this.ripeBufferHandler = ripeBufferHandler;
}
@@ -71,7 +80,8 @@ public class BufferPacker {
int targetSubIdx = currentSubIdx;
currentSubIdx = subIdx;
logBufferPack(false, dumpedBuffer.getDataType(),
dumpedBuffer.readableBytes());
- handleRipeBuffer(dumpedBuffer, targetSubIdx);
+ handleRipeBuffer(
+ dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(),
dumpedBuffer.isCompressed());
} else {
/**
* this is an optimization. if cachedBuffer can contain other buffer,
then other buffer can
@@ -95,12 +105,13 @@ public class BufferPacker {
cachedBuffer = buffer;
logBufferPack(false, dumpedBuffer.getDataType(),
dumpedBuffer.readableBytes());
- handleRipeBuffer(dumpedBuffer, currentSubIdx);
+ handleRipeBuffer(
+ dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(),
dumpedBuffer.isCompressed());
}
}
}
- private void logBufferPack(boolean isDrain, Buffer.DataType dataType, int
length) {
+ protected void logBufferPack(boolean isDrain, Buffer.DataType dataType, int
length) {
logger.debug(
"isDrain:{}, cachedBuffer pack partition:{} type:{}, length:{}",
isDrain,
@@ -112,15 +123,27 @@ public class BufferPacker {
public void drain() throws InterruptedException {
if (cachedBuffer != null) {
logBufferPack(true, cachedBuffer.getDataType(),
cachedBuffer.readableBytes());
- handleRipeBuffer(cachedBuffer, currentSubIdx);
+ handleRipeBuffer(
+ cachedBuffer, currentSubIdx, cachedBuffer.getDataType(),
cachedBuffer.isCompressed());
}
cachedBuffer = null;
currentSubIdx = -1;
}
- private void handleRipeBuffer(Buffer buffer, int subIdx) throws
InterruptedException {
+ protected void handleRipeBuffer(
+ Buffer buffer, int subIdx, Buffer.DataType dataType, boolean
isCompressed)
+ throws InterruptedException {
+ // Always set the compress flag to false, because the result buffer
generated by {@link
+ // BufferPacker} needs to be split into multiple buffers in unpack process,
+ // If the compress flag is set to true for this result buffer, it will
throw an exception during
+ // the unpack process, as compressed buffer cannot be sliced.
buffer.setCompressed(false);
- ripeBufferHandler.accept(buffer.asByteBuf(), subIdx);
+ ripeBufferHandler.accept(
+ buffer.asByteBuf(), new BufferHeader(subIdx, dataType, isCompressed,
buffer.getSize()));
+ }
+
+ public boolean isEmpty() {
+ return cachedBuffer == null;
}
public void close() {
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java
new file mode 100644
index 000000000..09337ec4f
--- /dev/null
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/buffer/ReceivedNoHeaderBufferPacker.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.buffer;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+
+/**
+ * Harness used to pack multiple partial buffers together as a full one. It
used in Flink hybrid
+ * shuffle integration strategy now.
+ */
+public class ReceivedNoHeaderBufferPacker extends BufferPacker {
+
+ /** The flink buffer header of cached first buffer. */
+ private BufferHeader firstBufferHeader;
+
+ public ReceivedNoHeaderBufferPacker(
+ BiConsumerWithException<ByteBuf, BufferHeader, InterruptedException>
ripeBufferHandler) {
+ super(ripeBufferHandler);
+ }
+
+ @Override
+ public void process(Buffer buffer, int subIdx) throws InterruptedException {
+ if (buffer == null) {
+ return;
+ }
+
+ if (buffer.readableBytes() == 0) {
+ buffer.recycleBuffer();
+ return;
+ }
+
+ if (cachedBuffer == null) {
+ // cache the first buffer and record flink buffer header of first buffer
+ cachedBuffer = buffer;
+ currentSubIdx = subIdx;
+ firstBufferHeader =
+ new BufferHeader(subIdx, buffer.getDataType(),
buffer.isCompressed(), buffer.getSize());
+ } else if (currentSubIdx != subIdx) {
+ // drain the previous cached buffer and cache current buffer
+ Buffer dumpedBuffer = cachedBuffer;
+ cachedBuffer = buffer;
+ int targetSubIdx = currentSubIdx;
+ currentSubIdx = subIdx;
+ logBufferPack(false, dumpedBuffer.getDataType(),
dumpedBuffer.readableBytes());
+ handleRipeBuffer(
+ dumpedBuffer, targetSubIdx, dumpedBuffer.getDataType(),
dumpedBuffer.isCompressed());
+ firstBufferHeader =
+ new BufferHeader(subIdx, buffer.getDataType(),
buffer.isCompressed(), buffer.getSize());
+ } else {
+ int bufferHeaderLength = BufferUtils.HEADER_LENGTH -
BufferUtils.HEADER_LENGTH_PREFIX;
+ if (cachedBuffer.readableBytes() + buffer.readableBytes() +
bufferHeaderLength
+ <= cachedBuffer.getMaxCapacity() - BufferUtils.HEADER_LENGTH) {
+ // if the cache buffer can contain the current buffer, then pack the
current buffer into
+ // cache buffer
+ ByteBuf byteBuf = cachedBuffer.asByteBuf();
+ byteBuf.writeByte(buffer.getDataType().ordinal());
+ byteBuf.writeBoolean(buffer.isCompressed());
+ byteBuf.writeInt(buffer.getSize());
+ byteBuf.writeBytes(buffer.asByteBuf(), 0, buffer.readableBytes());
+ logBufferPack(false, buffer.getDataType(), buffer.readableBytes() +
bufferHeaderLength);
+
+ buffer.recycleBuffer();
+ } else {
+ // if the cache buffer cannot contain the current buffer, drain the
cached buffer, and cache
+ // the current buffer
+ Buffer dumpedBuffer = cachedBuffer;
+ cachedBuffer = buffer;
+ logBufferPack(false, dumpedBuffer.getDataType(),
dumpedBuffer.readableBytes());
+
+ handleRipeBuffer(
+ dumpedBuffer, currentSubIdx, dumpedBuffer.getDataType(),
dumpedBuffer.isCompressed());
+ firstBufferHeader =
+ new BufferHeader(subIdx, buffer.getDataType(),
buffer.isCompressed(), buffer.getSize());
+ }
+ }
+ }
+
+ @Override
+ protected void handleRipeBuffer(
+ Buffer buffer, int subIdx, Buffer.DataType dataType, boolean
isCompressed)
+ throws InterruptedException {
+ if (buffer == null || buffer.readableBytes() == 0) {
+ return;
+ }
+ // Always set the compress flag to false, because this buffer contains
Celeborn header and
+ // multiple flink data buffers.
+ // It is crucial to keep this flag set to false because we need to slice
this buffer to extract
+ // flink data buffers
+ // during the unpacking process, the flink {@link NetworkBuffer} cannot
correctly slice
+ // compressed buffer.
+ buffer.setCompressed(false);
+ ripeBufferHandler.accept(buffer.asByteBuf(), firstBufferHeader);
+ }
+}
diff --git
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
index 14599e477..999d1eb10 100644
---
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
+++
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/utils/BufferUtils.java
@@ -59,6 +59,26 @@ public class BufferUtils {
buffer.setSize(dataLength + HEADER_LENGTH);
}
+ /**
+ * It is utilized in Hybrid Shuffle integration strategy, in this case the
buffer containing data
+ * only. Copies the data of the compressed buffer to the origin buffer.
+ */
+ public static void setCompressedDataWithoutHeader(Buffer buffer, Buffer
compressedBuffer) {
+ checkArgument(buffer != null, "Must be not null.");
+ checkArgument(buffer.getReaderIndex() == 0, "Illegal reader index.");
+
+ boolean isCompressed = compressedBuffer != null &&
compressedBuffer.isCompressed();
+ int dataLength = isCompressed ? compressedBuffer.readableBytes() :
buffer.readableBytes();
+ ByteBuf byteBuf = buffer.asByteBuf();
+ if (isCompressed) {
+ byteBuf.writerIndex(0);
+ byteBuf.writeBytes(compressedBuffer.asByteBuf());
+ // set the compression flag here, as we need it when writing the
sub-header of this buffer
+ buffer.setCompressed(true);
+ }
+ buffer.setSize(dataLength);
+ }
+
public static void setBufferHeader(
ByteBuf byteBuf, Buffer.DataType dataType, boolean isCompressed, int
dataLength) {
byteBuf.writerIndex(0);
diff --git
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
index 2d5d5e78f..8f3c0ce6e 100644
---
a/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
+++
b/client-flink/common/src/test/java/org/apache/celeborn/plugin/flink/BufferPackSuiteJ.java
@@ -23,20 +23,32 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker;
import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+@RunWith(Parameterized.class)
public class BufferPackSuiteJ {
private static final int BUFFER_SIZE = 20 + 16;
@@ -44,6 +56,18 @@ public class BufferPackSuiteJ {
private BufferPool bufferPool;
+ private boolean bufferPackerReceivedBufferHasHeader;
+
+ public BufferPackSuiteJ(boolean bufferPackerReceivedBufferHasHeader) {
+ this.bufferPackerReceivedBufferHasHeader =
bufferPackerReceivedBufferHasHeader;
+ }
+
+ @Parameterized.Parameters
+ public static Collection prepareData() {
+ Object[][] object = {{true}, {false}};
+ return Arrays.asList(object);
+ }
+
@Before
public void setup() throws Exception {
networkBufferPool = new NetworkBufferPool(10, BUFFER_SIZE);
@@ -66,13 +90,14 @@ public class BufferPackSuiteJ {
Integer subIdx = 2;
List<ByteBuf> output = new ArrayList<>();
- BufferPacker.BiConsumerWithException<ByteBuf, Integer,
InterruptedException> ripeBufferHandler =
- (ripe, sub) -> {
- assertEquals(subIdx, sub);
- output.add(ripe);
- };
-
- BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler =
+ (ripe, header) -> {
+ assertEquals(subIdx,
Integer.valueOf(header.getSubPartitionId()));
+ output.add(ripe);
+ };
+
+ BufferPacker packer = createBufferPakcer(ripeBufferHandler);
packer.process(buffers.get(0), subIdx);
packer.process(buffers.get(1), subIdx);
packer.process(buffers.get(2), subIdx);
@@ -89,9 +114,12 @@ public class BufferPackSuiteJ {
setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
- BufferPacker.BiConsumerWithException<ByteBuf, Integer,
InterruptedException> ripeBufferHandler =
- (ripe, sub) -> output.add(Pair.of(ripe, sub));
- BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler =
+ (ripe, header) ->
+ output.add(
+ Pair.of(addBufferHeaderPossible(ripe, header),
header.getSubPartitionId()));
+ BufferPacker packer = createBufferPakcer(ripeBufferHandler);
fillBuffers(buffers, 0, 1, 2);
packer.process(buffers.get(0), 2);
@@ -123,9 +151,12 @@ public class BufferPackSuiteJ {
setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
- BufferPacker.BiConsumerWithException<ByteBuf, Integer,
InterruptedException> ripeBufferHandler =
- (ripe, sub) -> output.add(Pair.of(ripe, sub));
- BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler =
+ (ripe, header) ->
+ output.add(
+ Pair.of(addBufferHeaderPossible(ripe, header),
header.getSubPartitionId()));
+ BufferPacker packer = createBufferPakcer(ripeBufferHandler);
fillBuffers(buffers, 0, 1, 2);
packer.process(buffers.get(0), 0);
@@ -158,9 +189,12 @@ public class BufferPackSuiteJ {
setDataType(buffers, EVENT_BUFFER, DATA_BUFFER, DATA_BUFFER);
List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
- BufferPacker.BiConsumerWithException<ByteBuf, Integer,
InterruptedException> ripeBufferHandler =
- (ripe, sub) -> output.add(Pair.of(ripe, sub));
- BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler =
+ (ripe, header) ->
+ output.add(
+ Pair.of(addBufferHeaderPossible(ripe, header),
header.getSubPartitionId()));
+ BufferPacker packer = createBufferPakcer(ripeBufferHandler);
fillBuffers(buffers, 0, 1, 2);
packer.process(buffers.get(0), 0);
@@ -186,6 +220,59 @@ public class BufferPackSuiteJ {
unpacked.forEach(Buffer::recycleBuffer);
}
+ @Test
+ public void testPackMultipleBuffers() throws Exception {
+ int numBuffers = 7;
+ List<Buffer> buffers = new ArrayList<>();
+ buffers.add(buildSomeBuffer(100));
+ buffers.addAll(requestBuffers(numBuffers - 1));
+ setCompressed(buffers, true, true, true, false, false, false, true);
+ setDataType(
+ buffers,
+ EVENT_BUFFER,
+ DATA_BUFFER,
+ DATA_BUFFER,
+ EVENT_BUFFER,
+ DATA_BUFFER,
+ DATA_BUFFER,
+ EVENT_BUFFER);
+
+ List<Pair<ByteBuf, Integer>> output = new ArrayList<>();
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler =
+ (ripe, header) ->
+ output.add(
+ Pair.of(addBufferHeaderPossible(ripe, header),
header.getSubPartitionId()));
+ BufferPacker packer = createBufferPakcer(ripeBufferHandler);
+ fillBuffers(buffers, 0, 1, 2, 3, 4, 5, 6, 7);
+
+ for (int i = 0; i < buffers.size(); i++) {
+ packer.process(buffers.get(i), 0);
+ }
+ packer.drain();
+
+ List<Buffer> unpacked = new ArrayList<>();
+ for (int i = 0; i < output.size(); i++) {
+ Pair<ByteBuf, Integer> pair = output.get(i);
+ assertEquals(Integer.valueOf(0), pair.getRight());
+ unpacked.addAll(BufferPacker.unpack(pair.getLeft()));
+ }
+ assertEquals(7, unpacked.size());
+
+ checkIfCompressed(unpacked, true, true, true, false, false, false, true);
+ checkDataType(
+ unpacked,
+ EVENT_BUFFER,
+ DATA_BUFFER,
+ DATA_BUFFER,
+ EVENT_BUFFER,
+ DATA_BUFFER,
+ DATA_BUFFER,
+ EVENT_BUFFER);
+ verifyBuffers(unpacked, 0, 1, 2, 3, 4, 5, 6, 7);
+ unpacked.forEach(Buffer::recycleBuffer);
+ }
+
@Test
public void testFailedToHandleRipeBufferAndClose() throws Exception {
List<Buffer> buffers = requestBuffers(1);
@@ -193,12 +280,13 @@ public class BufferPackSuiteJ {
setDataType(buffers, DATA_BUFFER);
fillBuffers(buffers, 0);
- BufferPacker.BiConsumerWithException<ByteBuf, Integer,
InterruptedException> ripeBufferHandler =
- (ripe, sub) -> {
- // ripe.release();
- throw new RuntimeException("Test");
- };
- BufferPacker packer = new BufferPacker(ripeBufferHandler);
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler =
+ (ripe, header) -> {
+ // ripe.release();
+ throw new RuntimeException("Test");
+ };
+ BufferPacker packer = createBufferPakcer(ripeBufferHandler);
System.out.println(buffers.get(0).refCnt());
packer.process(buffers.get(0), 0);
try {
@@ -248,8 +336,17 @@ public class BufferPackSuiteJ {
for (int i = 0; i < buffers.size(); i++) {
Buffer buffer = buffers.get(i);
ByteBuf target = buffer.asByteBuf();
- BufferUtils.setBufferHeader(target, buffer.getDataType(),
buffer.isCompressed(), 4);
- target.writerIndex(BufferUtils.HEADER_LENGTH);
+
+ if (bufferPackerReceivedBufferHasHeader) {
+ // If the buffer includes a header, we need to leave space for the
header, so we should
+ // update the writer index to BufferUtils.HEADER_LENGTH.
+ BufferUtils.setBufferHeader(target, buffer.getDataType(),
buffer.isCompressed(), 4);
+ target.writerIndex(BufferUtils.HEADER_LENGTH);
+ } else {
+ // if the buffer does not have a header, we can directly write data
starting from the
+ // beginning of the buffer.
+ target.writerIndex(0);
+ }
target.writeInt(ints[i]);
}
}
@@ -260,4 +357,51 @@ public class BufferPackSuiteJ {
assertEquals(expects[i], actual.getInt(0));
}
}
+
+ public static Buffer buildSomeBuffer(int size) {
+ final MemorySegment seg =
MemorySegmentFactory.allocateUnpooledSegment(size);
+ return new NetworkBuffer(seg, MemorySegment::free,
Buffer.DataType.DATA_BUFFER, size);
+ }
+
+ public ByteBuf addBufferHeaderPossible(ByteBuf byteBuf, BufferHeader
bufferHeader) {
+ // Try to add buffer header if bufferPackerReceivedBufferHasHeader set to
false in BufferPacker
+ // drain process
+ if (bufferPackerReceivedBufferHasHeader) {
+ return byteBuf;
+ }
+
+ CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+ // create a small buffer headerBuf to write the buffer header
+ ByteBuf headerBuf = Unpooled.buffer(BufferUtils.HEADER_LENGTH);
+
+ // write celeborn buffer header (subpartitionid(4) + attemptId(4) +
nextBatchId(4) +
+ // compressedsize)
+ headerBuf.writeInt(bufferHeader.getSubPartitionId());
+ headerBuf.writeInt(0);
+ headerBuf.writeInt(0);
+ headerBuf.writeInt(
+ byteBuf.readableBytes() + (BufferUtils.HEADER_LENGTH -
BufferUtils.HEADER_LENGTH_PREFIX));
+
+ // write flink buffer header (dataType(1) + isCompress(1) + size(4))
+ headerBuf.writeByte(bufferHeader.getDataType().ordinal());
+ headerBuf.writeBoolean(bufferHeader.isCompressed());
+ headerBuf.writeInt(bufferHeader.getSize());
+
+ // composite the headerBuf and data buffer together
+ compositeByteBuf.addComponents(true, headerBuf, byteBuf);
+ ByteBuf packedByteBuf =
Unpooled.wrappedBuffer(compositeByteBuf.nioBuffer());
+ byteBuf.writerIndex(0);
+ byteBuf.writeBytes(packedByteBuf, 0, packedByteBuf.readableBytes());
+ return byteBuf;
+ }
+
+ public BufferPacker createBufferPakcer(
+ BufferPacker.BiConsumerWithException<ByteBuf, BufferHeader,
InterruptedException>
+ ripeBufferHandler) {
+ if (bufferPackerReceivedBufferHasHeader) {
+ return new BufferPacker(ripeBufferHandler);
+ } else {
+ return new ReceivedNoHeaderBufferPacker(ripeBufferHandler);
+ }
+ }
}
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
index 326a11985..02306a5ad 100644
---
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierFactory.java
@@ -101,8 +101,16 @@ public class CelebornTierFactory implements TierFactory {
ScheduledExecutorService ioExecutor,
List<TierShuffleDescriptor> shuffleDescriptors,
int maxRequestedBuffers) {
- // TODO impl this in the follow-up PR.
- return null;
+ return new CelebornTierProducerAgent(
+ conf,
+ partitionId,
+ numPartitions,
+ numSubpartitions,
+ NUM_BYTES_PER_SEGMENT,
+ bufferSizeBytes,
+ storageMemoryManager,
+ resourceRegistry,
+ shuffleDescriptors);
}
@Override
diff --git
a/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
new file mode 100644
index 000000000..aab2b3ae5
--- /dev/null
+++
b/client-flink/flink-1.20/src/main/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierProducerAgent.java
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static
org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.EndOfSegmentEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageMemoryManager;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
+import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.CompositeByteBuf;
+import org.apache.flink.shaded.netty4.io.netty.buffer.Unpooled;
+import org.apache.flink.util.ExceptionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.exception.DriverChangedException;
+import org.apache.celeborn.common.protocol.PartitionLocation;
+import org.apache.celeborn.plugin.flink.buffer.BufferHeader;
+import org.apache.celeborn.plugin.flink.buffer.BufferPacker;
+import org.apache.celeborn.plugin.flink.buffer.ReceivedNoHeaderBufferPacker;
+import org.apache.celeborn.plugin.flink.readclient.FlinkShuffleClientImpl;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+public class CelebornTierProducerAgent implements TierProducerAgent {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(CelebornTierProducerAgent.class);
+
+ private final int numBuffersPerSegment;
+
+ private final int bufferSizeBytes;
+
+ private final int numPartitions;
+
+ private final int numSubPartitions;
+
+ private final CelebornConf celebornConf;
+
+ private final TieredStorageMemoryManager memoryManager;
+
+ private final String applicationId;
+
+ private final int shuffleId;
+
+ private final int mapId;
+
+ private final int attemptId;
+
+ private final int partitionId;
+
+ private final String lifecycleManagerHost;
+
+ private final int lifecycleManagerPort;
+
+ private final long lifecycleManagerTimestamp;
+
+ private FlinkShuffleClientImpl flinkShuffleClient;
+
+ private BufferPacker bufferPacker;
+
+ private final int[] subPartitionSegmentIds;
+
+ private final int[] subPartitionSegmentBuffers;
+
+ private final int maxReviveTimes;
+
+ private PartitionLocation partitionLocation;
+
+ private boolean hasRegisteredShuffle;
+
+ private int currentRegionIndex = 0;
+
+ private int currentSubpartition = 0;
+
+ private boolean hasSentHandshake = false;
+
+ private boolean hasSentRegionStart = false;
+
+ private volatile boolean isReleased;
+
+ CelebornTierProducerAgent(
+ CelebornConf conf,
+ TieredStoragePartitionId partitionId,
+ int numPartitions,
+ int numSubPartitions,
+ int numBytesPerSegment,
+ int bufferSizeBytes,
+ TieredStorageMemoryManager memoryManager,
+ TieredStorageResourceRegistry resourceRegistry,
+ List<TierShuffleDescriptor> shuffleDescriptors) {
+ checkArgument(
+ numBytesPerSegment >= bufferSizeBytes, "One segment should contain at
least one buffer.");
+ checkArgument(shuffleDescriptors.size() == 1, "There should be only one
shuffle descriptor.");
+ TierShuffleDescriptor descriptor = shuffleDescriptors.get(0);
+ checkArgument(
+ descriptor instanceof TierShuffleDescriptorImpl,
+ "Wrong shuffle descriptor type " + descriptor.getClass());
+ TierShuffleDescriptorImpl shuffleDesc = (TierShuffleDescriptorImpl)
descriptor;
+
+ this.numBuffersPerSegment = numBytesPerSegment / bufferSizeBytes;
+ this.bufferSizeBytes = bufferSizeBytes;
+ this.memoryManager = memoryManager;
+ this.numPartitions = numPartitions;
+ this.numSubPartitions = numSubPartitions;
+ this.celebornConf = conf;
+ this.subPartitionSegmentIds = new int[numSubPartitions];
+ this.subPartitionSegmentBuffers = new int[numSubPartitions];
+ this.maxReviveTimes = conf.clientPushMaxReviveTimes();
+
+ this.applicationId = shuffleDesc.getCelebornAppId();
+ this.shuffleId =
+
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getShuffleId();
+ this.mapId =
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getMapId();
+ this.attemptId =
+
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getAttemptId();
+ this.partitionId =
+
shuffleDesc.getShuffleResource().getMapPartitionShuffleDescriptor().getPartitionId();
+ this.lifecycleManagerHost =
shuffleDesc.getShuffleResource().getLifecycleManagerHost();
+ this.lifecycleManagerPort =
shuffleDesc.getShuffleResource().getLifecycleManagerPort();
+ this.lifecycleManagerTimestamp =
+ shuffleDesc.getShuffleResource().getLifecycleManagerTimestamp();
+ this.flinkShuffleClient = getShuffleClient();
+
+ Arrays.fill(subPartitionSegmentIds, -1);
+ Arrays.fill(subPartitionSegmentBuffers, 0);
+
+ this.bufferPacker = new ReceivedNoHeaderBufferPacker(this::write);
+ resourceRegistry.registerResource(partitionId, this::releaseResources);
+ registerShuffle();
+ try {
+ handshake();
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ @Override
+ public boolean tryStartNewSegment(
+ TieredStorageSubpartitionId tieredStorageSubpartitionId, int segmentId,
int minNumBuffers) {
+ int subPartitionId = tieredStorageSubpartitionId.getSubpartitionId();
+ checkState(
+ segmentId >= subPartitionSegmentIds[subPartitionId], "Wrong segment id
" + segmentId);
+ subPartitionSegmentIds[subPartitionId] = segmentId;
+ // If the start segment rpc is sent, the worker side will expect that
+ // there must be at least one buffer will be written in the next moment.
+ try {
+ flinkShuffleClient.segmentStart(
+ shuffleId, mapId, attemptId, subPartitionId, segmentId,
partitionLocation);
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ return true;
+ }
+
+ @Override
+ public boolean tryWrite(
+ TieredStorageSubpartitionId tieredStorageSubpartitionId,
+ Buffer buffer,
+ Object bufferOwner,
+ int numRemainingConsecutiveBuffers) {
+ // It should be noted that, unlike RemoteShuffleOutputGate#write, the
received buffer contains
+ // only
+ // and does not have any remaining space for writing the celeborn header.
+
+ int subPartitionId = tieredStorageSubpartitionId.getSubpartitionId();
+
+ if (subPartitionSegmentBuffers[subPartitionId] + 1 +
numRemainingConsecutiveBuffers
+ >= numBuffersPerSegment) {
+ // End the current segment if the segment buffer count reaches the
threshold
+ subPartitionSegmentBuffers[subPartitionId] = 0;
+ try {
+ bufferPacker.drain();
+ } catch (InterruptedException e) {
+ buffer.recycleBuffer();
+ ExceptionUtils.rethrow(e, "Failed to process buffer.");
+ }
+ appendEndOfSegmentBuffer(subPartitionId);
+ return false;
+ }
+
+ if (buffer.isBuffer()) {
+ memoryManager.transferBufferOwnership(
+ bufferOwner, CelebornTierFactory.getCelebornTierName(), buffer);
+ }
+
+ // write buffer to BufferPacker and record buffer count per subPartition
per segment
+ processBuffer(buffer, subPartitionId);
+ subPartitionSegmentBuffers[subPartitionId]++;
+ return true;
+ }
+
+ @Override
+ public void close() {
+ if (hasSentRegionStart) {
+ regionFinish();
+ }
+ try {
+ if (hasRegisteredShuffle && partitionLocation != null) {
+ flinkShuffleClient.mapPartitionMapperEnd(
+ shuffleId, mapId, attemptId, numPartitions,
partitionLocation.getId());
+ }
+ } catch (Exception e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ bufferPacker.close();
+ bufferPacker = null;
+ flinkShuffleClient.cleanup(shuffleId, mapId, attemptId);
+ flinkShuffleClient = null;
+ }
+
+ private void regionStartOrFinish(int subPartitionId) {
+ // check whether the region should be started or finished
+ regionStart();
+ if (subPartitionId < currentSubpartition) {
+ // if the consumed subPartitionId is out of order, it means that should
the previous region
+ // should be finished, and starting a new region.
+ regionFinish();
+ LOG.debug(
+ "Check region finish sub partition id {} and start next region {}",
+ subPartitionId,
+ currentRegionIndex);
+ regionStart();
+ }
+ }
+
+ private void regionStart() {
+ if (hasSentRegionStart) {
+ return;
+ }
+ regionStartWithRevive();
+ }
+
+ private void regionStartWithRevive() {
+ try {
+ int remainingReviveTimes = maxReviveTimes;
+ while (remainingReviveTimes-- > 0 && !hasSentRegionStart) {
+ Optional<PartitionLocation> revivePartition =
+ flinkShuffleClient.regionStart(
+ shuffleId, mapId, attemptId, partitionLocation,
currentRegionIndex, false);
+ if (revivePartition.isPresent()) {
+ LOG.info(
+ "Revive at regionStart, currentTimes:{}, totalTimes:{} for
shuffleId:{}, mapId:{}, "
+ + "attempId:{}, currentRegionIndex:{}, isBroadcast:{},
newPartition:{}, oldPartition:{}",
+ remainingReviveTimes,
+ maxReviveTimes,
+ shuffleId,
+ mapId,
+ attemptId,
+ currentRegionIndex,
+ false,
+ revivePartition,
+ partitionLocation);
+ partitionLocation = revivePartition.get();
+ // For every revive partition, handshake should be sent firstly
+ hasSentHandshake = false;
+ handshake();
+ if (numSubPartitions > 0) {
+ for (int i = 0; i < numSubPartitions; i++) {
+ flinkShuffleClient.segmentStart(
+ shuffleId, mapId, attemptId, i, subPartitionSegmentIds[i],
partitionLocation);
+ }
+ }
+ } else {
+ hasSentRegionStart = true;
+ currentSubpartition = 0;
+ }
+ }
+ if (remainingReviveTimes == 0 && !hasSentRegionStart) {
+ throw new RuntimeException(
+ "After retry " + maxReviveTimes + " times, still failed to send
regionStart");
+ }
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ void regionFinish() {
+ try {
+ bufferPacker.drain();
+ flinkShuffleClient.regionFinish(shuffleId, mapId, attemptId,
partitionLocation);
+ hasSentRegionStart = false;
+ currentRegionIndex++;
+ } catch (Exception e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ private void handshake() throws IOException {
+ try {
+ int remainingReviveTimes = maxReviveTimes;
+ while (remainingReviveTimes-- > 0 && !hasSentHandshake) {
+ Optional<PartitionLocation> revivePartition =
+ flinkShuffleClient.pushDataHandShake(
+ shuffleId, mapId, attemptId, numSubPartitions,
bufferSizeBytes, partitionLocation);
+ // if remainingReviveTimes == 0 and revivePartition.isPresent(), there
is no need to send
+ // handshake again
+ if (revivePartition.isPresent() && remainingReviveTimes > 0) {
+ LOG.info(
+ "Revive at handshake, currentTimes:{}, totalTimes:{} for
shuffleId:{}, mapId:{}, "
+ + "attempId:{}, currentRegionIndex:{}, newPartition:{},
oldPartition:{}",
+ remainingReviveTimes,
+ maxReviveTimes,
+ shuffleId,
+ mapId,
+ attemptId,
+ currentRegionIndex,
+ revivePartition,
+ partitionLocation);
+ partitionLocation = revivePartition.get();
+ hasSentHandshake = false;
+ } else {
+ hasSentHandshake = true;
+ }
+ }
+ if (remainingReviveTimes == 0 && !hasSentHandshake) {
+ throw new RuntimeException(
+ "After retry " + maxReviveTimes + " times, still failed to send
handshake");
+ }
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ private void releaseResources() {
+ if (!isReleased) {
+ isReleased = true;
+ }
+ }
+
+ private void registerShuffle() {
+ try {
+ if (!hasRegisteredShuffle) {
+ partitionLocation =
+ flinkShuffleClient.registerMapPartitionTask(
+ shuffleId, numPartitions, mapId, attemptId, partitionId, true);
+ Utils.checkNotNull(partitionLocation);
+ hasRegisteredShuffle = true;
+ }
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ private void write(ByteBuf byteBuf, BufferHeader bufferHeader) {
+ try {
+ // create a composite buffer and write a header into it. This composite
buffer will serve as
+ // the result packed buffer.
+ CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
+ ByteBuf headerBuf = Unpooled.buffer(BufferUtils.HEADER_LENGTH);
+
+ // write celeborn buffer header (subpartitionid(4) + attemptId(4) +
nextBatchId(4) +
+ // compressedsize)
+ headerBuf.writeInt(bufferHeader.getSubPartitionId());
+ headerBuf.writeInt(attemptId);
+ headerBuf.writeInt(0);
+ headerBuf.writeInt(
+ byteBuf.readableBytes() + (BufferUtils.HEADER_LENGTH -
BufferUtils.HEADER_LENGTH_PREFIX));
+
+ // write flink buffer header (dataType(1) + isCompress(1) + size(4))
+ headerBuf.writeByte(bufferHeader.getDataType().ordinal());
+ headerBuf.writeBoolean(bufferHeader.isCompressed());
+ headerBuf.writeInt(bufferHeader.getSize());
+
+ // composite the headerBuf and data buffer together
+ compositeByteBuf.addComponents(true, headerBuf, byteBuf);
+ io.netty.buffer.ByteBuf wrappedBuffer =
+ io.netty.buffer.Unpooled.wrappedBuffer(compositeByteBuf.nioBuffer());
+
+ int numWritten =
+ flinkShuffleClient.pushDataToLocation(
+ shuffleId,
+ mapId,
+ attemptId,
+ bufferHeader.getSubPartitionId(),
+ wrappedBuffer,
+ partitionLocation,
+ compositeByteBuf::release);
+ checkState(
+ numWritten == byteBuf.readableBytes() + BufferUtils.HEADER_LENGTH,
"Wrong written size.");
+ } catch (IOException e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ private void appendEndOfSegmentBuffer(int subPartitionId) {
+ try {
+ checkState(bufferPacker.isEmpty(), "BufferPacker is not empty");
+ MemorySegment endSegmentMemorySegment =
+ MemorySegmentFactory.wrap(
+
EventSerializer.toSerializedEvent(EndOfSegmentEvent.INSTANCE).array());
+ Buffer endOfSegmentBuffer =
+ new NetworkBuffer(
+ endSegmentMemorySegment,
+ FreeingBufferRecycler.INSTANCE,
+ END_OF_SEGMENT,
+ endSegmentMemorySegment.size());
+ processBuffer(endOfSegmentBuffer, subPartitionId);
+ } catch (Exception e) {
+ ExceptionUtils.rethrow(e, "Failed to append end of segment event.");
+ }
+ }
+
+ private void processBuffer(Buffer originBuffer, int subPartitionId) {
+ try {
+ regionStartOrFinish(subPartitionId);
+ currentSubpartition = subPartitionId;
+
+ Buffer buffer = originBuffer;
+ if (originBuffer.isCompressed()) {
+ // In flink 1.20.0, it will receive a compressed buffer. However,
since we need to write
+ // data to this buffer and the compressed buffer is read-only,
+ // we must create a new Buffer object to the wrap origin buffer.
+ NetworkBuffer networkBuffer =
+ new NetworkBuffer(
+ originBuffer.getMemorySegment(),
+ originBuffer.getRecycler(),
+ originBuffer.getDataType(),
+ originBuffer.getSize());
+ networkBuffer.writerIndex(originBuffer.asByteBuf().writerIndex());
+ buffer = networkBuffer;
+ }
+
+ // TODO: To enhance performance, the flink should pass an no-compressed
buffer to producer
+ // agent and we compress the buffer here
+
+ // set the buffer meta
+ BufferUtils.setCompressedDataWithoutHeader(buffer, originBuffer);
+
+ bufferPacker.process(buffer, subPartitionId);
+ } catch (InterruptedException e) {
+ originBuffer.recycleBuffer();
+ ExceptionUtils.rethrow(e, "Failed to process buffer.");
+ }
+ }
+
+ @VisibleForTesting
+ FlinkShuffleClientImpl getShuffleClient() {
+ try {
+ return FlinkShuffleClientImpl.get(
+ applicationId,
+ lifecycleManagerHost,
+ lifecycleManagerPort,
+ lifecycleManagerTimestamp,
+ celebornConf,
+ null);
+ } catch (DriverChangedException e) {
+ // would generate a new attempt to retry output gate
+ throw new RuntimeException(e.getMessage());
+ }
+ }
+}
diff --git
a/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierMasterAgentSuiteJ.java
b/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierMasterAgentSuiteJ.java
new file mode 100644
index 000000000..f53d010cd
--- /dev/null
+++
b/client-flink/flink-1.20/src/test/java/org/apache/celeborn/plugin/flink/tiered/CelebornTierMasterAgentSuiteJ.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.tiered;
+
+import java.net.UnknownHostException;
+import java.util.Collection;
+import java.util.concurrent.CompletableFuture;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleDescriptor;
+import
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierShuffleHandler;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.util.Utils$;
+import org.apache.celeborn.plugin.flink.ShuffleResource;
+import org.apache.celeborn.plugin.flink.ShuffleResourceDescriptor;
+import org.apache.celeborn.plugin.flink.utils.FlinkUtils;
+
+public class CelebornTierMasterAgentSuiteJ {
+ private static final Logger LOG =
LoggerFactory.getLogger(CelebornTierMasterAgentSuiteJ.class);
+ private CelebornTierMasterAgent masterAgent;
+
+ @Before
+ public void setUp() {
+ Configuration configuration = new Configuration();
+ int startPort = Utils$.MODULE$.selectRandomPort(1024, 65535);
+ configuration.setInteger("celeborn.master.port", startPort);
+ configuration.setString("celeborn.master.endpoints", "localhost:" +
startPort);
+ masterAgent = createMasterAgent(configuration);
+ }
+
+ @Test
+ public void testRegisterJob() {
+ TierShuffleHandler tierShuffleHandler = createTierShuffleHandler();
+ JobID jobID = JobID.generate();
+ masterAgent.registerJob(jobID, tierShuffleHandler);
+
+ // reRunRegister job
+ try {
+ masterAgent.registerJob(jobID, tierShuffleHandler);
+ Assert.fail("should throw exception if double register job");
+ } catch (Exception e) {
+ Assert.assertTrue(true);
+ }
+
+ // unRegister job
+ masterAgent.unregisterJob(jobID);
+ masterAgent.registerJob(jobID, tierShuffleHandler);
+ }
+
+ private static TierShuffleHandler createTierShuffleHandler() {
+ return new TierShuffleHandler() {
+
+ @Override
+ public CompletableFuture<?> onReleasePartitions(
+ Collection<TieredStoragePartitionId> collection) {
+ return CompletableFuture.completedFuture(null);
+ }
+
+ @Override
+ public void onFatalError(Throwable throwable) {
+ System.exit(-1);
+ }
+ };
+ }
+
+ @Test
+ public void testRegisterPartitionWithProducer() {
+ JobID jobID = JobID.generate();
+ TierShuffleHandler tierShuffleHandler = createTierShuffleHandler();
+ masterAgent.registerJob(jobID, tierShuffleHandler);
+
+ ExecutionAttemptID executionAttemptID =
+ new ExecutionAttemptID(
+ new ExecutionGraphID(), new ExecutionVertexID(new JobVertexID(0L,
0L), 0), 0);
+ ResultPartitionID resultPartitionID =
+ new ResultPartitionID(
+ new IntermediateResultPartitionID(new IntermediateDataSetID(), 0),
executionAttemptID);
+ TierShuffleDescriptor tierShuffleDescriptor =
+ masterAgent.addPartitionAndGetShuffleDescriptor(jobID,
resultPartitionID);
+ Assert.assertTrue(tierShuffleDescriptor instanceof
TierShuffleDescriptorImpl);
+ ShuffleResource shuffleResource =
+ ((TierShuffleDescriptorImpl)
tierShuffleDescriptor).getShuffleResource();
+ ShuffleResourceDescriptor mapPartitionShuffleDescriptor =
+ shuffleResource.getMapPartitionShuffleDescriptor();
+
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getPartitionId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getAttemptId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getMapId());
+
+ // use same partition id
+ tierShuffleDescriptor =
+ masterAgent.addPartitionAndGetShuffleDescriptor(jobID,
resultPartitionID);
+ mapPartitionShuffleDescriptor =
+ ((TierShuffleDescriptorImpl) tierShuffleDescriptor)
+ .getShuffleResource()
+ .getMapPartitionShuffleDescriptor();
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getMapId());
+ Assert.assertEquals(1, mapPartitionShuffleDescriptor.getPartitionId());
+ Assert.assertEquals(1, mapPartitionShuffleDescriptor.getAttemptId());
+
+ // use another partition number
+ tierShuffleDescriptor =
+ masterAgent.addPartitionAndGetShuffleDescriptor(
+ jobID,
+ new ResultPartitionID(
+ new IntermediateResultPartitionID(
+
resultPartitionID.getPartitionId().getIntermediateDataSetID(), 1),
+ executionAttemptID));
+ mapPartitionShuffleDescriptor =
+ ((TierShuffleDescriptorImpl) tierShuffleDescriptor)
+ .getShuffleResource()
+ .getMapPartitionShuffleDescriptor();
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getShuffleId());
+ Assert.assertEquals(1, mapPartitionShuffleDescriptor.getMapId());
+ Assert.assertEquals(2, mapPartitionShuffleDescriptor.getPartitionId());
+ Assert.assertEquals(0, mapPartitionShuffleDescriptor.getAttemptId());
+ }
+
+ @Test
+ public void testRegisterMultipleJobs() throws UnknownHostException {
+ JobID jobID1 = JobID.generate();
+ TierShuffleHandler tierShuffleHandler1 = createTierShuffleHandler();
+ masterAgent.registerJob(jobID1, tierShuffleHandler1);
+
+ JobID jobID2 = JobID.generate();
+ TierShuffleHandler tierShuffleHandler2 = createTierShuffleHandler();
+ masterAgent.registerJob(jobID2, tierShuffleHandler2);
+
+ IntermediateDataSetID intermediateDataSetID = new IntermediateDataSetID();
+ ResultPartitionID resultPartitionID = new ResultPartitionID();
+ TierShuffleDescriptor tierShuffleDescriptor1 =
+ masterAgent.addPartitionAndGetShuffleDescriptor(jobID1,
resultPartitionID);
+
+ // use same partition id but different jobId
+ TierShuffleDescriptor tierShuffleDescriptor2 =
+ masterAgent.addPartitionAndGetShuffleDescriptor(jobID2,
resultPartitionID);
+
+ Assert.assertEquals(
+ ((TierShuffleDescriptorImpl) tierShuffleDescriptor1)
+ .getShuffleResource()
+ .getMapPartitionShuffleDescriptor()
+ .getShuffleId(),
+ 0);
+ Assert.assertEquals(
+ ((TierShuffleDescriptorImpl) tierShuffleDescriptor2)
+ .getShuffleResource()
+ .getMapPartitionShuffleDescriptor()
+ .getShuffleId(),
+ 1);
+ }
+
+ @After
+ public void tearDown() {
+ if (masterAgent != null) {
+ try {
+ masterAgent.close();
+ } catch (Exception e) {
+ LOG.warn(e.getMessage(), e);
+ }
+ }
+ }
+
+ public CelebornTierMasterAgent createMasterAgent(Configuration
configuration) {
+ CelebornConf conf = FlinkUtils.toCelebornConf(configuration);
+ return new CelebornTierMasterAgent(conf);
+ }
+}