This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 907364db [CELEBORN-156] add remoteShuffleResultPartition in
flink-plugin (#1103)
907364db is described below
commit 907364dbf2236230a591406fe4bb2d62fa3b2e47
Author: zhongqiangczq <[email protected]>
AuthorDate: Wed Dec 21 12:22:17 2022 +0800
[CELEBORN-156] add remoteShuffleResultPartition in flink-plugin (#1103)
---
.../plugin/flink/RemoteShuffleResultPartition.java | 380 ++++++++++++++++++
.../flink/RemoteShuffleResultPartitionSuiteJ.java | 93 +++++
.../plugin/flink/buffer/PartitionSortedBuffer.java | 440 +++++++++++++++++++++
.../celeborn/plugin/flink/buffer/SortBuffer.java | 92 +++++
4 files changed, 1005 insertions(+)
diff --git
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
new file mode 100644
index 00000000..6be7c665
--- /dev/null
+++
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.CompletableFuture;
+
+import javax.annotation.Nullable;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.plugin.flink.buffer.PartitionSortedBuffer;
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/**
+ * A {@link ResultPartition} which appends records and events to {@link
SortBuffer} and after the
+ * {@link SortBuffer} is full, all data in the {@link SortBuffer} will be
copied and spilled to the
+ * remote shuffle service in subpartition index order sequentially. Large
records that can not be
+ * appended to an empty {@link
org.apache.flink.runtime.io.network.partition.SortBuffer} will be
+ * spilled directly.
+ */
+public class RemoteShuffleResultPartition extends ResultPartition {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(RemoteShuffleResultPartition.class);
+
+ /** Size of network buffer and write buffer. */
+ private final int networkBufferSize;
+
+ /** {@link SortBuffer} for records sent by {@link
#broadcastRecord(ByteBuffer)}. */
+ private SortBuffer broadcastSortBuffer;
+
+ /** {@link SortBuffer} for records sent by {@link #emitRecord(ByteBuffer,
int)}. */
+ private SortBuffer unicastSortBuffer;
+
+ /** Utility to spill data to shuffle workers. */
+ private final RemoteShuffleOutputGate outputGate;
+
+ public RemoteShuffleResultPartition(
+ String owningTaskName,
+ int partitionIndex,
+ ResultPartitionID partitionId,
+ ResultPartitionType partitionType,
+ int numSubpartitions,
+ int numTargetKeyGroups,
+ int networkBufferSize,
+ ResultPartitionManager partitionManager,
+ @Nullable BufferCompressor bufferCompressor,
+ SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+ RemoteShuffleOutputGate outputGate) {
+
+ super(
+ owningTaskName,
+ partitionIndex,
+ partitionId,
+ partitionType,
+ numSubpartitions,
+ numTargetKeyGroups,
+ partitionManager,
+ bufferCompressor,
+ bufferPoolFactory);
+
+ this.networkBufferSize = networkBufferSize;
+ this.outputGate = outputGate;
+ }
+
+ @Override
+ public void setup() throws IOException {
+ LOG.info("Setup {}", this);
+ super.setup();
+ BufferUtils.reserveNumRequiredBuffers(bufferPool, 1);
+ try {
+ outputGate.setup();
+ } catch (Throwable throwable) {
+ LOG.error("Failed to setup remote output gate.", throwable);
+ Utils.rethrowAsRuntimeException(throwable);
+ }
+ }
+
+ @Override
+ public void emitRecord(ByteBuffer record, int targetSubpartition) throws
IOException {
+ emit(record, targetSubpartition, DataType.DATA_BUFFER, false);
+ }
+
+ @Override
+ public void broadcastRecord(ByteBuffer record) throws IOException {
+ broadcast(record, DataType.DATA_BUFFER);
+ }
+
+ @Override
+ public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent)
throws IOException {
+ Buffer buffer = EventSerializer.toBuffer(event, isPriorityEvent);
+ try {
+ ByteBuffer serializedEvent = buffer.getNioBufferReadable();
+ broadcast(serializedEvent, buffer.getDataType());
+ } finally {
+ buffer.recycleBuffer();
+ }
+ }
+
+ private void broadcast(ByteBuffer record, DataType dataType) throws
IOException {
+ emit(record, 0, dataType, true);
+ }
+
+ private void emit(
+ ByteBuffer record, int targetSubpartition, DataType dataType, boolean
isBroadcast)
+ throws IOException {
+
+ checkInProduceState();
+ if (isBroadcast) {
+ Preconditions.checkState(
+ targetSubpartition == 0, "Target subpartition index can only be 0
when broadcast.");
+ }
+
+ SortBuffer sortBuffer = isBroadcast ? getBroadcastSortBuffer() :
getUnicastSortBuffer();
+ if (sortBuffer.append(record, targetSubpartition, dataType)) {
+ return;
+ }
+
+ try {
+ if (!sortBuffer.hasRemaining()) {
+ // the record can not be appended to the free sort buffer because it
is too large
+ sortBuffer.finish();
+ sortBuffer.release();
+ writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
+ return;
+ }
+ flushSortBuffer(sortBuffer, isBroadcast);
+ } catch (InterruptedException e) {
+ LOG.error("Failed to flush the sort buffer.", e);
+ Utils.rethrowAsRuntimeException(e);
+ }
+ emit(record, targetSubpartition, dataType, isBroadcast);
+ }
+
+ private void releaseSortBuffer(SortBuffer sortBuffer) {
+ if (sortBuffer != null) {
+ sortBuffer.release();
+ }
+ }
+
+ @VisibleForTesting
+ SortBuffer getUnicastSortBuffer() throws IOException {
+ flushBroadcastSortBuffer();
+
+ if (unicastSortBuffer != null && !unicastSortBuffer.isFinished()) {
+ return unicastSortBuffer;
+ }
+
+ unicastSortBuffer =
+ new PartitionSortedBuffer(bufferPool, numSubpartitions,
networkBufferSize, null);
+ return unicastSortBuffer;
+ }
+
+ private SortBuffer getBroadcastSortBuffer() throws IOException {
+ flushUnicastSortBuffer();
+
+ if (broadcastSortBuffer != null && !broadcastSortBuffer.isFinished()) {
+ return broadcastSortBuffer;
+ }
+
+ broadcastSortBuffer =
+ new PartitionSortedBuffer(bufferPool, numSubpartitions,
networkBufferSize, null);
+ return broadcastSortBuffer;
+ }
+
+ private void flushBroadcastSortBuffer() throws IOException {
+ flushSortBuffer(broadcastSortBuffer, true);
+ }
+
+ private void flushUnicastSortBuffer() throws IOException {
+ flushSortBuffer(unicastSortBuffer, false);
+ }
+
+ @VisibleForTesting
+ void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws
IOException {
+ if (sortBuffer == null || sortBuffer.isReleased()) {
+ return;
+ }
+ sortBuffer.finish();
+ if (sortBuffer.hasRemaining()) {
+ try {
+ outputGate.regionStart(isBroadcast);
+ while (sortBuffer.hasRemaining()) {
+ MemorySegment segment =
outputGate.getBufferPool().requestMemorySegmentBlocking();
+ SortBuffer.BufferWithChannel bufferWithChannel;
+ try {
+ bufferWithChannel =
+ sortBuffer.copyIntoSegment(
+ segment, outputGate.getBufferPool(),
BufferUtils.HEADER_LENGTH);
+ } catch (Throwable t) {
+ outputGate.getBufferPool().recycle(segment);
+ throw new FlinkRuntimeException("Shuffle write failure.", t);
+ }
+
+ Buffer buffer = bufferWithChannel.getBuffer();
+ int subpartitionIndex = bufferWithChannel.getChannelIndex();
+ updateStatistics(bufferWithChannel.getBuffer());
+ writeCompressedBufferIfPossible(buffer, subpartitionIndex);
+ }
+ outputGate.regionFinish();
+ } catch (InterruptedException e) {
+ throw new IOException("Failed to flush the sort buffer, broadcast=" +
isBroadcast, e);
+ }
+ }
+ releaseSortBuffer(sortBuffer);
+ }
+
+ private void writeCompressedBufferIfPossible(Buffer buffer, int
targetSubpartition)
+ throws InterruptedException {
+ Buffer compressedBuffer = null;
+ try {
+ if (canBeCompressed(buffer)) {
+ Buffer dataBuffer =
+ buffer.readOnlySlice(
+ BufferUtils.HEADER_LENGTH, buffer.getSize() -
BufferUtils.HEADER_LENGTH);
+ compressedBuffer =
+
Utils.checkNotNull(bufferCompressor).compressToIntermediateBuffer(dataBuffer);
+ }
+ BufferUtils.setCompressedDataWithHeader(buffer, compressedBuffer);
+ } catch (Throwable throwable) {
+ buffer.recycleBuffer();
+ throw new RuntimeException("Shuffle write failure.", throwable);
+ } finally {
+ if (compressedBuffer != null && compressedBuffer.isCompressed()) {
+ compressedBuffer.setReaderIndex(0);
+ compressedBuffer.recycleBuffer();
+ }
+ }
+ outputGate.write(buffer, targetSubpartition);
+ }
+
+ private void updateStatistics(Buffer buffer) {
+ numBuffersOut.inc();
+ numBytesOut.inc(buffer.readableBytes() - BufferUtils.HEADER_LENGTH);
+ }
+
+ /** Spills the large record into {@link RemoteShuffleOutputGate}. */
+ private void writeLargeRecord(
+ ByteBuffer record, int targetSubpartition, DataType dataType, boolean
isBroadcast)
+ throws InterruptedException {
+
+ outputGate.regionStart(isBroadcast);
+ while (record.hasRemaining()) {
+ MemorySegment writeBuffer =
outputGate.getBufferPool().requestMemorySegmentBlocking();
+ int toCopy = Math.min(record.remaining(), writeBuffer.size() -
BufferUtils.HEADER_LENGTH);
+ writeBuffer.put(BufferUtils.HEADER_LENGTH, record, toCopy);
+ NetworkBuffer buffer =
+ new NetworkBuffer(
+ writeBuffer,
+ outputGate.getBufferPool(),
+ dataType,
+ toCopy + BufferUtils.HEADER_LENGTH);
+
+ updateStatistics(buffer);
+ writeCompressedBufferIfPossible(buffer, targetSubpartition);
+ }
+ outputGate.regionFinish();
+ }
+
+ @Override
+ public void finish() throws IOException {
+ Utils.checkState(!isReleased(), "Result partition is already released.");
+ broadcastEvent(EndOfPartitionEvent.INSTANCE, false);
+ Utils.checkState(
+ unicastSortBuffer == null || unicastSortBuffer.isReleased(),
+ "The unicast sort buffer should be either null or released.");
+ flushBroadcastSortBuffer();
+ try {
+ outputGate.finish();
+ } catch (InterruptedException e) {
+ throw new IOException("Output gate fails to finish.", e);
+ }
+ super.finish();
+ }
+
+ @Override
+ public synchronized void close() {
+ releaseSortBuffer(unicastSortBuffer);
+ releaseSortBuffer(broadcastSortBuffer);
+ super.close();
+ try {
+ outputGate.close();
+ } catch (Exception e) {
+ Utils.rethrowAsRuntimeException(e);
+ }
+ }
+
+ @Override
+ protected void releaseInternal() {
+ // no-op
+ }
+
+ @Override
+ public void flushAll() {
+ try {
+ flushUnicastSortBuffer();
+ flushBroadcastSortBuffer();
+ } catch (Throwable t) {
+ LOG.error("Failed to flush the current sort buffer.", t);
+ Utils.rethrowAsRuntimeException(t);
+ }
+ }
+
+ @Override
+ public void flush(int subpartitionIndex) {
+ flushAll();
+ }
+
+ @Override
+ public CompletableFuture<?> getAvailableFuture() {
+ return AVAILABLE;
+ }
+
+ @Override
+ public int getNumberOfQueuedBuffers() {
+ return 0;
+ }
+
+ @Override
+ public int getNumberOfQueuedBuffers(int targetSubpartition) {
+ return 0;
+ }
+
+ @Override
+ public ResultSubpartitionView createSubpartitionView(
+ int index, BufferAvailabilityListener availabilityListener) {
+ throw new UnsupportedOperationException("Not supported.");
+ }
+
+ @Override
+ public String toString() {
+ return "ResultPartition "
+ + partitionId.toString()
+ + " ["
+ + partitionType
+ + ", "
+ + numSubpartitions
+ + " subpartitions, shuffle-descriptor: "
+ + outputGate.getShuffleDesc()
+ + "]";
+ }
+}
diff --git
a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
new file mode 100644
index 00000000..a0b620a5
--- /dev/null
+++
b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.util.function.SupplierWithException;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class RemoteShuffleResultPartitionSuiteJ {
+ private BufferCompressor bufferCompressor =
+ new BufferCompressor(32 * 1024, "lz4");
+ private RemoteShuffleOutputGate remoteShuffleOutputGate =
mock(RemoteShuffleOutputGate.class);
+
+ @Before
+ public void setup() {
+
+ }
+
+ @Test
+ public void tesSimpleFlush() throws IOException, InterruptedException {
+ List<SupplierWithException<BufferPool, IOException>> bufferPool =
createBufferPoolFactory();
+ RemoteShuffleResultPartition remoteShuffleResultPartition = new
RemoteShuffleResultPartition("test",
+ 0,
+ new ResultPartitionID(),
+ ResultPartitionType.BLOCKING,
+ 2,
+ 2,
+ 32 * 1024,
+ new ResultPartitionManager(),
+ bufferCompressor,
+ bufferPool.get(0),
+ remoteShuffleOutputGate);
+ remoteShuffleResultPartition.setup();
+ doNothing().when(remoteShuffleOutputGate).regionStart(anyBoolean());
+ doNothing().when(remoteShuffleOutputGate).regionFinish();
+
when(remoteShuffleOutputGate.getBufferPool()).thenReturn(bufferPool.get(1).get());
+ SortBuffer sortBuffer =
remoteShuffleResultPartition.getUnicastSortBuffer();
+ ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[] {1, 2, 3});
+ sortBuffer.append(byteBuffer, 0, Buffer.DataType.DATA_BUFFER);
+ remoteShuffleResultPartition.flushSortBuffer(sortBuffer, true);
+ }
+
+ private List<SupplierWithException<BufferPool, IOException>>
createBufferPoolFactory() {
+ NetworkBufferPool networkBufferPool =
+ new NetworkBufferPool(256 * 8, 32 * 1024,
Duration.ofMillis(1000));
+
+ int numBuffersPerPartition = 64 * 1024 / 32;
+ int numForResultPartition = numBuffersPerPartition * 7 / 8;
+ int numForOutputGate = numBuffersPerPartition - numForResultPartition;
+
+ List<SupplierWithException<BufferPool, IOException>> factories = new
ArrayList<>();
+ factories.add(
+ () ->
networkBufferPool.createBufferPool(numForResultPartition,
numForResultPartition));
+ factories.add(() ->
networkBufferPool.createBufferPool(numForOutputGate, numForOutputGate));
+ return factories;
+ }
+
+
+}
diff --git
a/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
new file mode 100644
index 00000000..a4184cac
--- /dev/null
+++
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.buffer;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.NotThreadSafe;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.util.FlinkRuntimeException;
+
+/**
+ * A {@link SortBuffer} implementation which sorts all appended records only
by subpartition index.
+ * Records of the same subpartition keep the appended order.
+ *
+ * <p>It maintains a list of {@link MemorySegment}s as a joint buffer. Data
will be appended to the
+ * joint buffer sequentially. When writing a record, an index entry will be
appended first. An index
+ * entry consists of 4 fields: 4 bytes for record length, 4 bytes for {@link
DataType} and 8 bytes
+ * for address pointing to the next index entry of the same channel which will
be used to index the
+ * next record to read when coping data from this {@link SortBuffer}. For
simplicity, no index entry
+ * can span multiple segments. The corresponding record data is seated right
after its index entry
+ * and different from the index entry, records have variable length thus may
span multiple segments.
+ */
+@NotThreadSafe
+public class PartitionSortedBuffer implements SortBuffer {
+
+ /**
+ * Size of an index entry: 4 bytes for record length, 4 bytes for data type
and 8 bytes for
+ * pointer to next entry.
+ */
+ private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
+
+ private final Object lock;
+ /** A buffer pool to request memory segments from. */
+ private final BufferPool bufferPool;
+
+ /** A segment list as a joint buffer which stores all records and index
entries. */
+ @GuardedBy("lock")
+ private final ArrayList<MemorySegment> buffers = new ArrayList<>();
+
+ /** Addresses of the first record's index entry for each subpartition. */
+ private final long[] firstIndexEntryAddresses;
+
+ /** Addresses of the last record's index entry for each subpartition. */
+ private final long[] lastIndexEntryAddresses;
+ /** Size of buffers requested from buffer pool. All buffers must be of the
same size. */
+ private final int bufferSize;
+ /** Data of different subpartitions in this sort buffer will be read in this
order. */
+ private final int[] subpartitionReadOrder;
+
+ //
---------------------------------------------------------------------------------------------
+ // Statistics and states
+ //
---------------------------------------------------------------------------------------------
+ /** Total number of bytes already appended to this sort buffer. */
+ private long numTotalBytes;
+ /** Total number of records already appended to this sort buffer. */
+ private long numTotalRecords;
+ /** Total number of bytes already read from this sort buffer. */
+ private long numTotalBytesRead;
+ /** Whether this sort buffer is finished. One can only read a finished sort
buffer. */
+ private boolean isFinished;
+
+ //
---------------------------------------------------------------------------------------------
+ // For writing
+ //
---------------------------------------------------------------------------------------------
+ /** Whether this sort buffer is released. A released sort buffer can not be
used. */
+ @GuardedBy("lock")
+ private boolean isReleased;
+ /** Array index in the segment list of the current available buffer for
writing. */
+ private int writeSegmentIndex;
+
+ //
---------------------------------------------------------------------------------------------
+ // For reading
+ //
---------------------------------------------------------------------------------------------
+ /** Next position in the current available buffer for writing. */
+ private int writeSegmentOffset;
+ /** Index entry address of the current record or event to be read. */
+ private long readIndexEntryAddress;
+
+ /** Record bytes remaining after last copy, which must be read first in next
copy. */
+ private int recordRemainingBytes;
+
+ /** Used to index the current available channel to read data from. */
+ private int readOrderIndex = -1;
+
+ public PartitionSortedBuffer(
+ BufferPool bufferPool,
+ int numSubpartitions,
+ int bufferSize,
+ @Nullable int[] customReadOrder) {
+ checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is too small.");
+
+ this.lock = new Object();
+ this.bufferPool = checkNotNull(bufferPool);
+ this.bufferSize = bufferSize;
+ this.firstIndexEntryAddresses = new long[numSubpartitions];
+ this.lastIndexEntryAddresses = new long[numSubpartitions];
+
+ // initialized with -1 means the corresponding channel has no data.
+ Arrays.fill(firstIndexEntryAddresses, -1L);
+ Arrays.fill(lastIndexEntryAddresses, -1L);
+
+ this.subpartitionReadOrder = new int[numSubpartitions];
+ if (customReadOrder != null) {
+ checkArgument(customReadOrder.length == numSubpartitions, "Illegal data
read order.");
+ System.arraycopy(customReadOrder, 0, this.subpartitionReadOrder, 0,
numSubpartitions);
+ } else {
+ for (int channel = 0; channel < numSubpartitions; ++channel) {
+ this.subpartitionReadOrder[channel] = channel;
+ }
+ }
+ }
+
+ @Override
+ public boolean append(ByteBuffer source, int targetChannel, DataType
dataType)
+ throws IOException {
+ checkArgument(source.hasRemaining(), "Cannot append empty data.");
+ checkState(!isFinished, "Sort buffer is already finished.");
+ checkState(!isReleased, "Sort buffer is already released.");
+
+ int totalBytes = source.remaining();
+
+ // return false directly if it can not allocate enough buffers for the
given record
+ if (!allocateBuffersForRecord(totalBytes)) {
+ return false;
+ }
+
+ // write the index entry and record or event data
+ writeIndex(targetChannel, totalBytes, dataType);
+ writeRecord(source);
+
+ ++numTotalRecords;
+ numTotalBytes += totalBytes;
+
+ return true;
+ }
+
+ private void writeIndex(int channelIndex, int numRecordBytes, DataType
dataType) {
+ MemorySegment segment = buffers.get(writeSegmentIndex);
+
+ // record length takes the high 32 bits and data type takes the low 32 bits
+ segment.putLong(writeSegmentOffset, ((long) numRecordBytes << 32) |
dataType.ordinal());
+
+ // segment index takes the high 32 bits and segment offset takes the low
32 bits
+ long indexEntryAddress = ((long) writeSegmentIndex << 32) |
writeSegmentOffset;
+
+ long lastIndexEntryAddress = lastIndexEntryAddresses[channelIndex];
+ lastIndexEntryAddresses[channelIndex] = indexEntryAddress;
+
+ if (lastIndexEntryAddress >= 0) {
+ // link the previous index entry of the given channel to the new index
entry
+ segment = buffers.get(getSegmentIndexFromPointer(lastIndexEntryAddress));
+ segment.putLong(getSegmentOffsetFromPointer(lastIndexEntryAddress) + 8,
indexEntryAddress);
+ } else {
+ firstIndexEntryAddresses[channelIndex] = indexEntryAddress;
+ }
+
+ // move the writer position forward to write the corresponding record
+ updateWriteSegmentIndexAndOffset(INDEX_ENTRY_SIZE);
+ }
+
+ private void writeRecord(ByteBuffer source) {
+ while (source.hasRemaining()) {
+ MemorySegment segment = buffers.get(writeSegmentIndex);
+ int toCopy = Math.min(bufferSize - writeSegmentOffset,
source.remaining());
+ segment.put(writeSegmentOffset, source, toCopy);
+
+ // move the writer position forward to write the remaining bytes or next
record
+ updateWriteSegmentIndexAndOffset(toCopy);
+ }
+ }
+
+ private boolean allocateBuffersForRecord(int numRecordBytes) throws
IOException {
+ int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
+ int availableBytes = writeSegmentIndex == buffers.size() ? 0 : bufferSize
- writeSegmentOffset;
+
+ // return directly if current available bytes is adequate
+ if (availableBytes >= numBytesRequired) {
+ return true;
+ }
+
+ // skip the remaining free space if the available bytes is not enough for
an index entry
+ if (availableBytes < INDEX_ENTRY_SIZE) {
+ updateWriteSegmentIndexAndOffset(availableBytes);
+ availableBytes = 0;
+ }
+
+ // allocate exactly enough buffers for the appended record
+ do {
+ MemorySegment segment = requestBufferFromPool();
+ if (segment == null) {
+ // return false if we can not allocate enough buffers for the appended
record
+ return false;
+ }
+
+ availableBytes += bufferSize;
+ addBuffer(segment);
+ } while (availableBytes < numBytesRequired);
+
+ return true;
+ }
+
+ private void addBuffer(MemorySegment segment) {
+ synchronized (lock) {
+ if (segment.size() != bufferSize) {
+ bufferPool.recycle(segment);
+ throw new IllegalStateException("Illegal memory segment size.");
+ }
+
+ if (isReleased) {
+ bufferPool.recycle(segment);
+ throw new IllegalStateException("Sort buffer is already released.");
+ }
+
+ buffers.add(segment);
+ }
+ }
+
+ private MemorySegment requestBufferFromPool() throws IOException {
+ try {
+ // blocking request buffers if there is still guaranteed memory
+ if (buffers.size() < bufferPool.getNumberOfRequiredMemorySegments()) {
+ return bufferPool.requestMemorySegmentBlocking();
+ }
+ } catch (InterruptedException e) {
+ throw new IOException("Interrupted while requesting buffer.");
+ }
+
+ return bufferPool.requestMemorySegment();
+ }
+
+ private void updateWriteSegmentIndexAndOffset(int numBytes) {
+ writeSegmentOffset += numBytes;
+
+ // using the next available free buffer if the current is full
+ if (writeSegmentOffset == bufferSize) {
+ ++writeSegmentIndex;
+ writeSegmentOffset = 0;
+ }
+ }
+
+ @Override
+ public BufferWithChannel copyIntoSegment(
+ MemorySegment target, BufferRecycler recycler, int offset) {
+ synchronized (lock) {
+ checkState(hasRemaining(), "No data remaining.");
+ checkState(isFinished, "Should finish the sort buffer first before
coping any data.");
+ checkState(!isReleased, "Sort buffer is already released.");
+
+ int numBytesCopied = 0;
+ DataType bufferDataType = DataType.DATA_BUFFER;
+ int channelIndex = subpartitionReadOrder[readOrderIndex];
+
+ do {
+ int sourceSegmentIndex =
getSegmentIndexFromPointer(readIndexEntryAddress);
+ int sourceSegmentOffset =
getSegmentOffsetFromPointer(readIndexEntryAddress);
+ MemorySegment sourceSegment = buffers.get(sourceSegmentIndex);
+
+ long lengthAndDataType = sourceSegment.getLong(sourceSegmentOffset);
+ int length = getSegmentIndexFromPointer(lengthAndDataType);
+ DataType dataType =
DataType.values()[getSegmentOffsetFromPointer(lengthAndDataType)];
+
+ // return the data read directly if the next to read is an event
+ if (dataType.isEvent() && numBytesCopied > 0) {
+ break;
+ }
+ bufferDataType = dataType;
+
+ // get the next index entry address and move the read position forward
+ long nextReadIndexEntryAddress =
sourceSegment.getLong(sourceSegmentOffset + 8);
+ sourceSegmentOffset += INDEX_ENTRY_SIZE;
+
+ // throws if the event is too big to be accommodated by a buffer.
+ if (bufferDataType.isEvent() && target.size() < length) {
+ throw new FlinkRuntimeException("Event is too big to be accommodated
by a buffer");
+ }
+
+ numBytesCopied +=
+ copyRecordOrEvent(
+ target, numBytesCopied + offset, sourceSegmentIndex,
sourceSegmentOffset, length);
+
+ if (recordRemainingBytes == 0) {
+ // move to next channel if the current channel has been finished
+ if (readIndexEntryAddress == lastIndexEntryAddresses[channelIndex]) {
+ updateReadChannelAndIndexEntryAddress();
+ break;
+ }
+ readIndexEntryAddress = nextReadIndexEntryAddress;
+ }
+ } while (numBytesCopied < target.size() - offset &&
bufferDataType.isBuffer());
+
+ numTotalBytesRead += numBytesCopied;
+ Buffer buffer = new NetworkBuffer(target, recycler, bufferDataType,
numBytesCopied + offset);
+ return new BufferWithChannel(buffer, channelIndex);
+ }
+ }
+
+ private int copyRecordOrEvent(
+ MemorySegment targetSegment,
+ int targetSegmentOffset,
+ int sourceSegmentIndex,
+ int sourceSegmentOffset,
+ int recordLength) {
+ if (recordRemainingBytes > 0) {
+ // skip the data already read if there is remaining partial record after
the previous
+ // copy
+ long position = (long) sourceSegmentOffset + (recordLength -
recordRemainingBytes);
+ sourceSegmentIndex += (position / bufferSize);
+ sourceSegmentOffset = (int) (position % bufferSize);
+ } else {
+ recordRemainingBytes = recordLength;
+ }
+
+ int targetSegmentSize = targetSegment.size();
+ int numBytesToCopy = Math.min(targetSegmentSize - targetSegmentOffset,
recordRemainingBytes);
+ do {
+ // move to next data buffer if all data of the current buffer has been
copied
+ if (sourceSegmentOffset == bufferSize) {
+ ++sourceSegmentIndex;
+ sourceSegmentOffset = 0;
+ }
+
+ int sourceRemainingBytes = Math.min(bufferSize - sourceSegmentOffset,
recordRemainingBytes);
+ int numBytes = Math.min(targetSegmentSize - targetSegmentOffset,
sourceRemainingBytes);
+ MemorySegment sourceSegment = buffers.get(sourceSegmentIndex);
+ sourceSegment.copyTo(sourceSegmentOffset, targetSegment,
targetSegmentOffset, numBytes);
+
+ recordRemainingBytes -= numBytes;
+ targetSegmentOffset += numBytes;
+ sourceSegmentOffset += numBytes;
+ } while (recordRemainingBytes > 0 && targetSegmentOffset <
targetSegmentSize);
+
+ return numBytesToCopy;
+ }
+
+ private void updateReadChannelAndIndexEntryAddress() {
+ // skip the channels without any data
+ while (++readOrderIndex < firstIndexEntryAddresses.length) {
+ int channelIndex = subpartitionReadOrder[readOrderIndex];
+ if ((readIndexEntryAddress = firstIndexEntryAddresses[channelIndex]) >=
0) {
+ break;
+ }
+ }
+ }
+
+ private int getSegmentIndexFromPointer(long value) {
+ return (int) (value >>> 32);
+ }
+
+ private int getSegmentOffsetFromPointer(long value) {
+ return (int) (value);
+ }
+
+ @Override
+ public long numRecords() {
+ return numTotalRecords;
+ }
+
+ @Override
+ public long numBytes() {
+ return numTotalBytes;
+ }
+
+ @Override
+ public boolean hasRemaining() {
+ return numTotalBytesRead < numTotalBytes;
+ }
+
+ @Override
+ public void finish() {
+ checkState(
+ !isFinished, "com.alibaba.flink.shuffle.plugin.transfer.SortBuffer is
already finished.");
+
+ isFinished = true;
+
+ // prepare for reading
+ updateReadChannelAndIndexEntryAddress();
+ }
+
+ @Override
+ public boolean isFinished() {
+ return isFinished;
+ }
+
+ @Override
+ public void release() {
+ // the sort buffer can be released by other threads
+ synchronized (lock) {
+ if (isReleased) {
+ return;
+ }
+
+ isReleased = true;
+
+ for (MemorySegment segment : buffers) {
+ bufferPool.recycle(segment);
+ }
+ buffers.clear();
+
+ numTotalBytes = 0;
+ numTotalRecords = 0;
+ }
+ }
+
+ @Override
+ public boolean isReleased() {
+ synchronized (lock) {
+ return isReleased;
+ }
+ }
+}
diff --git
a/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
new file mode 100644
index 00000000..7dd43f81
--- /dev/null
+++
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.buffer;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+
+/**
+ * Data of different channels can be appended to a {@link SortBuffer}., after
apending finished,
+ * data can be copied from it in channel index order.
+ */
+public interface SortBuffer {
+
+ /**
+ * Appends data of the specified channel to this {@link SortBuffer} and
returns true if all bytes
+ * of the source buffer is copied to this {@link SortBuffer} successfully,
otherwise if returns
+ * false, nothing will be copied.
+ */
+ boolean append(ByteBuffer source, int targetChannel, Buffer.DataType
dataType) throws IOException;
+
+ /**
+ * Copies data from this {@link SortBuffer} to the target {@link
MemorySegment} in channel index
+ * order and returns {@link BufferWithChannel} which contains the copied
data and the
+ * corresponding channel index.
+ */
+ BufferWithChannel copyIntoSegment(MemorySegment target, BufferRecycler
recycler, int offset);
+
+ /** Returns the number of records written to this {@link SortBuffer}. */
+ long numRecords();
+
+ /** Returns the number of bytes written to this {@link SortBuffer}. */
+ long numBytes();
+
+ /** Returns true if there is still data can be consumed in this {@link
SortBuffer}. */
+ boolean hasRemaining();
+
+ /** Finishes this {@link SortBuffer} which means no record can be appended
any more. */
+ void finish();
+
+ /** Whether this {@link SortBuffer} is finished or not. */
+ boolean isFinished();
+
+ /** Releases this {@link SortBuffer} which releases all resources. */
+ void release();
+
+ /** Whether this {@link SortBuffer} is released or not. */
+ boolean isReleased();
+
+ /** Buffer and the corresponding channel index returned to reader. */
+ class BufferWithChannel {
+
+ private final Buffer buffer;
+
+ private final int channelIndex;
+
+ BufferWithChannel(Buffer buffer, int channelIndex) {
+ this.buffer = checkNotNull(buffer);
+ this.channelIndex = channelIndex;
+ }
+
+ /** Get {@link Buffer}. */
+ public Buffer getBuffer() {
+ return buffer;
+ }
+
+ /** Get channel index. */
+ public int getChannelIndex() {
+ return channelIndex;
+ }
+ }
+}