This is an automated email from the ASF dual-hosted git repository.

panyuepeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 7b1e2129a65 [FLINK-38943][runtime] Support Adaptive Partition 
Selection for RescalePartitioner & RebalancePartitioner (#27446)
7b1e2129a65 is described below

commit 7b1e2129a6565b0edcce6f6990dea74df23df9b3
Author: Yuepeng Pan <[email protected]>
AuthorDate: Tue Feb 10 11:19:05 2026 +0800

    [FLINK-38943][runtime] Support Adaptive Partition Selection for 
RescalePartitioner & RebalancePartitioner (#27446)
    
    Co-authored-by: Tartarus0zm <[email protected]>
    Co-authored-by: 1996fanrui <[email protected]>
---
 .../generated/all_taskmanager_network_section.html |  12 ++
 .../netty_shuffle_environment_configuration.html   |  12 ++
 .../NettyShuffleEnvironmentOptions.java            |  28 +++
 .../api/writer/AdaptiveLoadBasedRecordWriter.java  | 139 +++++++++++++
 .../io/network/api/writer/RecordWriterBuilder.java |  23 ++-
 .../network/api/writer/ResultPartitionWriter.java  |   8 +
 .../runtime/io/network/buffer/BufferPool.java      |   5 +
 .../runtime/io/network/buffer/LocalBufferPool.java |   5 +
 .../partition/BufferWritingResultPartition.java    |  32 ++-
 .../io/network/partition/ResultPartition.java      |   4 +
 .../flink/streaming/runtime/tasks/StreamTask.java  |  27 +++
 .../writer/AdaptiveLoadBasedRecordWriterTest.java  | 224 +++++++++++++++++++++
 .../streaming/runtime/tasks/StreamTaskTest.java    |  23 +++
 13 files changed, 531 insertions(+), 11 deletions(-)

diff --git 
a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html 
b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
index 0036d781c12..7299bbab326 100644
--- a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
+++ b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
@@ -8,6 +8,18 @@
         </tr>
     </thead>
     <tbody>
+        <tr>
+            <td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
+            <td style="word-wrap: break-word;">false</td>
+            <td>Boolean</td>
+            <td>Whether to enable adaptive partitioner feature for rescale and 
rebalance partitioners based on the load of the downstream tasks.</td>
+        </tr>
+        <tr>
+            
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
+            <td style="word-wrap: break-word;">4</td>
+            <td>Integer</td>
+            <td>Maximum number of channels to traverse when looking for the 
most idle channel for rescale and rebalance partitioners when <code 
class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code>
 is enabled.<br />Note, the value of the configuration option must be greater 
than `1`.</td>
+        </tr>
         <tr>
             <td><h5>taskmanager.network.compression.codec</h5></td>
             <td style="word-wrap: break-word;">LZ4</td>
diff --git 
a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
 
b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
index 3e6012bea1d..7e851456b0d 100644
--- 
a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
+++ 
b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
@@ -26,6 +26,18 @@
             <td>Boolean</td>
             <td>Enable SSL support for the taskmanager data transport. This is 
applicable only when the global flag for internal SSL 
(security.ssl.internal.enabled) is set to true</td>
         </tr>
+        <tr>
+            <td><h5>taskmanager.network.adaptive-partitioner.enabled</h5></td>
+            <td style="word-wrap: break-word;">false</td>
+            <td>Boolean</td>
+            <td>Whether to enable adaptive partitioner feature for rescale and 
rebalance partitioners based on the load of the downstream tasks.</td>
+        </tr>
+        <tr>
+            
<td><h5>taskmanager.network.adaptive-partitioner.max-traverse-size</h5></td>
+            <td style="word-wrap: break-word;">4</td>
+            <td>Integer</td>
+            <td>Maximum number of channels to traverse when looking for the 
most idle channel for rescale and rebalance partitioners when <code 
class="highlighter-rouge">taskmanager.network.adaptive-partitioner.enabled</code>
 is enabled.<br />Note, the value of the configuration option must be greater 
than `1`.</td>
+        </tr>
         <tr>
             <td><h5>taskmanager.network.compression.codec</h5></td>
             <td style="word-wrap: break-word;">LZ4</td>
diff --git 
a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
 
b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
index 018b11e5ecd..d9be6cb1b27 100644
--- 
a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
+++ 
b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
@@ -325,6 +325,34 @@ public class NettyShuffleEnvironmentOptions {
                                             
code(NETWORK_REQUEST_BACKOFF_MAX.key()))
                                     .build());
 
+    /** Whether to improve the rebalance and rescale partitioners to adaptive 
partition. */
+    @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+    public static final ConfigOption<Boolean> ADAPTIVE_PARTITIONER_ENABLED =
+            key("taskmanager.network.adaptive-partitioner.enabled")
+                    .booleanType()
+                    .defaultValue(false)
+                    .withDescription(
+                            "Whether to enable adaptive partitioner feature 
for rescale and rebalance partitioners based on the load of the downstream 
tasks.");
+
+    /**
+     * Maximum number of channels to traverse when looking for the most idle 
channel for rescale and
+     * rebalance partitioners when {@link #ADAPTIVE_PARTITIONER_ENABLED} is 
true.
+     */
+    @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+    public static final ConfigOption<Integer> 
ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE =
+            key("taskmanager.network.adaptive-partitioner.max-traverse-size")
+                    .intType()
+                    .defaultValue(4)
+                    .withDescription(
+                            Description.builder()
+                                    .text(
+                                            "Maximum number of channels to 
traverse when looking for the most idle channel for rescale and rebalance 
partitioners when %s is enabled.",
+                                            
code(ADAPTIVE_PARTITIONER_ENABLED.key()))
+                                    .linebreak()
+                                    .text(
+                                            "Note, the value of the 
configuration option must be greater than `1`.")
+                                    .build());
+
     // ------------------------------------------------------------------------
 
     /** Not intended to be instantiated. */
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
new file mode 100644
index 00000000000..19eb08e6571
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriter.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api.writer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.io.IOReadableWritable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * A record writer based on load of downstream tasks for {@link
+ * org.apache.flink.streaming.runtime.partitioner.RescalePartitioner} and 
{@link
+ * org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner}.
+ *
+ * <pre>
+ *
+ * Here are clarifications for some items to provide quick understanding.
+ *
+ * - Two new immutable attributes are introduced in this class:
+ *   -- `numberOfSubpartitions` represents the number of downstream partitions 
that can be written to.
+ *   -- `maxTraverseSize` represents the maximum number of partitions that the 
current partition selector can compare when performing rescale or rebalance.
+ *
+ * - Why do `maxTraverseSize` and `numberOfSubpartitions` not share a common 
attribute ?
+ *   If the same field were shared and `maxTraverseSize` were less than 
`numberOfSubpartitions` (e.g., 2 < 6), it would result in some downstream 
partitions (4 in this case) never being written to, which is incorrect behavior.
+ *
+ * - Why is it described that users cannot explicitly configure 
`maxTraverseSize` as 1 ?
+ *   Users should not explicitly set it to 1, as this would mean no load 
comparison is performed, effectively disabling the adaptive partitioning 
feature.
+ *
+ * - Why the internal value of `maxTraverseSize` may become 1:
+ *   This is reasonable if and only if the number of downstream partitions is 
exactly 1 (since no comparison is needed). This situation can arise from 
framework behaviors such as the {@link 
org.apache.flink.runtime.scheduler.adaptive.AdaptiveScheduler}, which are not 
directly controlled by users.
+ *    For example, when the following job enables the AdaptiveScheduler before 
rescaling:
+ *
+ *      JobVertexA(parallelism=4, slotSharingGroup=SSG-A) --(rescale)--> 
JobVertexB(parallelism=5, slotSharingGroup=SSG-B)
+ *
+ *      If the job scales down and only 2 slots are available, the parallelism 
configuration of the job changes to:
+ *
+ *      JobVertexA(parallelism=1, slotSharingGroup=SSG-A) --(rescale)--> 
JobVertexB(parallelism=1, slotSharingGroup=SSG-B)
+ *
+ *    In this case, the task of JobVertexA has only one writable downstream 
partition, so a `maxTraverseSize` of 1 is reasonable and meaningful.
+ *
+ * </pre>
+ *
+ * @param <T> The type of IOReadableWritable records.
+ */
+@Internal
+public final class AdaptiveLoadBasedRecordWriter<T extends IOReadableWritable>
+        extends RecordWriter<T> {
+
+    private final int maxTraverseSize;
+    private final int numberOfSubpartitions;
+    private int currentChannel = -1;
+
+    AdaptiveLoadBasedRecordWriter(
+            ResultPartitionWriter writer, long timeout, String taskName, int 
maxTraverseSize) {
+        super(writer, timeout, taskName);
+        this.numberOfSubpartitions = writer.getNumberOfSubpartitions();
+        this.maxTraverseSize = Math.min(maxTraverseSize, 
numberOfSubpartitions);
+    }
+
+    @Override
+    public void emit(T record) throws IOException {
+        checkErroneous();
+
+        currentChannel = getIdlestChannelIndex();
+
+        ByteBuffer byteBuffer = serializeRecord(serializer, record);
+        targetPartition.emitRecord(byteBuffer, currentChannel);
+
+        if (flushAlways) {
+            targetPartition.flush(currentChannel);
+        }
+    }
+
+    @VisibleForTesting
+    int getIdlestChannelIndex() {
+        int bestChannelBuffersCount = Integer.MAX_VALUE;
+        long bestChannelBytesInQueue = Long.MAX_VALUE;
+        int bestChannel = 0;
+        for (int i = 1; i <= maxTraverseSize; i++) {
+            int candidateChannel = (currentChannel + i) % 
numberOfSubpartitions;
+            int candidateChannelBuffersCount =
+                    targetPartition.getBuffersCountUnsafe(candidateChannel);
+            long candidateChannelBytesInQueue =
+                    targetPartition.getBytesInQueueUnsafe(candidateChannel);
+
+            if (candidateChannelBuffersCount == 0) {
+                // If there isn't any pending data in the current channel, 
choose this channel
+                // directly.
+                return candidateChannel;
+            }
+
+            if (candidateChannelBuffersCount < bestChannelBuffersCount
+                    || (candidateChannelBuffersCount == bestChannelBuffersCount
+                            && candidateChannelBytesInQueue < 
bestChannelBytesInQueue)) {
+                bestChannel = candidateChannel;
+                bestChannelBuffersCount = candidateChannelBuffersCount;
+                bestChannelBytesInQueue = candidateChannelBytesInQueue;
+            }
+        }
+        return bestChannel;
+    }
+
+    /** Copy from {@link ChannelSelectorRecordWriter#broadcastEmit}. */
+    @Override
+    public void broadcastEmit(T record) throws IOException {
+        checkErroneous();
+
+        // Emitting to all channels in a for loop can be better than calling
+        // ResultPartitionWriter#broadcastRecord because the broadcastRecord
+        // method incurs extra overhead.
+        ByteBuffer serializedRecord = serializeRecord(serializer, record);
+        for (int channelIndex = 0; channelIndex < numberOfSubpartitions; 
channelIndex++) {
+            serializedRecord.rewind();
+            emit(record, channelIndex);
+        }
+
+        if (flushAlways) {
+            flushAll();
+        }
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
index 78e6424844d..d730a73a7fe 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterBuilder.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.api.writer;
 
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
 import org.apache.flink.core.io.IOReadableWritable;
 
 /** Utility class to encapsulate the logic of building a {@link RecordWriter} 
instance. */
@@ -29,6 +30,11 @@ public class RecordWriterBuilder<T extends 
IOReadableWritable> {
 
     private String taskName = "test";
 
+    private boolean enabledAdaptivePartitioner = false;
+
+    private int maxTraverseSize =
+            
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.defaultValue();
+
     public RecordWriterBuilder<T> setChannelSelector(ChannelSelector<T> 
selector) {
         this.selector = selector;
         return this;
@@ -44,11 +50,24 @@ public class RecordWriterBuilder<T extends 
IOReadableWritable> {
         return this;
     }
 
+    public RecordWriterBuilder<T> setEnabledAdaptivePartitioner(
+            boolean enabledAdaptivePartitioner) {
+        this.enabledAdaptivePartitioner = enabledAdaptivePartitioner;
+        return this;
+    }
+
+    public RecordWriterBuilder<T> setMaxTraverseSize(int maxTraverseSize) {
+        this.maxTraverseSize = maxTraverseSize;
+        return this;
+    }
+
     public RecordWriter<T> build(ResultPartitionWriter writer) {
         if (selector.isBroadcast()) {
             return new BroadcastRecordWriter<>(writer, timeout, taskName);
-        } else {
-            return new ChannelSelectorRecordWriter<>(writer, selector, 
timeout, taskName);
         }
+        if (enabledAdaptivePartitioner) {
+            return new AdaptiveLoadBasedRecordWriter<>(writer, timeout, 
taskName, maxTraverseSize);
+        }
+        return new ChannelSelectorRecordWriter<>(writer, selector, timeout, 
taskName);
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index e283fac596f..04cfa0ad33d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -60,6 +60,14 @@ public interface ResultPartitionWriter extends 
AutoCloseable, AvailabilityProvid
     /** Writes the given serialized record to the target subpartition. */
     void emitRecord(ByteBuffer record, int targetSubpartition) throws 
IOException;
 
+    default long getBytesInQueueUnsafe(int targetSubpartition) {
+        return 0;
+    }
+
+    default int getBuffersCountUnsafe(int targetSubpartition) {
+        return 0;
+    }
+
     /**
      * Writes the given serialized record to all subpartitions. One can also 
achieve the same effect
      * by emitting the same record to all subpartitions one by one, however, 
this method can have
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
index c574607e28e..5061d08353d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/BufferPool.java
@@ -75,4 +75,9 @@ public interface BufferPool extends BufferProvider, 
BufferRecycler {
 
     /** Returns the number of used buffers of this buffer pool. */
     int bestEffortGetNumOfUsedBuffers();
+
+    /** Returns the requested buffer count for target channel. */
+    default int getBuffersCountUnsafe(int targetChannel) {
+        return 0;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
index 873414c6fe2..f31bd95f1a1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/LocalBufferPool.java
@@ -824,4 +824,9 @@ public class LocalBufferPool implements BufferPool {
             }
         }
     }
+
+    @Override
+    public int getBuffersCountUnsafe(int targetChannel) {
+        return subpartitionBuffersCount[targetChannel];
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
index eae9260642a..334647a367b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BufferWritingResultPartition.java
@@ -65,7 +65,7 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
 
     private TimerGauge hardBackPressuredTimeMsPerSecond = new TimerGauge();
 
-    private long totalWrittenBytes;
+    private final long[] writtenBytesPerSubpartition;
 
     public BufferWritingResultPartition(
             String owningTaskName,
@@ -91,6 +91,7 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
 
         this.subpartitions = checkNotNull(subpartitions);
         this.unicastBufferBuilders = new BufferBuilder[subpartitions.length];
+        this.writtenBytesPerSubpartition = new long[subpartitions.length];
     }
 
     @Override
@@ -114,6 +115,11 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
 
     @Override
     public long getSizeOfQueuedBuffersUnsafe() {
+        long totalWrittenBytes = 0;
+        for (int i = 0; i < subpartitions.length; i++) {
+            totalWrittenBytes += writtenBytesPerSubpartition[i];
+        }
+
         long totalNumberOfBytes = 0;
 
         for (ResultSubpartition subpartition : subpartitions) {
@@ -123,6 +129,12 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
         return totalWrittenBytes - totalNumberOfBytes;
     }
 
+    @Override
+    public long getBytesInQueueUnsafe(int targetSubpartition) {
+        return writtenBytesPerSubpartition[targetSubpartition]
+                - 
subpartitions[targetSubpartition].getTotalNumberOfBytesUnsafe();
+    }
+
     @Override
     public int getNumberOfQueuedBuffers(int targetSubpartition) {
         checkArgument(targetSubpartition >= 0 && targetSubpartition < 
numSubpartitions);
@@ -151,7 +163,7 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
 
     @Override
     public void emitRecord(ByteBuffer record, int targetSubpartition) throws 
IOException {
-        totalWrittenBytes += record.remaining();
+        writtenBytesPerSubpartition[targetSubpartition] += record.remaining();
 
         BufferBuilder buffer = appendUnicastDataForNewRecord(record, 
targetSubpartition);
 
@@ -171,7 +183,9 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
 
     @Override
     public void broadcastRecord(ByteBuffer record) throws IOException {
-        totalWrittenBytes += ((long) record.remaining() * numSubpartitions);
+        for (int i = 0; i < subpartitions.length; i++) {
+            writtenBytesPerSubpartition[i] += record.remaining();
+        }
 
         BufferBuilder buffer = appendBroadcastDataForNewRecord(record);
 
@@ -197,11 +211,11 @@ public abstract class BufferWritingResultPartition 
extends ResultPartition {
 
         try (BufferConsumer eventBufferConsumer =
                 EventSerializer.toBufferConsumer(event, isPriorityEvent)) {
-            totalWrittenBytes += ((long) eventBufferConsumer.getWrittenBytes() 
* numSubpartitions);
-            for (ResultSubpartition subpartition : subpartitions) {
+            for (int i = 0; i < subpartitions.length; i++) {
                 // Retain the buffer so that it can be recycled by each 
subpartition of
                 // targetPartition
-                subpartition.add(eventBufferConsumer.copy(), 0);
+                subpartitions[i].add(eventBufferConsumer.copy(), 0);
+                writtenBytesPerSubpartition[i] += 
eventBufferConsumer.getWrittenBytes();
             }
         }
     }
@@ -246,8 +260,8 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
         finishBroadcastBufferBuilder();
         finishUnicastBufferBuilders();
 
-        for (ResultSubpartition subpartition : subpartitions) {
-            totalWrittenBytes += subpartition.finish();
+        for (int i = 0; i < subpartitions.length; i++) {
+            writtenBytesPerSubpartition[i] += subpartitions[i].finish();
         }
 
         super.finish();
@@ -340,7 +354,7 @@ public abstract class BufferWritingResultPartition extends 
ResultPartition {
     protected int addToSubpartition(
             int targetSubpartition, BufferConsumer bufferConsumer, int 
partialRecordLength)
             throws IOException {
-        totalWrittenBytes += bufferConsumer.getWrittenBytes();
+        writtenBytesPerSubpartition[targetSubpartition] += 
bufferConsumer.getWrittenBytes();
         return subpartitions[targetSubpartition].add(bufferConsumer, 
partialRecordLength);
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index 6cbcfc0c598..47b52caa8d6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -202,6 +202,10 @@ public abstract class ResultPartition implements 
ResultPartitionWriter {
     /** Returns the number of queued buffers of the given target subpartition. 
*/
     public abstract int getNumberOfQueuedBuffers(int targetSubpartition);
 
+    public int getBuffersCountUnsafe(int targetSubpartition) {
+        return bufferPool.getBuffersCountUnsafe(targetSubpartition);
+    }
+
     public void setMaxOverdraftBuffersPerGate(int maxOverdraftBuffersPerGate) {
         
this.bufferPool.setMaxOverdraftBuffersPerGate(maxOverdraftBuffersPerGate);
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 505e67a30d7..07167659cb0 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.operators.MailboxExecutor;
 import 
org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
 import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
 import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.core.execution.RecoveryClaimMode;
 import org.apache.flink.core.fs.AutoCloseableRegistry;
@@ -80,6 +81,7 @@ import 
org.apache.flink.runtime.taskmanager.AsyncExceptionHandler;
 import org.apache.flink.runtime.taskmanager.AsynchronousException;
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.runtime.taskmanager.Task;
+import org.apache.flink.runtime.util.ConfigurationParserUtils;
 import org.apache.flink.streaming.api.graph.NonChainedOutput;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
@@ -95,6 +97,7 @@ import 
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHand
 import 
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer;
@@ -1830,17 +1833,41 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
                 ((ConfigurableStreamPartitioner) 
outputPartitioner).configure(numKeyGroups);
             }
         }
+        Configuration conf = environment.getJobConfiguration();
+        final boolean enabledAdaptivePartitioner =
+                (outputPartitioner instanceof RebalancePartitioner
+                                || outputPartitioner instanceof 
RescalePartitioner)
+                        && 
conf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED)
+                        && bufferWriter.getNumberOfSubpartitions() > 1;
+        final int maxTraverseSize = getAndCheckMaxTraverseSize(conf);
 
         RecordWriter<SerializationDelegate<StreamRecord<OUT>>> output =
                 new 
RecordWriterBuilder<SerializationDelegate<StreamRecord<OUT>>>()
                         .setChannelSelector(outputPartitioner)
                         .setTimeout(bufferTimeout)
                         .setTaskName(taskNameWithSubtask)
+                        
.setEnabledAdaptivePartitioner(enabledAdaptivePartitioner)
+                        .setMaxTraverseSize(maxTraverseSize)
                         .build(bufferWriter);
         output.setMetricGroup(environment.getMetricGroup().getIOMetricGroup());
         return output;
     }
 
+    @VisibleForTesting
+    static int getAndCheckMaxTraverseSize(Configuration jobConf) {
+        final int maxTraverseSize =
+                
jobConf.get(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE);
+        ConfigurationParserUtils.checkConfigParameter(
+                maxTraverseSize > 1,
+                maxTraverseSize,
+                
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(),
+                String.format(
+                        "The value of '%s' must be greater than 1 when '%s' is 
enabled.",
+                        
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE.key(),
+                        
NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_ENABLED.key()));
+        return maxTraverseSize;
+    }
+
     private void handleTimerException(Exception ex) {
         handleAsyncException("Caught exception while processing timer.", new 
TimerException(ex));
     }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
new file mode 100644
index 00000000000..b835a31a5da
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/AdaptiveLoadBasedRecordWriterTest.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api.writer;
+
+import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
+import 
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
+
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Stream;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link AdaptiveLoadBasedRecordWriter}. */
+class AdaptiveLoadBasedRecordWriterTest {
+
+    static Stream<Arguments> getTestingParams() {
+        return Stream.of(
+                // maxTraverseSize, bytesPerPartition, bufferPerPartition,
+                // targetResultPartitionIndex
+                Arguments.of(2, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4}, 
0),
+                Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4}, 
0),
+                Arguments.of(2, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0}, 
0),
+                Arguments.of(3, new long[] {1L, 2L, 3L}, new int[] {2, 3, 4}, 
0),
+                Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {2, 3, 4}, 
0),
+                Arguments.of(3, new long[] {0L, 0L, 0L}, new int[] {0, 0, 0}, 
0),
+                Arguments.of(
+                        2, new long[] {1L, 2L, 3L, 1L, 2L, 3L}, new int[] {2, 
3, 4, 2, 3, 4}, 0),
+                Arguments.of(
+                        2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {3, 
2, 4, 2, 3, 4}, 1),
+                Arguments.of(
+                        2, new long[] {0L, 0L, 3L, 1L, 2L, 3L}, new int[] {0, 
0, 4, 2, 3, 4}, 0),
+                Arguments.of(
+                        4, new long[] {1L, 2L, 3L, 0L, 2L, 3L}, new int[] {2, 
3, 4, 2, 3, 4}, 3),
+                Arguments.of(
+                        4, new long[] {1L, 1L, 1L, 1L, 2L, 3L}, new int[] {2, 
3, 4, 0, 3, 4}, 3),
+                Arguments.of(
+                        4, new long[] {0L, 0L, 0L, 0L, 2L, 3L}, new int[] {2, 
3, 0, 2, 3, 4}, 2));
+    }
+
+    @ParameterizedTest(
+            name =
+                    "maxTraverseSize: {0}, bytesPerPartition: {1}, 
bufferPerPartition: {2}, targetResultPartitionIndex: {3}")
+    @MethodSource("getTestingParams")
+    void testGetIdlestChannelIndex(
+            int maxTraverseSize,
+            long[] bytesPerPartition,
+            int[] buffersPerPartition,
+            int targetResultPartitionIndex) {
+        TestingResultPartitionWriter resultPartitionWriter =
+                getTestingResultPartitionWriter(bytesPerPartition, 
buffersPerPartition);
+
+        AdaptiveLoadBasedRecordWriter<IOReadableWritable> 
adaptiveLoadBasedRecordWriter =
+                new AdaptiveLoadBasedRecordWriter<>(
+                        resultPartitionWriter, 5L, "testingTask", 
maxTraverseSize);
+        assertThat(adaptiveLoadBasedRecordWriter.getIdlestChannelIndex())
+                .isEqualTo(targetResultPartitionIndex);
+    }
+
+    private static TestingResultPartitionWriter 
getTestingResultPartitionWriter(
+            long[] bytesPerPartition, int[] buffersPerPartition) {
+        final Map<Integer, Long> bytesPerPartitionMap = new HashMap<>();
+        final Map<Integer, Integer> bufferPerPartitionMap = new HashMap<>();
+        for (int i = 0; i < bytesPerPartition.length; i++) {
+            bytesPerPartitionMap.put(i, bytesPerPartition[i]);
+            bufferPerPartitionMap.put(i, buffersPerPartition[i]);
+        }
+
+        return new TestingResultPartitionWriter(
+                buffersPerPartition.length, bytesPerPartitionMap, 
bufferPerPartitionMap);
+    }
+
+    /** Test utils class to simulate {@link ResultPartitionWriter}. */
+    static final class TestingResultPartitionWriter implements 
ResultPartitionWriter {
+
+        private final int numberOfSubpartitions;
+        private final Map<Integer, Long> bytesPerPartition;
+        private final Map<Integer, Integer> bufferPerPartition;
+
+        TestingResultPartitionWriter(
+                int numberOfSubpartitions,
+                Map<Integer, Long> bytesPerPartition,
+                Map<Integer, Integer> bufferPerPartition) {
+            this.numberOfSubpartitions = numberOfSubpartitions;
+            this.bytesPerPartition = bytesPerPartition;
+            this.bufferPerPartition = bufferPerPartition;
+        }
+
+        // The methods that are used in the testing.
+
+        @Override
+        public long getBytesInQueueUnsafe(int targetSubpartition) {
+            return bytesPerPartition.getOrDefault(targetSubpartition, 0L);
+        }
+
+        @Override
+        public int getBuffersCountUnsafe(int targetSubpartition) {
+            return bufferPerPartition.getOrDefault(targetSubpartition, 0);
+        }
+
+        @Override
+        public int getNumberOfSubpartitions() {
+            return numberOfSubpartitions;
+        }
+
+        // The methods that are not used.
+
+        @Override
+        public void setup() throws IOException {}
+
+        @Override
+        public ResultPartitionID getPartitionId() {
+            return null;
+        }
+
+        @Override
+        public int getNumTargetKeyGroups() {
+            return 0;
+        }
+
+        @Override
+        public void setMaxOverdraftBuffersPerGate(int 
maxOverdraftBuffersPerGate) {}
+
+        @Override
+        public void emitRecord(ByteBuffer record, int targetSubpartition) 
throws IOException {}
+
+        @Override
+        public void broadcastRecord(ByteBuffer record) throws IOException {}
+
+        @Override
+        public void broadcastEvent(AbstractEvent event, boolean 
isPriorityEvent)
+                throws IOException {}
+
+        @Override
+        public void alignedBarrierTimeout(long checkpointId) throws 
IOException {}
+
+        @Override
+        public void abortCheckpoint(long checkpointId, CheckpointException 
cause) {}
+
+        @Override
+        public void notifyEndOfData(StopMode mode) throws IOException {}
+
+        @Override
+        public CompletableFuture<Void> getAllDataProcessedFuture() {
+            return null;
+        }
+
+        @Override
+        public void setMetricGroup(TaskIOMetricGroup metrics) {}
+
+        @Override
+        public ResultSubpartitionView createSubpartitionView(
+                ResultSubpartitionIndexSet indexSet,
+                BufferAvailabilityListener availabilityListener)
+                throws IOException {
+            return null;
+        }
+
+        @Override
+        public void flushAll() {}
+
+        @Override
+        public void flush(int subpartitionIndex) {}
+
+        @Override
+        public void fail(@Nullable Throwable throwable) {}
+
+        @Override
+        public void finish() throws IOException {}
+
+        @Override
+        public boolean isFinished() {
+            return false;
+        }
+
+        @Override
+        public void release(Throwable cause) {}
+
+        @Override
+        public boolean isReleased() {
+            return false;
+        }
+
+        @Override
+        public void close() throws Exception {}
+
+        @Override
+        public CompletableFuture<?> getAvailableFuture() {
+            return null;
+        }
+    }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index de7c3d6f5a8..1b42449f539 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -28,6 +28,8 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
 import org.apache.flink.configuration.ReadableConfig;
 import org.apache.flink.core.execution.SavepointFormatType;
 import org.apache.flink.core.fs.FSDataInputStream;
@@ -1885,6 +1887,27 @@ public class StreamTaskTest {
         }
     }
 
+    @Test
+    void testGetAndCheckMaxTraverseSize() {
+        Configuration config = new Configuration();
+        assertThat(StreamTask.getAndCheckMaxTraverseSize(config)).isEqualTo(4);
+
+        
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
 2);
+        assertThat(StreamTask.getAndCheckMaxTraverseSize(config)).isEqualTo(2);
+
+        
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
 -1);
+        assertThatThrownBy(() -> StreamTask.getAndCheckMaxTraverseSize(config))
+                .isInstanceOf(IllegalConfigurationException.class);
+
+        
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
 0);
+        assertThatThrownBy(() -> StreamTask.getAndCheckMaxTraverseSize(config))
+                .isInstanceOf(IllegalConfigurationException.class);
+
+        
config.set(NettyShuffleEnvironmentOptions.ADAPTIVE_PARTITIONER_MAX_TRAVERSE_SIZE,
 1);
+        assertThatThrownBy(() -> StreamTask.getAndCheckMaxTraverseSize(config))
+                .isInstanceOf(IllegalConfigurationException.class);
+    }
+
     private int getCurrentBufferSize(InputGate inputGate) {
         return getTestChannel(inputGate, 0).getCurrentBufferSize();
     }


Reply via email to