This is an automated email from the ASF dual-hosted git repository.

chia7712 pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2dd6126b5d8 KAFKA-18855 Slice API for MemoryRecords (#19581)
2dd6126b5d8 is described below

commit 2dd6126b5d8f2cf73470be3884f5ad2aa301b29e
Author: Apoorv Mittal <[email protected]>
AuthorDate: Thu May 8 07:02:25 2025 +0100

    KAFKA-18855 Slice API for MemoryRecords (#19581)
    
    The PR adds `slice` API in `Records.java` and further implementation in
    `MemoryRecords`. With the addition of ShareFetch and it's support to
    read from TieredStorage, where ShareFetch might acquire subset of fetch
    batches and TieredStorage emits MemoryRecords, hence a slice API is
    needed for MemoryRecords as well to limit the bytes transferred (if
    subset batches are acquired).
    
    MemoryRecords are sliced using `duplicate` and `slice` API of
    ByteBuffer, which are backed by the original buffer itself hence no-copy
    is created rather position, limit and offset are changed as per the new
    position and length.
    
    Reviewers: Andrew Schofield <[email protected]>, Jun Rao
     <[email protected]>, Chia-Ping Tsai <[email protected]>
---
 .../apache/kafka/common/record/FileRecords.java    |  15 +--
 .../apache/kafka/common/record/MemoryRecords.java  |  25 ++++
 .../org/apache/kafka/common/record/Records.java    |  16 +++
 .../kafka/common/record/FileRecordsTest.java       |  24 ++--
 .../kafka/common/record/MemoryRecordsTest.java     | 145 +++++++++++++++++++++
 .../java/kafka/server/share/ShareFetchUtils.java   |  19 ++-
 .../main/scala/kafka/tools/DumpLogSegments.scala   |   2 +-
 .../kafka/server/share/ShareFetchUtilsTest.java    |  77 +++++++----
 8 files changed, 259 insertions(+), 64 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
index 7f78235ab70..6a42a52d2e0 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java
@@ -120,19 +120,8 @@ public class FileRecords extends AbstractRecords 
implements Closeable {
         buffer.flip();
     }
 
-    /**
-     * Return a slice of records from this instance, which is a view into this 
set starting from the given position
-     * and with the given size limit.
-     *
-     * If the size is beyond the end of the file, the end will be based on the 
size of the file at the time of the read.
-     *
-     * If this message set is already sliced, the position will be taken 
relative to that slicing.
-     *
-     * @param position The start position to begin the read from
-     * @param size The number of bytes after the start position to include
-     * @return A sliced wrapper on this message set limited based on the given 
position and size
-     */
-    public FileRecords slice(int position, int size) throws IOException {
+    @Override
+    public Records slice(int position, int size) throws IOException {
         int availableBytes = availableBytes(position, size);
         int startPosition = this.start + position;
         return new FileRecords(file, channel, startPosition, startPosition + 
availableBytes, true);
diff --git 
a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java 
b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
index c2fd231e4b7..1786f61d187 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java
@@ -300,6 +300,31 @@ public class MemoryRecords extends AbstractRecords {
         return buffer.duplicate();
     }
 
+    @Override
+    public Records slice(int position, int size) {
+        if (position < 0)
+            throw new IllegalArgumentException("Invalid position: " + position 
+ " in read from " + this);
+        if (position > buffer.limit())
+            throw new IllegalArgumentException("Slice from position " + 
position + " exceeds end position of " + this);
+        if (size < 0)
+            throw new IllegalArgumentException("Invalid size: " + size + " in 
read from " + this);
+
+        int availableBytes = Math.min(size, buffer.limit() - position);
+        // As of now, clients module support Java11 hence can't use 
ByteBuffer::slice(position, size) method.
+        // So we need to create a duplicate buffer and set the position and 
limit. Duplicate buffer
+        // is backed by original bytes hence not the content but only the 
relative position and limit
+        // are changed in the duplicate buffer. Once the position and limit 
are set, we can call the
+        // slice method to get the sliced buffer, which is a backed by the 
original buffer with the
+        // position reset to 0 and limit set to the size of the slice.
+        ByteBuffer slicedBuffer = buffer.duplicate();
+        slicedBuffer.position(position);
+        slicedBuffer.limit(position + availableBytes);
+        // Reset the position to 0 so that the sliced view has a relative 
position.
+        slicedBuffer = slicedBuffer.slice();
+
+        return readableRecords(slicedBuffer);
+    }
+
     @Override
     public Iterable<MutableRecordBatch> batches() {
         return batches;
diff --git a/clients/src/main/java/org/apache/kafka/common/record/Records.java 
b/clients/src/main/java/org/apache/kafka/common/record/Records.java
index ec710394bec..3d45762e815 100644
--- a/clients/src/main/java/org/apache/kafka/common/record/Records.java
+++ b/clients/src/main/java/org/apache/kafka/common/record/Records.java
@@ -18,6 +18,7 @@ package org.apache.kafka.common.record;
 
 import org.apache.kafka.common.utils.AbstractIterator;
 
+import java.io.IOException;
 import java.util.Iterator;
 import java.util.Optional;
 
@@ -90,4 +91,19 @@ public interface Records extends TransferableRecords {
      * @return The record iterator
      */
     Iterable<Record> records();
+
+    /**
+     * Return a slice of records from this instance, which is a view into this 
set starting from the given position
+     * and with the given size limit.
+     *
+     * If the size is beyond the end of the records, the end will be based on 
the size of the records at the time of the read.
+     *
+     * If this records set is already sliced, the position will be taken 
relative to that slicing.
+     *
+     * @param position The start position to begin the read from. The position 
should be aligned to
+     *                 the batch boundary, else the returned records can't be 
iterated.
+     * @param size The number of bytes after the start position to include
+     * @return A sliced wrapper on this message set limited based on the given 
position and size
+     */
+    Records slice(int position, int size) throws IOException;
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java 
b/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java
index a2e89d3f4c6..75babcf95b0 100644
--- a/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java
@@ -142,7 +142,7 @@ public class FileRecordsTest {
             Future<Object> readerCompletion = executor.submit(() -> {
                 while (log.sizeInBytes() < maxSizeInBytes) {
                     int currentSize = log.sizeInBytes();
-                    FileRecords slice = log.slice(0, currentSize);
+                    Records slice = log.slice(0, currentSize);
                     assertEquals(currentSize, slice.sizeInBytes());
                 }
                 return null;
@@ -198,9 +198,9 @@ public class FileRecordsTest {
      */
     @Test
     public void testRead() throws IOException {
-        FileRecords read = fileRecords.slice(0, fileRecords.sizeInBytes());
+        Records read = fileRecords.slice(0, fileRecords.sizeInBytes());
         assertEquals(fileRecords.sizeInBytes(), read.sizeInBytes());
-        TestUtils.checkEquals(fileRecords.batches(), read.batches());
+        TestUtils.checkEquals(fileRecords.batches(), ((FileRecords) 
read).batches());
 
         List<RecordBatch> items = batches(read);
         RecordBatch first = items.get(0);
@@ -283,9 +283,9 @@ public class FileRecordsTest {
         RecordBatch batch = batches(fileRecords).get(1);
         int start = fileRecords.searchForOffsetFromPosition(1, 0).position;
         int size = batch.sizeInBytes();
-        FileRecords slice = fileRecords.slice(start, size);
+        Records slice = fileRecords.slice(start, size);
         assertEquals(Collections.singletonList(batch), batches(slice));
-        FileRecords slice2 = fileRecords.slice(start, size - 1);
+        Records slice2 = fileRecords.slice(start, size - 1);
         assertEquals(Collections.emptyList(), batches(slice2));
     }
 
@@ -429,24 +429,22 @@ public class FileRecordsTest {
             "abcd".getBytes(),
             "efgh".getBytes(),
             "ijkl".getBytes(),
-            "mnop".getBytes(),
-            "qrst".getBytes()
+            "mnopqr".getBytes(),
+            "stuv".getBytes()
         };
         try (FileRecords fileRecords = createFileRecords(values)) {
             List<RecordBatch> items = batches(fileRecords.slice(0, 
fileRecords.sizeInBytes()));
 
             // Slice from fourth message until the end.
             int position = IntStream.range(0, 3).map(i -> 
items.get(i).sizeInBytes()).sum();
-            FileRecords sliced  = fileRecords.slice(position, 
fileRecords.sizeInBytes() - position);
+            Records sliced  = fileRecords.slice(position, 
fileRecords.sizeInBytes() - position);
             assertEquals(fileRecords.sizeInBytes() - position, 
sliced.sizeInBytes());
             assertEquals(items.subList(3, items.size()), batches(sliced), 
"Read starting from the fourth message");
 
             // Further slice the already sliced file records, from fifth 
message until the end. Now the
-            // bytes available in the sliced file records are less than the 
start position. However, the
-            // position to slice is relative hence reset position to second 
message in the sliced file
-            // records i.e. reset with the size of the fourth message from the 
original file records.
-            position = items.get(4).sizeInBytes();
-            FileRecords finalSliced = sliced.slice(position, 
sliced.sizeInBytes() - position);
+            // bytes available in the sliced records are less than the moved 
position from original records.
+            position = items.get(3).sizeInBytes();
+            Records finalSliced = sliced.slice(position, sliced.sizeInBytes() 
- position);
             assertEquals(sliced.sizeInBytes() - position, 
finalSliced.sizeInBytes());
             assertEquals(items.subList(4, items.size()), batches(finalSliced), 
"Read starting from the fifth message");
         }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java 
b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
index 3818976e423..6f98da21eed 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java
@@ -35,14 +35,19 @@ import org.junit.jupiter.params.provider.Arguments;
 import org.junit.jupiter.params.provider.ArgumentsProvider;
 import org.junit.jupiter.params.provider.ArgumentsSource;
 
+import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.OptionalLong;
 import java.util.function.BiFunction;
 import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 import java.util.stream.Stream;
 
 import static java.util.Arrays.asList;
@@ -1068,6 +1073,146 @@ public class MemoryRecordsTest {
         });
     }
 
+    @ParameterizedTest
+    @ArgumentsSource(MemoryRecordsArgumentsProvider.class)
+    public void testSlice(Args args) throws IOException {
+        // Create a MemoryRecords instance with multiple batches. Prior 
RecordBatch.MAGIC_VALUE_V2,
+        // every append in a batch is a new batch. After 
RecordBatch.MAGIC_VALUE_V2, we can have multiple
+        // batches in a single MemoryRecords instance. Though with 
compression, we can have multiple
+        // appends resulting in a single batch prior 
RecordBatch.MAGIC_VALUE_V2 as well.
+        LinkedHashMap<Long, Integer> recordsPerOffset = new LinkedHashMap<>();
+        recordsPerOffset.put(args.firstOffset, 3);
+        recordsPerOffset.put(args.firstOffset + 6L, 8);
+        recordsPerOffset.put(args.firstOffset + 15L, 4);
+        MemoryRecords records = createMemoryRecords(args, recordsPerOffset);
+
+        // Test slicing from start
+        Records sliced = records.slice(0, records.sizeInBytes());
+        assertEquals(records.sizeInBytes(), sliced.sizeInBytes());
+        assertEquals(records.validBytes(), ((MemoryRecords) 
sliced).validBytes());
+        TestUtils.checkEquals(records.batches(), ((MemoryRecords) 
sliced).batches());
+
+        List<RecordBatch> items = batches(records);
+        // Test slicing first message.
+        RecordBatch first = items.get(0);
+        sliced = records.slice(first.sizeInBytes(), records.sizeInBytes() - 
first.sizeInBytes());
+        assertEquals(records.sizeInBytes() - first.sizeInBytes(), 
sliced.sizeInBytes());
+        assertEquals(items.subList(1, items.size()), batches(sliced), "Read 
starting from the second message");
+        assertTrue(((MemoryRecords) sliced).validBytes() <= 
sliced.sizeInBytes());
+
+        // Read from second message and size is past the end of the file.
+        sliced = records.slice(first.sizeInBytes(), records.sizeInBytes());
+        assertEquals(records.sizeInBytes() - first.sizeInBytes(), 
sliced.sizeInBytes());
+        assertEquals(items.subList(1, items.size()), batches(sliced), "Read 
starting from the second message");
+        assertTrue(((MemoryRecords) sliced).validBytes() <= 
sliced.sizeInBytes());
+
+        // Read from second message and position + size overflows.
+        sliced = records.slice(first.sizeInBytes(), Integer.MAX_VALUE);
+        assertEquals(records.sizeInBytes() - first.sizeInBytes(), 
sliced.sizeInBytes());
+        assertEquals(items.subList(1, items.size()), batches(sliced), "Read 
starting from the second message");
+        assertTrue(((MemoryRecords) sliced).validBytes() <= 
sliced.sizeInBytes());
+
+        // Read a single message starting from second message.
+        RecordBatch second = items.get(1);
+        sliced = records.slice(first.sizeInBytes(), second.sizeInBytes());
+        assertEquals(second.sizeInBytes(), sliced.sizeInBytes());
+        assertEquals(Collections.singletonList(second), batches(sliced), "Read 
a single message starting from the second message");
+
+        // Read from already sliced view.
+        List<RecordBatch> remainingItems = IntStream.range(0, 
items.size()).filter(i -> i != 0 && i != 
1).mapToObj(items::get).collect(Collectors.toList());
+        int remainingSize = 
remainingItems.stream().mapToInt(RecordBatch::sizeInBytes).sum();
+        sliced = records.slice(first.sizeInBytes(), records.sizeInBytes() - 
first.sizeInBytes())
+                        .slice(second.sizeInBytes(), records.sizeInBytes() - 
first.sizeInBytes() - second.sizeInBytes());
+        assertEquals(remainingSize, sliced.sizeInBytes());
+        assertEquals(remainingItems, batches(sliced), "Read starting from the 
third message");
+
+        // Read from second message and size is past the end of the file on 
the already sliced view.
+        sliced = records.slice(1, records.sizeInBytes() - 1)
+            .slice(first.sizeInBytes() - 1, records.sizeInBytes());
+        assertEquals(records.sizeInBytes() - first.sizeInBytes(), 
sliced.sizeInBytes());
+        assertEquals(items.subList(1, items.size()), batches(sliced), "Read 
starting from the second message");
+        assertTrue(((MemoryRecords) sliced).validBytes() <= 
sliced.sizeInBytes());
+
+        // Read from second message and position + size overflows on the 
already sliced view.
+        sliced = records.slice(1, records.sizeInBytes() - 1)
+            .slice(first.sizeInBytes() - 1, Integer.MAX_VALUE);
+        assertEquals(records.sizeInBytes() - first.sizeInBytes(), 
sliced.sizeInBytes());
+        assertEquals(items.subList(1, items.size()), batches(sliced), "Read 
starting from the second message");
+        assertTrue(((MemoryRecords) sliced).validBytes() <= 
sliced.sizeInBytes());
+    }
+
+    @ParameterizedTest
+    @ArgumentsSource(MemoryRecordsArgumentsProvider.class)
+    public void testSliceInvalidPosition(Args args) {
+        MemoryRecords records = createMemoryRecords(args, 
Map.of(args.firstOffset, 1));
+        assertThrows(IllegalArgumentException.class, () -> records.slice(-1, 
records.sizeInBytes()));
+        assertThrows(IllegalArgumentException.class, () -> 
records.slice(records.sizeInBytes() + 1, records.sizeInBytes()));
+    }
+
+    @ParameterizedTest
+    @ArgumentsSource(MemoryRecordsArgumentsProvider.class)
+    public void testSliceInvalidSize(Args args) {
+        MemoryRecords records = createMemoryRecords(args, 
Map.of(args.firstOffset, 1));
+        assertThrows(IllegalArgumentException.class, () -> records.slice(0, 
-1));
+    }
+
+    @Test
+    public void testSliceEmptyRecords() {
+        MemoryRecords empty = MemoryRecords.EMPTY;
+        Records sliced = empty.slice(0, 0);
+        assertEquals(0, sliced.sizeInBytes());
+        assertEquals(0, batches(sliced).size());
+    }
+
+    /**
+     * Test slice when already sliced memory records have start position 
greater than available bytes
+     * in the memory records.
+     */
+    @ParameterizedTest
+    @ArgumentsSource(MemoryRecordsArgumentsProvider.class)
+    public void testSliceForAlreadySlicedMemoryRecords(Args args) throws 
IOException {
+        LinkedHashMap<Long, Integer> recordsPerOffset = new LinkedHashMap<>();
+        recordsPerOffset.put(args.firstOffset, 5);
+        recordsPerOffset.put(args.firstOffset + 5L, 10);
+        recordsPerOffset.put(args.firstOffset + 15L, 12);
+        recordsPerOffset.put(args.firstOffset + 27L, 4);
+
+        MemoryRecords records = createMemoryRecords(args, recordsPerOffset);
+        List<RecordBatch> items = batches(records.slice(0, 
records.sizeInBytes()));
+
+        // Slice from third message until the end.
+        int position = IntStream.range(0, 2).map(i -> 
items.get(i).sizeInBytes()).sum();
+        Records sliced  = records.slice(position, records.sizeInBytes() - 
position);
+        assertEquals(records.sizeInBytes() - position, sliced.sizeInBytes());
+        assertEquals(items.subList(2, items.size()), batches(sliced), "Read 
starting from the third message");
+
+        // Further slice the already sliced memory records, from fourth 
message until the end. Now the
+        // bytes available in the sliced records are less than the moved 
position from original records.
+        position = items.get(2).sizeInBytes();
+        Records finalSliced = sliced.slice(position, sliced.sizeInBytes() - 
position);
+        assertEquals(sliced.sizeInBytes() - position, 
finalSliced.sizeInBytes());
+        assertEquals(items.subList(3, items.size()), batches(finalSliced), 
"Read starting from the fourth message");
+    }
+
+    private MemoryRecords createMemoryRecords(Args args, Map<Long, Integer> 
recordsPerOffset) {
+        ByteBuffer buffer = ByteBuffer.allocate(1024);
+        recordsPerOffset.forEach((offset, numOfRecords) -> {
+            MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, 
args.magic, args.compression,
+                    TimestampType.CREATE_TIME, offset);
+            for (int i = 0; i < numOfRecords; i++) {
+                builder.appendWithOffset(offset + i, 0L, 
TestUtils.randomString(10).getBytes(), TestUtils.randomString(10).getBytes());
+            }
+            builder.close();
+        });
+        buffer.flip();
+
+        return MemoryRecords.readableRecords(buffer);
+    }
+
+    private static List<RecordBatch> batches(Records buffer) {
+        return TestUtils.toList(buffer.batches());
+    }
+
     private static class RetainNonNullKeysFilter extends 
MemoryRecords.RecordFilter {
         public RetainNonNullKeysFilter() {
             super(0, 0);
diff --git a/core/src/main/java/kafka/server/share/ShareFetchUtils.java 
b/core/src/main/java/kafka/server/share/ShareFetchUtils.java
index 603ae8e048b..ba9e5368bcf 100644
--- a/core/src/main/java/kafka/server/share/ShareFetchUtils.java
+++ b/core/src/main/java/kafka/server/share/ShareFetchUtils.java
@@ -27,9 +27,9 @@ import 
org.apache.kafka.common.errors.OffsetNotAvailableException;
 import org.apache.kafka.common.message.ShareFetchResponseData;
 import org.apache.kafka.common.message.ShareFetchResponseData.AcquiredRecords;
 import org.apache.kafka.common.protocol.Errors;
-import 
org.apache.kafka.common.record.FileLogInputStream.FileChannelRecordBatch;
 import org.apache.kafka.common.record.FileRecords;
 import org.apache.kafka.common.record.MemoryRecords;
+import org.apache.kafka.common.record.RecordBatch;
 import org.apache.kafka.common.record.Records;
 import org.apache.kafka.common.requests.ListOffsetsRequest;
 import org.apache.kafka.coordinator.group.GroupConfigManager;
@@ -205,20 +205,17 @@ public class ShareFetchUtils {
      *
      * @param records The records to be sliced.
      * @param shareAcquiredRecords The share acquired records containing the 
non-empty acquired records.
-     * @return The sliced records, if the records are of type FileRecords and 
the acquired records are a subset
-     *        of the fetched records. Otherwise, the original records are 
returned.
+     * @return The sliced records, if the acquired records are a subset of the 
fetched records. Otherwise,
+     *         the original records are returned.
      */
     static Records maybeSliceFetchRecords(Records records, 
ShareAcquiredRecords shareAcquiredRecords) {
-        if (!(records instanceof FileRecords fileRecords)) {
-            return records;
-        }
         // The acquired records should be non-empty, do not check as the 
method is called only when the
         // acquired records are non-empty.
         List<AcquiredRecords> acquiredRecords = 
shareAcquiredRecords.acquiredRecords();
         try {
-            final Iterator<FileChannelRecordBatch> iterator = 
fileRecords.batchIterator();
+            final Iterator<? extends RecordBatch> iterator = 
records.batchIterator();
             // Track the first overlapping batch with the first acquired 
offset.
-            FileChannelRecordBatch firstOverlapBatch = iterator.next();
+            RecordBatch firstOverlapBatch = iterator.next();
             // If there exists single fetch batch, then return the original 
records.
             if (!iterator.hasNext()) {
                 return records;
@@ -230,7 +227,7 @@ public class ShareFetchUtils {
             int size = 0;
             // Start iterating from the second batch.
             while (iterator.hasNext()) {
-                FileChannelRecordBatch batch = iterator.next();
+                RecordBatch batch = iterator.next();
                 // Iterate until finds the first overlap batch with the first 
acquired offset. All the
                 // batches before this first overlap batch should be sliced 
hence increment the start
                 // position.
@@ -249,10 +246,10 @@ public class ShareFetchUtils {
             // acquired offset.
             size += firstOverlapBatch.sizeInBytes();
             // Check if we do not need slicing i.e. neither start position nor 
size changed.
-            if (startPosition == 0 && size == fileRecords.sizeInBytes()) {
+            if (startPosition == 0 && size == records.sizeInBytes()) {
                 return records;
             }
-            return fileRecords.slice(startPosition, size);
+            return records.slice(startPosition, size);
         } catch (Exception e) {
             log.error("Error while checking batches for acquired records: {}, 
skipping slicing.", acquiredRecords, e);
             // If there is an exception while slicing, return the original 
records so that the fetch
diff --git a/core/src/main/scala/kafka/tools/DumpLogSegments.scala 
b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
index a95c77cca4e..85bc4e1a269 100755
--- a/core/src/main/scala/kafka/tools/DumpLogSegments.scala
+++ b/core/src/main/scala/kafka/tools/DumpLogSegments.scala
@@ -277,7 +277,7 @@ object DumpLogSegments {
         println(s"Snapshot end offset: ${path.snapshotId.offset}, epoch: 
${path.snapshotId.epoch}")
       }
     }
-    val fileRecords = FileRecords.open(file, false).slice(0, maxBytes)
+    val fileRecords = FileRecords.open(file, false).slice(0, 
maxBytes).asInstanceOf[FileRecords]
     try {
       var validBytes = 0L
       var lastOffset = -1L
diff --git a/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java 
b/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java
index 3bec497d7a1..e3a77158daf 100644
--- a/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java
+++ b/core/src/test/java/kafka/server/share/ShareFetchUtilsTest.java
@@ -45,9 +45,15 @@ import org.apache.kafka.storage.log.metrics.BrokerTopicStats;
 import org.apache.kafka.test.TestUtils;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtensionContext;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.ArgumentsProvider;
+import org.junit.jupiter.params.provider.ArgumentsSource;
 import org.mockito.Mockito;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -56,6 +62,7 @@ import java.util.OptionalInt;
 import java.util.OptionalLong;
 import java.util.concurrent.CompletableFuture;
 import java.util.function.BiConsumer;
+import java.util.stream.Stream;
 
 import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.createFileRecords;
 import static 
org.apache.kafka.server.share.fetch.ShareFetchTestUtils.createShareAcquiredRecords;
@@ -462,11 +469,9 @@ public class ShareFetchUtilsTest {
         Mockito.verify(sp0, times(0)).updateCacheAndOffsets(any(Long.class));
     }
 
-    @Test
-    public void testMaybeSliceFetchRecordsSingleBatch() throws IOException {
-        // Create 1 batch of records with 10 records.
-        FileRecords records = createFileRecords(Map.of(5L, 10));
-
+    @ParameterizedTest(name = "{0}")
+    @ArgumentsSource(RecordsArgumentsProvider.class)
+    public void testMaybeSliceFetchRecordsSingleBatch(String name, Records 
records) {
         // Acquire all offsets, should return same records.
         List<AcquiredRecords> acquiredRecords = List.of(new 
AcquiredRecords().setFirstOffset(5).setLastOffset(14).setDeliveryCount((short) 
1));
         Records slicedRecords = 
ShareFetchUtils.maybeSliceFetchRecords(records, new 
ShareAcquiredRecords(acquiredRecords, 10));
@@ -498,15 +503,9 @@ public class ShareFetchUtilsTest {
         assertEquals(records, slicedRecords);
     }
 
-    @Test
-    public void testMaybeSliceFetchRecordsMultipleBatches() throws IOException 
{
-        // Create 3 batches of records with 3, 2 and 4 records respectively.
-        LinkedHashMap<Long, Integer> recordsPerOffset = new LinkedHashMap<>();
-        recordsPerOffset.put(0L, 3);
-        recordsPerOffset.put(3L, 2);
-        recordsPerOffset.put(7L, 4); // Gap of 2 offsets between batches.
-        FileRecords records = createFileRecords(recordsPerOffset);
-
+    @ParameterizedTest(name = "{0}")
+    @ArgumentsSource(MultipleBatchesRecordsArgumentsProvider.class)
+    public void testMaybeSliceFetchRecordsMultipleBatches(String name, Records 
records) {
         // Acquire all offsets, should return same records.
         List<AcquiredRecords> acquiredRecords = List.of(new 
AcquiredRecords().setFirstOffset(0).setLastOffset(10).setDeliveryCount((short) 
1));
         Records slicedRecords = 
ShareFetchUtils.maybeSliceFetchRecords(records, new 
ShareAcquiredRecords(acquiredRecords, 11));
@@ -617,10 +616,9 @@ public class ShareFetchUtilsTest {
         assertEquals(records.sizeInBytes(), slicedRecords.sizeInBytes());
     }
 
-    @Test
-    public void testMaybeSliceFetchRecordsException() throws IOException {
-        // Create 1 batch of records with 3 records.
-        FileRecords records = createFileRecords(Map.of(0L, 3));
+    @ParameterizedTest(name = "{0}")
+    @ArgumentsSource(MultipleBatchesRecordsArgumentsProvider.class)
+    public void testMaybeSliceFetchRecordsException(String name, Records 
records) {
         // Send empty acquired records which should trigger an exception and 
same file records should
         // be returned. The method doesn't expect empty acquired records.
         Records slicedRecords = ShareFetchUtils.maybeSliceFetchRecords(
@@ -628,14 +626,41 @@ public class ShareFetchUtilsTest {
         assertEquals(records, slicedRecords);
     }
 
-    @Test
-    public void testMaybeSliceFetchRecordsNonFileRecords() {
-        // Send memory records which should be returned as is.
-        try (MemoryRecordsBuilder records = memoryRecordsBuilder(2, 0)) {
-            List<AcquiredRecords> acquiredRecords = List.of(new 
AcquiredRecords().setFirstOffset(0).setLastOffset(1).setDeliveryCount((short) 
1));
-            Records slicedRecords = ShareFetchUtils.maybeSliceFetchRecords(
-                records.build(), new ShareAcquiredRecords(acquiredRecords, 2));
-            assertEquals(records.build(), slicedRecords);
+    private static class RecordsArgumentsProvider implements ArgumentsProvider 
{
+        @Override
+        public Stream<? extends Arguments> provideArguments(ExtensionContext 
context) throws Exception {
+            return Stream.of(
+                Arguments.of("FileRecords", createFileRecords(Map.of(5L, 10))),
+                Arguments.of("MemoryRecords", createMemoryRecords(5L, 10))
+            );
+        }
+
+        private MemoryRecords createMemoryRecords(long baseOffset, int 
numRecords) {
+            try (MemoryRecordsBuilder recordsBuilder = 
memoryRecordsBuilder(numRecords, baseOffset)) {
+                return recordsBuilder.build();
+            }
+        }
+    }
+
+    private static class MultipleBatchesRecordsArgumentsProvider implements 
ArgumentsProvider {
+        @Override
+        public Stream<? extends Arguments> provideArguments(ExtensionContext 
context) throws Exception {
+            LinkedHashMap<Long, Integer> recordsPerOffset = new 
LinkedHashMap<>();
+            recordsPerOffset.put(0L, 3);
+            recordsPerOffset.put(3L, 2);
+            recordsPerOffset.put(7L, 4); // Gap of 2 offsets between batches.
+            return Stream.of(
+                Arguments.of("FileRecords", 
createFileRecords(recordsPerOffset)),
+                Arguments.of("MemoryRecords", 
createMemoryRecords(recordsPerOffset))
+            );
+        }
+
+        private MemoryRecords createMemoryRecords(Map<Long, Integer> 
recordsPerOffset) {
+            ByteBuffer buffer = ByteBuffer.allocate(1024);
+            recordsPerOffset.forEach((offset, numOfRecords) -> 
memoryRecordsBuilder(buffer, numOfRecords, offset).close());
+            buffer.flip();
+
+            return MemoryRecords.readableRecords(buffer);
         }
     }
 }

Reply via email to