This is an automated email from the ASF dual-hosted git repository.

guoweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit daf476f76bd3ff58e34a01848a76ee1c985b2fc3
Author: Yuxin Tan <tanyuxinw...@gmail.com>
AuthorDate: Thu May 25 15:12:20 2023 +0800

    [FLINK-31640][network] Write the accumulated buffers to the right storage 
tier
---
 .../tiered/shuffle/TieredResultPartition.java      |   9 +
 .../storage/TieredStorageProducerClient.java       | 113 ++++++++-
 .../TieredStorageProducerMetricUpdate.java}        |  33 ++-
 .../hybrid/tiered/tier/TierProducerAgent.java      |  31 ++-
 .../hybrid/tiered/TestingBufferAccumulator.java    |  23 +-
 .../hybrid/tiered/TestingTierProducerAgent.java    |  99 ++++++++
 ...ccumulator.java => TieredStorageTestUtils.java} |  27 +--
 .../tiered/shuffle/TieredResultPartitionTest.java  |   6 +-
 .../tiered/storage/HashBufferAccumulatorTest.java  |   7 +-
 .../storage/TieredStorageProducerClientTest.java   | 253 +++++++++++++++++++++
 10 files changed, 540 insertions(+), 61 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java
index cd24ac89577..5ddfb2cd90c 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartition.java
@@ -36,6 +36,7 @@ import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageIdMappingUtils;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageProducerClient;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageProducerMetricUpdate;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.util.concurrent.FutureUtils;
@@ -102,6 +103,8 @@ public class TieredResultPartition extends ResultPartition {
     @Override
     public void setMetricGroup(TaskIOMetricGroup metrics) {
         super.setMetricGroup(metrics);
+        tieredStorageProducerClient.setMetricStatisticsUpdater(
+                this::updateProducerMetricStatistics);
     }
 
     @Override
@@ -139,6 +142,12 @@ public class TieredResultPartition extends ResultPartition 
{
                 record, TieredStorageIdMappingUtils.convertId(consumerId), 
dataType, isBroadcast);
     }
 
+    private void updateProducerMetricStatistics(
+            TieredStorageProducerMetricUpdate metricStatistics) {
+        numBuffersOut.inc(metricStatistics.numWriteBuffersDelta());
+        numBytesOut.inc(metricStatistics.numWriteBytesDelta());
+    }
+
     @Override
     public ResultSubpartitionView createSubpartitionView(
             int subpartitionId, BufferAvailabilityListener 
availabilityListener)
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClient.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClient.java
index 299c71a7bfc..d6b72b25c39 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClient.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClient.java
@@ -28,10 +28,17 @@ import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
+import java.util.function.Consumer;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** Client of the Tiered Storage used by the producer. */
 public class TieredStorageProducerClient {
+
     private final boolean isBroadcastOnly;
 
     private final int numSubpartitions;
@@ -40,8 +47,24 @@ public class TieredStorageProducerClient {
 
     private final BufferCompressor bufferCompressor;
 
+    /**
+     * Note that the {@link TierProducerAgent}s are sorted by priority, with a 
lower index
+     * indicating a higher priority.
+     */
     private final List<TierProducerAgent> tierProducerAgents;
 
+    /** The current writing segment index for each subpartition. */
+    private final int[] currentSubpartitionSegmentId;
+
+    /** The current writing storage tier for each subpartition. */
+    private final TierProducerAgent[] currentSubpartitionTierAgent;
+
+    /**
+     * The metric statistics for producer client. Note that it is necessary to 
check whether the
+     * value is null before used.
+     */
+    @Nullable private Consumer<TieredStorageProducerMetricUpdate> 
metricStatisticsUpdater;
+
     public TieredStorageProducerClient(
             int numSubpartitions,
             boolean isBroadcastOnly,
@@ -53,6 +76,10 @@ public class TieredStorageProducerClient {
         this.bufferAccumulator = bufferAccumulator;
         this.bufferCompressor = bufferCompressor;
         this.tierProducerAgents = tierProducerAgents;
+        this.currentSubpartitionSegmentId = new int[numSubpartitions];
+        this.currentSubpartitionTierAgent = new 
TierProducerAgent[numSubpartitions];
+
+        Arrays.fill(currentSubpartitionSegmentId, -1);
 
         bufferAccumulator.setup(this::writeAccumulatedBuffers);
     }
@@ -92,6 +119,11 @@ public class TieredStorageProducerClient {
         }
     }
 
+    public void setMetricStatisticsUpdater(
+            Consumer<TieredStorageProducerMetricUpdate> 
metricStatisticsUpdater) {
+        this.metricStatisticsUpdater = checkNotNull(metricStatisticsUpdater);
+    }
+
     public void close() {
         bufferAccumulator.close();
         tierProducerAgents.forEach(TierProducerAgent::close);
@@ -105,26 +137,93 @@ public class TieredStorageProducerClient {
      */
     private void writeAccumulatedBuffers(
             TieredStorageSubpartitionId subpartitionId, List<Buffer> 
accumulatedBuffers) {
-        try {
-            for (Buffer finishedBuffer : accumulatedBuffers) {
-                writeAccumulatedBuffer(subpartitionId, finishedBuffer);
+        Iterator<Buffer> bufferIterator = accumulatedBuffers.iterator();
+
+        int numWriteBytes = 0;
+        int numWriteBuffers = 0;
+        while (bufferIterator.hasNext()) {
+            Buffer buffer = bufferIterator.next();
+            numWriteBuffers++;
+            numWriteBytes += buffer.readableBytes();
+            try {
+                writeAccumulatedBuffer(subpartitionId, buffer);
+            } catch (IOException ioe) {
+                buffer.recycleBuffer();
+                while (bufferIterator.hasNext()) {
+                    bufferIterator.next().recycleBuffer();
+                }
+                ExceptionUtils.rethrow(ioe);
             }
-        } catch (IOException e) {
-            ExceptionUtils.rethrow(e);
         }
+        updateMetricStatistics(numWriteBuffers, numWriteBytes);
     }
 
     /**
      * Write the accumulated buffer of this subpartitionId to an appropriate 
tier. After the tier is
      * decided, the buffer will be written to the selected tier.
      *
+     * <p>Note that the method only throws an exception when choosing a 
storage tier, so the caller
+     * should ensure that the buffer is recycled when throwing an exception.
+     *
      * @param subpartitionId the subpartition identifier
      * @param accumulatedBuffer one accumulated buffer of this subpartition
      */
     private void writeAccumulatedBuffer(
             TieredStorageSubpartitionId subpartitionId, Buffer 
accumulatedBuffer)
             throws IOException {
-        // TODO, Try to write the accumulated buffer to the appropriate tier. 
After the tier is
-        // decided, then write the accumulated buffer to the tier.
+        Buffer compressedBuffer = compressBufferIfPossible(accumulatedBuffer);
+
+        if (currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()] 
== null) {
+            chooseStorageTierToStartSegment(subpartitionId);
+        }
+
+        if 
(!currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()].tryWrite(
+                subpartitionId, compressedBuffer)) {
+            chooseStorageTierToStartSegment(subpartitionId);
+            checkState(
+                    
currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()].tryWrite(
+                            subpartitionId, compressedBuffer),
+                    "Failed to write the first buffer to the new segment");
+        }
+    }
+
+    private void chooseStorageTierToStartSegment(TieredStorageSubpartitionId 
subpartitionId)
+            throws IOException {
+        int subpartitionIndex = subpartitionId.getSubpartitionId();
+        int segmentIndex = currentSubpartitionSegmentId[subpartitionIndex];
+        int nextSegmentIndex = segmentIndex + 1;
+
+        for (TierProducerAgent tierProducerAgent : tierProducerAgents) {
+            if (tierProducerAgent.tryStartNewSegment(subpartitionId, 
nextSegmentIndex)) {
+                // Update the segment index and the chosen storage tier for 
the subpartition.
+                currentSubpartitionSegmentId[subpartitionIndex] = 
nextSegmentIndex;
+                currentSubpartitionTierAgent[subpartitionIndex] = 
tierProducerAgent;
+                return;
+            }
+        }
+        throw new IOException("Failed to choose a storage tier to start a new 
segment.");
+    }
+
+    private Buffer compressBufferIfPossible(Buffer buffer) {
+        if (!canBeCompressed(buffer)) {
+            return buffer;
+        }
+
+        return checkNotNull(bufferCompressor).compressToOriginalBuffer(buffer);
+    }
+
+    /**
+     * Whether the buffer can be compressed or not. Note that event is not 
compressed because it is
+     * usually small and the size can become even larger after compression.
+     */
+    private boolean canBeCompressed(Buffer buffer) {
+        return bufferCompressor != null && buffer.isBuffer() && 
buffer.readableBytes() > 0;
+    }
+
+    private void updateMetricStatistics(int numWriteBuffersDelta, int 
numWriteBytesDelta) {
+        checkNotNull(metricStatisticsUpdater)
+                .accept(
+                        new TieredStorageProducerMetricUpdate(
+                                numWriteBuffersDelta, numWriteBytesDelta));
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerMetricUpdate.java
similarity index 51%
copy from 
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerMetricUpdate.java
index adc9481f4b3..d2351504e84 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerMetricUpdate.java
@@ -16,28 +16,25 @@
  * limitations under the License.
  */
 
-package org.apache.flink.runtime.io.network.partition.hybrid.tiered;
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
 
-import org.apache.flink.runtime.io.network.buffer.Buffer;
-import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
-import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.BufferAccumulator;
+/** The metric statistics for the tiered storage producer. */
+public class TieredStorageProducerMetricUpdate {
 
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.List;
-import java.util.function.BiConsumer;
+    private final int numWriteBuffersDelta;
 
-/** Test implementation for {@link BufferAccumulator}. */
-public class TestingBufferAccumulator implements BufferAccumulator {
+    private final int numWriteBytesDelta;
 
-    @Override
-    public void setup(BiConsumer<TieredStorageSubpartitionId, List<Buffer>> 
bufferFlusher) {}
+    TieredStorageProducerMetricUpdate(int numWriteBuffersDelta, int 
numWriteBytesDelta) {
+        this.numWriteBuffersDelta = numWriteBuffersDelta;
+        this.numWriteBytesDelta = numWriteBytesDelta;
+    }
 
-    @Override
-    public void receive(
-            ByteBuffer record, TieredStorageSubpartitionId subpartitionId, 
Buffer.DataType dataType)
-            throws IOException {}
+    public int numWriteBuffersDelta() {
+        return numWriteBuffersDelta;
+    }
 
-    @Override
-    public void close() {}
+    public int numWriteBytesDelta() {
+        return numWriteBytesDelta;
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java
index 9c39e8e05ef..5c12f301af6 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java
@@ -21,10 +21,16 @@ package 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
 
-import java.io.IOException;
-
-/** The producer-side agent of a Tier. */
-public interface TierProducerAgent {
+/**
+ * The producer-side agent of a Tier.
+ *
+ * <p>Note that when writing a buffer to a tier, the {@link TierProducerAgent} 
should first call
+ * {@code tryStartNewSegment} to start a new segment. The agent can then 
continue writing the buffer
+ * to the tier as long as the return value of {@code write} is true. If the 
return value of {@code
+ * write} is false, it indicates that the current segment can no longer store 
the buffer, and the
+ * agent should try to start a new segment before writing the buffer.
+ */
+public interface TierProducerAgent extends AutoCloseable {
 
     /**
      * Try to start a new segment in the Tier.
@@ -35,9 +41,20 @@ public interface TierProducerAgent {
      */
     boolean tryStartNewSegment(TieredStorageSubpartitionId subpartitionId, int 
segmentId);
 
-    /** Writes the finished {@link Buffer} to the consumer. */
-    boolean write(TieredStorageSubpartitionId subpartitionId, Buffer 
finishedBuffer)
-            throws IOException;
+    /**
+     * Writes the finished {@link Buffer} to the consumer.
+     *
+     * <p>Note that the method is successfully executed (without throwing any 
exception), the buffer
+     * should be released by the caller, otherwise the tier should be 
responsible to recycle the
+     * buffer.
+     *
+     * @param subpartitionId the subpartition id that the buffer is writing to
+     * @param finishedBuffer the writing buffer
+     * @return return true if the buffer is written successfully, return false 
if the current
+     *     segment can not store this buffer and the current segment is 
finished. When returning
+     *     false, the agent should try start a new segment before writing the 
buffer.
+     */
+    boolean tryWrite(TieredStorageSubpartitionId subpartitionId, Buffer 
finishedBuffer);
 
     /**
      * Close the agent.
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
index adc9481f4b3..d9971101322 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
@@ -18,25 +18,44 @@
 
 package org.apache.flink.runtime.io.network.partition.hybrid.tiered;
 
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.BufferAccumulator;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.Collections;
 import java.util.List;
 import java.util.function.BiConsumer;
 
 /** Test implementation for {@link BufferAccumulator}. */
 public class TestingBufferAccumulator implements BufferAccumulator {
 
+    private BiConsumer<TieredStorageSubpartitionId, List<Buffer>> 
bufferFlusher;
+
     @Override
-    public void setup(BiConsumer<TieredStorageSubpartitionId, List<Buffer>> 
bufferFlusher) {}
+    public void setup(BiConsumer<TieredStorageSubpartitionId, List<Buffer>> 
bufferFlusher) {
+        this.bufferFlusher = bufferFlusher;
+    }
 
     @Override
     public void receive(
             ByteBuffer record, TieredStorageSubpartitionId subpartitionId, 
Buffer.DataType dataType)
-            throws IOException {}
+            throws IOException {
+        MemorySegment recordData = MemorySegmentFactory.wrap(record.array());
+        bufferFlusher.accept(
+                subpartitionId,
+                Collections.singletonList(
+                        new NetworkBuffer(
+                                recordData,
+                                FreeingBufferRecycler.INSTANCE,
+                                dataType,
+                                recordData.size())));
+    }
 
     @Override
     public void close() {}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTierProducerAgent.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTierProducerAgent.java
new file mode 100644
index 00000000000..0dc623d8334
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingTierProducerAgent.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+
+import java.util.function.BiFunction;
+
+/** Test implementation for {@link TierProducerAgent}. */
+public class TestingTierProducerAgent implements TierProducerAgent {
+
+    private final BiFunction<TieredStorageSubpartitionId, Integer, Boolean>
+            tryStartNewSegmentSupplier;
+
+    private final BiFunction<TieredStorageSubpartitionId, Buffer, Boolean> 
tryWriterFunction;
+
+    private final Runnable closeRunnable;
+
+    private TestingTierProducerAgent(
+            BiFunction<TieredStorageSubpartitionId, Integer, Boolean> 
tryStartNewSegmentSupplier,
+            BiFunction<TieredStorageSubpartitionId, Buffer, Boolean> 
tryWriterFunction,
+            Runnable closeRunnable) {
+        this.tryStartNewSegmentSupplier = tryStartNewSegmentSupplier;
+        this.tryWriterFunction = tryWriterFunction;
+        this.closeRunnable = closeRunnable;
+    }
+
+    public static TestingTierProducerAgent.Builder builder() {
+        return new TestingTierProducerAgent.Builder();
+    }
+
+    @Override
+    public boolean tryStartNewSegment(TieredStorageSubpartitionId 
subpartitionId, int segmentId) {
+        return tryStartNewSegmentSupplier.apply(subpartitionId, segmentId);
+    }
+
+    @Override
+    public boolean tryWrite(TieredStorageSubpartitionId subpartitionId, Buffer 
finishedBuffer) {
+        return tryWriterFunction.apply(subpartitionId, finishedBuffer);
+    }
+
+    @Override
+    public void close() {
+        closeRunnable.run();
+    }
+
+    /** Builder for {@link TierProducerAgent}. */
+    public static class Builder {
+        private BiFunction<TieredStorageSubpartitionId, Integer, Boolean> 
tryStartSegmentSupplier =
+                (subpartitionId, integer) -> true;
+
+        private BiFunction<TieredStorageSubpartitionId, Buffer, Boolean> 
tryWriterFunction =
+                (subpartitionId, buffer) -> true;
+
+        private Runnable closeRunnable = () -> {};
+
+        public Builder() {}
+
+        public TestingTierProducerAgent.Builder setTryStartSegmentSupplier(
+                BiFunction<TieredStorageSubpartitionId, Integer, Boolean> 
tryStartSegmentSupplier) {
+            this.tryStartSegmentSupplier = tryStartSegmentSupplier;
+            return this;
+        }
+
+        public TestingTierProducerAgent.Builder setTryWriterFunction(
+                BiFunction<TieredStorageSubpartitionId, Buffer, Boolean> 
tryWriterFunction) {
+            this.tryWriterFunction = tryWriterFunction;
+            return this;
+        }
+
+        public TestingTierProducerAgent.Builder setCloseRunnable(Runnable 
closeRunnable) {
+            this.closeRunnable = closeRunnable;
+            return this;
+        }
+
+        public TestingTierProducerAgent build() {
+            return new TestingTierProducerAgent(
+                    tryStartSegmentSupplier, tryWriterFunction, closeRunnable);
+        }
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TieredStorageTestUtils.java
similarity index 53%
copy from 
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
copy to 
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TieredStorageTestUtils.java
index adc9481f4b3..ab910d10771 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TestingBufferAccumulator.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/TieredStorageTestUtils.java
@@ -18,26 +18,15 @@
 
 package org.apache.flink.runtime.io.network.partition.hybrid.tiered;
 
-import org.apache.flink.runtime.io.network.buffer.Buffer;
-import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
-import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.BufferAccumulator;
-
-import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.List;
-import java.util.function.BiConsumer;
-
-/** Test implementation for {@link BufferAccumulator}. */
-public class TestingBufferAccumulator implements BufferAccumulator {
-
-    @Override
-    public void setup(BiConsumer<TieredStorageSubpartitionId, List<Buffer>> 
bufferFlusher) {}
+import java.util.Random;
 
-    @Override
-    public void receive(
-            ByteBuffer record, TieredStorageSubpartitionId subpartitionId, 
Buffer.DataType dataType)
-            throws IOException {}
+/** Test utils for the tiered storage. */
+public class TieredStorageTestUtils {
 
-    @Override
-    public void close() {}
+    public static ByteBuffer generateRandomData(int dataSize, Random random) {
+        byte[] dataWritten = new byte[dataSize];
+        random.nextBytes(dataWritten);
+        return ByteBuffer.wrap(dataWritten);
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java
index 7e69f015366..5faa1210ac1 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/shuffle/TieredResultPartitionTest.java
@@ -31,6 +31,7 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingBufferAccumulator;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingTierProducerAgent;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageProducerClient;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.TieredStorageResourceRegistry;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
@@ -45,7 +46,7 @@ import org.junit.jupiter.api.io.TempDir;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
+import java.util.Collections;
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledThreadPoolExecutor;
@@ -189,6 +190,7 @@ class TieredResultPartitionTest {
     private TieredResultPartition createTieredStoreResultPartition(
             int numSubpartitions, BufferPool bufferPool, boolean 
isBroadcastOnly)
             throws IOException {
+        TestingTierProducerAgent tierProducerAgent = new 
TestingTierProducerAgent.Builder().build();
         TieredResultPartition tieredResultPartition =
                 new TieredResultPartition(
                         "TieredStoreResultPartitionTest",
@@ -205,7 +207,7 @@ class TieredResultPartitionTest {
                                 isBroadcastOnly,
                                 new TestingBufferAccumulator(),
                                 null,
-                                new ArrayList<>()),
+                                Collections.singletonList(tierProducerAgent)),
                         new TieredStorageResourceRegistry());
         taskIOMetricGroup =
                 
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/HashBufferAccumulatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/HashBufferAccumulatorTest.java
index 4cc8c3e3d16..8e69c3d24ca 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/HashBufferAccumulatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/HashBufferAccumulatorTest.java
@@ -35,6 +35,7 @@ import java.util.Collections;
 import java.util.Random;
 import java.util.concurrent.atomic.AtomicInteger;
 
+import static 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.TieredStorageTestUtils.generateRandomData;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
@@ -161,10 +162,4 @@ class HashBufferAccumulatorTest {
                 bufferPool, Collections.singletonList(new 
TieredStorageMemorySpec(this, 1)));
         return storageMemoryManager;
     }
-
-    private static ByteBuffer generateRandomData(int dataSize, Random random) {
-        byte[] dataWritten = new byte[dataSize];
-        random.nextBytes(dataWritten);
-        return ByteBuffer.wrap(dataWritten);
-    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java
new file mode 100644
index 00000000000..476bbbd58cc
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingBufferAccumulator;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.TestingTierProducerAgent;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.TieredStorageTestUtils.generateRandomData;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    @Parameter public boolean isBroadcast;
+
+    private NetworkBufferPool globalPool;
+
+    @Parameters(name = "isBroadcast={0}")
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    @BeforeEach
+    void before() {
+        globalPool = new NetworkBufferPool(NUM_TOTAL_BUFFERS, 
NETWORK_BUFFER_SIZE);
+    }
+
+    @AfterEach
+    void after() {
+        globalPool.destroy();
+    }
+
+    @TestTemplate
+    void testWriteRecordsToEmptyStorageTiers() {
+        int numSubpartitions = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                createTieredStorageProducerClient(numSubpartitions, 
Collections.emptyList());
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Failed to choose a storage tier");
+    }
+
+    @TestTemplate
+    void testWriteRecords() throws IOException {
+        int numSubpartitions = 10;
+        int numToWriteRecords = 20;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        AtomicInteger numReceivedBuffers = new AtomicInteger(0);
+        AtomicInteger numReceivedBytes = new AtomicInteger(0);
+        AtomicInteger numReceivedBuffersInTier1 = new AtomicInteger(0);
+        AtomicInteger numReceivedBuffersInTier2 = new AtomicInteger(0);
+
+        TestingTierProducerAgent tierProducerAgent1 =
+                new TestingTierProducerAgent.Builder()
+                        .setTryStartSegmentSupplier(
+                                ((subpartitionId, integer) -> 
numReceivedBuffersInTier1.get() < 1))
+                        .setTryWriterFunction(
+                                ((subpartitionId, buffer) -> {
+                                    boolean isSuccess = 
numReceivedBuffersInTier1.get() % 2 == 0;
+                                    if (isSuccess) {
+                                        numReceivedBuffers.incrementAndGet();
+                                        
numReceivedBuffersInTier1.incrementAndGet();
+                                        numReceivedBytes.set(
+                                                numReceivedBytes.get() + 
buffer.readableBytes());
+                                    }
+                                    return isSuccess;
+                                }))
+                        .build();
+        TestingTierProducerAgent tierProducerAgent2 =
+                new TestingTierProducerAgent.Builder()
+                        .setTryWriterFunction(
+                                ((subpartitionId, buffer) -> {
+                                    numReceivedBuffers.incrementAndGet();
+                                    
numReceivedBuffersInTier2.incrementAndGet();
+                                    numReceivedBytes.set(
+                                            numReceivedBytes.get() + 
buffer.readableBytes());
+                                    return true;
+                                }))
+                        .build();
+        List<TierProducerAgent> tierProducerAgents = new ArrayList<>();
+        tierProducerAgents.add(tierProducerAgent1);
+        tierProducerAgents.add(tierProducerAgent2);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                createTieredStorageProducerClient(numSubpartitions, 
tierProducerAgents);
+        TieredStorageSubpartitionId subpartitionId = new 
TieredStorageSubpartitionId(0);
+
+        for (int i = 0; i < numToWriteRecords; i++) {
+            tieredStorageProducerClient.write(
+                    generateRandomData(bufferSize, random),
+                    subpartitionId,
+                    Buffer.DataType.DATA_BUFFER,
+                    isBroadcast);
+        }
+
+        int numExpectedBytes =
+                isBroadcast
+                        ? numSubpartitions * numToWriteRecords * bufferSize
+                        : numToWriteRecords * bufferSize;
+        assertThat(numReceivedBuffersInTier1.get()).isEqualTo(1);
+        assertThat(numReceivedBuffers.get())
+                .isEqualTo(numReceivedBuffersInTier1.get() + 
numReceivedBuffersInTier2.get());
+        assertThat(numReceivedBytes.get()).isEqualTo(numExpectedBytes);
+    }
+
+    @TestTemplate
+    void testTierCanNotStartNewSegment() {
+        int numSubpartitions = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TestingTierProducerAgent tierProducerAgent =
+                new TestingTierProducerAgent.Builder()
+                        .setTryStartSegmentSupplier(((subpartitionId, integer) 
-> false))
+                        .build();
+        TieredStorageProducerClient tieredStorageProducerClient =
+                createTieredStorageProducerClient(
+                        numSubpartitions, 
Collections.singletonList(tierProducerAgent));
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Failed to choose a storage tier");
+    }
+
+    @TestTemplate
+    void testUpdateMetrics() throws IOException {
+        int numSubpartitions = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TestingTierProducerAgent tierProducerAgent = new 
TestingTierProducerAgent.Builder().build();
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions,
+                        false,
+                        new TestingBufferAccumulator(),
+                        null,
+                        Collections.singletonList(tierProducerAgent));
+
+        AtomicInteger numWriteBuffers = new AtomicInteger(0);
+        AtomicInteger numWriteBytes = new AtomicInteger(0);
+        tieredStorageProducerClient.setMetricStatisticsUpdater(
+                metricStatistics -> {
+                    numWriteBuffers.set(
+                            numWriteBuffers.get() + 
metricStatistics.numWriteBuffersDelta());
+                    numWriteBytes.set(numWriteBytes.get() + 
metricStatistics.numWriteBytesDelta());
+                });
+
+        tieredStorageProducerClient.write(
+                generateRandomData(bufferSize, random),
+                new TieredStorageSubpartitionId(0),
+                Buffer.DataType.DATA_BUFFER,
+                isBroadcast);
+
+        int numExpectedBuffers = isBroadcast ? numSubpartitions : 1;
+        int numExpectedBytes = isBroadcast ? bufferSize * numSubpartitions : 
bufferSize;
+        assertThat(numWriteBuffers.get()).isEqualTo(numExpectedBuffers);
+        assertThat(numWriteBytes.get()).isEqualTo(numExpectedBytes);
+    }
+
+    @TestTemplate
+    void testClose() {
+        int numSubpartitions = 10;
+
+        AtomicBoolean isClosed = new AtomicBoolean(false);
+        TestingTierProducerAgent tierProducerAgent =
+                new TestingTierProducerAgent.Builder()
+                        .setCloseRunnable(() -> isClosed.set(true))
+                        .build();
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                createTieredStorageProducerClient(
+                        numSubpartitions, 
Collections.singletonList(tierProducerAgent));
+
+        assertThat(isClosed.get()).isFalse();
+        tieredStorageProducerClient.close();
+        assertThat(isClosed.get()).isTrue();
+    }
+
+    private static TieredStorageProducerClient 
createTieredStorageProducerClient(
+            int numSubpartitions, List<TierProducerAgent> tierProducerAgents) {
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions,
+                        false,
+                        new TestingBufferAccumulator(),
+                        null,
+                        tierProducerAgents);
+        
tieredStorageProducerClient.setMetricStatisticsUpdater(metricStatistics -> {});
+        return tieredStorageProducerClient;
+    }
+}


Reply via email to