showuon commented on code in PR #13135:
URL: https://github.com/apache/kafka/pull/13135#discussion_r1185891570


##########
clients/src/main/java/org/apache/kafka/common/utils/ChunkedBytesStream.java:
##########
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import java.io.BufferedInputStream;
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * ChunkedBytesStream is a copy of {@link ByteBufferInputStream} with the 
following differences:
+ * - Unlike {@link java.io.BufferedInputStream#skip(long)} this class could be 
configured to not push skip() to
+ * input stream. We may want to avoid pushing this to input stream because 
it's implementation maybe inefficient,
+ * e.g. the case of ZstdInputStream which allocates a new buffer from buffer 
pool, per skip call.
+ * - Unlike {@link java.io.BufferedInputStream}, which allocates an 
intermediate buffer, this uses a buffer supplier to
+ * create the intermediate buffer.
+ * <p>
+ * Note that:
+ * - this class is not thread safe and shouldn't be used in scenarios where 
multiple threads access this.
+ * - the implementation of this class is performance sensitive. Minor changes 
such as usage of ByteBuffer instead of byte[]
+ * can significantly impact performance, hence, proceed with caution.
+ */
+public class ChunkedBytesStream extends FilterInputStream {
+    /**
+     * Supplies the ByteBuffer which is used as intermediate buffer to store 
the chunk of output data.
+     */
+    private final BufferSupplier bufferSupplier;
+    /**
+     * Intermediate buffer to store the chunk of output data. The 
ChunkedBytesStream is considered closed if
+     * this buffer is null.
+     */
+    private byte[] intermediateBuf;
+    /**
+     * The index one greater than the index of the last valid byte in
+     * the buffer.
+     * This value is always in the range <code>0</code> through 
<code>intermediateBuf.length</code>;
+     * elements <code>intermediateBuf[0]</code>  through 
<code>intermediateBuf[count-1]
+     * </code>contain buffered input data obtained
+     * from the underlying  input stream.
+     */
+    protected int count = 0;
+    /**
+     * The current position in the buffer. This is the index of the next
+     * character to be read from the <code>buf</code> array.
+     * <p>
+     * This value is always in the range <code>0</code>
+     * through <code>count</code>. If it is less
+     * than <code>count</code>, then  <code>intermediateBuf[pos]</code>
+     * is the next byte to be supplied as input;
+     * if it is equal to <code>count</code>, then
+     * the  next <code>read</code> or <code>skip</code>
+     * operation will require more bytes to be
+     * read from the contained  input stream.
+     */
+    protected int pos = 0;
+    /**
+     * Reference for the intermediate buffer. This reference is only kept for 
releasing the buffer from the
+     * buffer supplier.
+     */
+    private final ByteBuffer intermediateBufRef;
+    /**
+     * Determines if the skip be pushed down
+     */
+    private final boolean pushSkipToSourceStream;
+
+    public ChunkedBytesStream(InputStream in, BufferSupplier bufferSupplier, 
int intermediateBufSize, boolean pushSkipToSourceStream) {
+        super(in);
+        this.bufferSupplier = bufferSupplier;
+        intermediateBufRef = bufferSupplier.get(intermediateBufSize);
+        if (!intermediateBufRef.hasArray() || 
(intermediateBufRef.arrayOffset() != 0)) {
+            throw new IllegalArgumentException("provided ByteBuffer lacks 
array or has non-zero arrayOffset");
+        }
+        intermediateBuf = intermediateBufRef.array();
+        this.pushSkipToSourceStream = pushSkipToSourceStream;
+    }
+
+    /**
+     * Check to make sure that buffer has not been nulled out due to
+     * close; if not return it;
+     */
+    private byte[] getBufIfOpen() throws IOException {
+        byte[] buffer = intermediateBuf;
+        if (buffer == null)
+            throw new IOException("Stream closed");
+        return buffer;
+    }
+
+    /**
+     * See

Review Comment:
   See ?



##########
clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java:
##########
@@ -356,164 +346,100 @@ private static DefaultRecord readFrom(ByteBuffer buffer,
                 throw new InvalidRecordException("Invalid record size: 
expected to read " + sizeOfBodyInBytes +
                         " bytes in record payload, but instead read " + 
(buffer.position() - recordStart));
 
-            return new DefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, key, value, headers);
+            int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + 
sizeOfBodyInBytes;
+            return new DefaultRecord(totalSizeInBytes, attributes, offset, 
timestamp, sequence, key, value, headers);
         } catch (BufferUnderflowException | IllegalArgumentException e) {
             throw new InvalidRecordException("Found invalid record structure", 
e);
         }
     }
 
-    public static PartialDefaultRecord readPartiallyFrom(DataInput input,
-                                                         byte[] skipArray,
+    public static PartialDefaultRecord readPartiallyFrom(InputStream input,
                                                          long baseOffset,
                                                          long baseTimestamp,
                                                          int baseSequence,
                                                          Long logAppendTime) 
throws IOException {
         int sizeOfBodyInBytes = ByteUtils.readVarint(input);
         int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + 
sizeOfBodyInBytes;
 
-        return readPartiallyFrom(input, skipArray, totalSizeInBytes, 
sizeOfBodyInBytes, baseOffset, baseTimestamp,
+        return readPartiallyFrom(input, totalSizeInBytes, baseOffset, 
baseTimestamp,
             baseSequence, logAppendTime);
     }
 
-    private static PartialDefaultRecord readPartiallyFrom(DataInput input,
-                                                          byte[] skipArray,
+    private static PartialDefaultRecord readPartiallyFrom(InputStream input,
                                                           int sizeInBytes,
-                                                          int 
sizeOfBodyInBytes,
                                                           long baseOffset,
                                                           long baseTimestamp,
                                                           int baseSequence,
                                                           Long logAppendTime) 
throws IOException {
-        ByteBuffer skipBuffer = ByteBuffer.wrap(skipArray);
-        // set its limit to 0 to indicate no bytes readable yet
-        skipBuffer.limit(0);
-
         try {
-            // reading the attributes / timestamp / offset and key-size does 
not require
-            // any byte array allocation and therefore we can just read them 
straight-forwardly
-            IntRef bytesRemaining = PrimitiveRef.ofInt(sizeOfBodyInBytes);
-
-            byte attributes = readByte(skipBuffer, input, bytesRemaining);
-            long timestampDelta = readVarLong(skipBuffer, input, 
bytesRemaining);
+            byte attributes = (byte) input.read();
+            long timestampDelta = ByteUtils.readVarlong(input);
             long timestamp = baseTimestamp + timestampDelta;
             if (logAppendTime != null)
                 timestamp = logAppendTime;
 
-            int offsetDelta = readVarInt(skipBuffer, input, bytesRemaining);
+            int offsetDelta = ByteUtils.readVarint(input);
             long offset = baseOffset + offsetDelta;
             int sequence = baseSequence >= 0 ?
                 DefaultRecordBatch.incrementSequence(baseSequence, 
offsetDelta) :
                 RecordBatch.NO_SEQUENCE;
 
-            // first skip key
-            int keySize = skipLengthDelimitedField(skipBuffer, input, 
bytesRemaining);
+            // skip key
+            int keySize = ByteUtils.readVarint(input);
+            skipBytes(input, keySize);
 
-            // then skip value
-            int valueSize = skipLengthDelimitedField(skipBuffer, input, 
bytesRemaining);
+            // skip value
+            int valueSize = ByteUtils.readVarint(input);
+            skipBytes(input, valueSize);
 
-            // then skip header
-            int numHeaders = readVarInt(skipBuffer, input, bytesRemaining);
+            // skip header
+            int numHeaders = ByteUtils.readVarint(input);
             if (numHeaders < 0)
                 throw new InvalidRecordException("Found invalid number of 
record headers " + numHeaders);
             for (int i = 0; i < numHeaders; i++) {
-                int headerKeySize = skipLengthDelimitedField(skipBuffer, 
input, bytesRemaining);
+                int headerKeySize = ByteUtils.readVarint(input);
                 if (headerKeySize < 0)
                     throw new InvalidRecordException("Invalid negative header 
key size " + headerKeySize);
+                skipBytes(input, headerKeySize);
 
                 // headerValueSize
-                skipLengthDelimitedField(skipBuffer, input, bytesRemaining);
+                int headerValueSize = ByteUtils.readVarint(input);
+                skipBytes(input, headerValueSize);
             }
 
-            if (bytesRemaining.value > 0 || skipBuffer.remaining() > 0)
-                throw new InvalidRecordException("Invalid record size: 
expected to read " + sizeOfBodyInBytes +
-                    " bytes in record payload, but there are still bytes 
remaining");
-
             return new PartialDefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, keySize, valueSize);
         } catch (BufferUnderflowException | IllegalArgumentException e) {
             throw new InvalidRecordException("Found invalid record structure", 
e);
         }
     }
 
-    private static byte readByte(ByteBuffer buffer, DataInput input, IntRef 
bytesRemaining) throws IOException {
-        if (buffer.remaining() < 1 && bytesRemaining.value > 0) {
-            readMore(buffer, input, bytesRemaining);
-        }
-
-        return buffer.get();
-    }
-
-    private static long readVarLong(ByteBuffer buffer, DataInput input, IntRef 
bytesRemaining) throws IOException {
-        if (buffer.remaining() < 10 && bytesRemaining.value > 0) {
-            readMore(buffer, input, bytesRemaining);
-        }
-
-        return ByteUtils.readVarlong(buffer);
-    }
-
-    private static int readVarInt(ByteBuffer buffer, DataInput input, IntRef 
bytesRemaining) throws IOException {
-        if (buffer.remaining() < 5 && bytesRemaining.value > 0) {
-            readMore(buffer, input, bytesRemaining);
-        }
-
-        return ByteUtils.readVarint(buffer);
-    }
-
-    private static int skipLengthDelimitedField(ByteBuffer buffer, DataInput 
input, IntRef bytesRemaining) throws IOException {
-        boolean needMore = false;
-        int sizeInBytes = -1;
-        int bytesToSkip = -1;
-
-        while (true) {
-            if (needMore) {
-                readMore(buffer, input, bytesRemaining);
-                needMore = false;
-            }
-
-            if (bytesToSkip < 0) {
-                if (buffer.remaining() < 5 && bytesRemaining.value > 0) {
-                    needMore = true;
-                } else {
-                    sizeInBytes = ByteUtils.readVarint(buffer);
-                    if (sizeInBytes <= 0)
-                        return sizeInBytes;
-                    else
-                        bytesToSkip = sizeInBytes;
 
+    /**
+     * Skips n bytes from the data input.

Review Comment:
   nit: Skips {@code bytesToSkip} bytes...



##########
clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java:
##########
@@ -356,164 +346,100 @@ private static DefaultRecord readFrom(ByteBuffer buffer,
                 throw new InvalidRecordException("Invalid record size: 
expected to read " + sizeOfBodyInBytes +
                         " bytes in record payload, but instead read " + 
(buffer.position() - recordStart));
 
-            return new DefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, key, value, headers);
+            int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + 
sizeOfBodyInBytes;
+            return new DefaultRecord(totalSizeInBytes, attributes, offset, 
timestamp, sequence, key, value, headers);
         } catch (BufferUnderflowException | IllegalArgumentException e) {
             throw new InvalidRecordException("Found invalid record structure", 
e);
         }
     }
 
-    public static PartialDefaultRecord readPartiallyFrom(DataInput input,
-                                                         byte[] skipArray,
+    public static PartialDefaultRecord readPartiallyFrom(InputStream input,
                                                          long baseOffset,
                                                          long baseTimestamp,
                                                          int baseSequence,
                                                          Long logAppendTime) 
throws IOException {
         int sizeOfBodyInBytes = ByteUtils.readVarint(input);
         int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + 
sizeOfBodyInBytes;
 
-        return readPartiallyFrom(input, skipArray, totalSizeInBytes, 
sizeOfBodyInBytes, baseOffset, baseTimestamp,
+        return readPartiallyFrom(input, totalSizeInBytes, baseOffset, 
baseTimestamp,
             baseSequence, logAppendTime);
     }
 
-    private static PartialDefaultRecord readPartiallyFrom(DataInput input,
-                                                          byte[] skipArray,
+    private static PartialDefaultRecord readPartiallyFrom(InputStream input,
                                                           int sizeInBytes,
-                                                          int 
sizeOfBodyInBytes,
                                                           long baseOffset,
                                                           long baseTimestamp,
                                                           int baseSequence,
                                                           Long logAppendTime) 
throws IOException {
-        ByteBuffer skipBuffer = ByteBuffer.wrap(skipArray);
-        // set its limit to 0 to indicate no bytes readable yet
-        skipBuffer.limit(0);
-
         try {
-            // reading the attributes / timestamp / offset and key-size does 
not require
-            // any byte array allocation and therefore we can just read them 
straight-forwardly
-            IntRef bytesRemaining = PrimitiveRef.ofInt(sizeOfBodyInBytes);
-
-            byte attributes = readByte(skipBuffer, input, bytesRemaining);
-            long timestampDelta = readVarLong(skipBuffer, input, 
bytesRemaining);
+            byte attributes = (byte) input.read();
+            long timestampDelta = ByteUtils.readVarlong(input);
             long timestamp = baseTimestamp + timestampDelta;
             if (logAppendTime != null)
                 timestamp = logAppendTime;
 
-            int offsetDelta = readVarInt(skipBuffer, input, bytesRemaining);
+            int offsetDelta = ByteUtils.readVarint(input);
             long offset = baseOffset + offsetDelta;
             int sequence = baseSequence >= 0 ?
                 DefaultRecordBatch.incrementSequence(baseSequence, 
offsetDelta) :
                 RecordBatch.NO_SEQUENCE;
 
-            // first skip key
-            int keySize = skipLengthDelimitedField(skipBuffer, input, 
bytesRemaining);
+            // skip key
+            int keySize = ByteUtils.readVarint(input);
+            skipBytes(input, keySize);
 
-            // then skip value
-            int valueSize = skipLengthDelimitedField(skipBuffer, input, 
bytesRemaining);
+            // skip value
+            int valueSize = ByteUtils.readVarint(input);
+            skipBytes(input, valueSize);
 
-            // then skip header
-            int numHeaders = readVarInt(skipBuffer, input, bytesRemaining);
+            // skip header
+            int numHeaders = ByteUtils.readVarint(input);
             if (numHeaders < 0)
                 throw new InvalidRecordException("Found invalid number of 
record headers " + numHeaders);
             for (int i = 0; i < numHeaders; i++) {
-                int headerKeySize = skipLengthDelimitedField(skipBuffer, 
input, bytesRemaining);
+                int headerKeySize = ByteUtils.readVarint(input);
                 if (headerKeySize < 0)
                     throw new InvalidRecordException("Invalid negative header 
key size " + headerKeySize);
+                skipBytes(input, headerKeySize);
 
                 // headerValueSize
-                skipLengthDelimitedField(skipBuffer, input, bytesRemaining);
+                int headerValueSize = ByteUtils.readVarint(input);
+                skipBytes(input, headerValueSize);
             }
 
-            if (bytesRemaining.value > 0 || skipBuffer.remaining() > 0)
-                throw new InvalidRecordException("Invalid record size: 
expected to read " + sizeOfBodyInBytes +
-                    " bytes in record payload, but there are still bytes 
remaining");
-
             return new PartialDefaultRecord(sizeInBytes, attributes, offset, 
timestamp, sequence, keySize, valueSize);
         } catch (BufferUnderflowException | IllegalArgumentException e) {
             throw new InvalidRecordException("Found invalid record structure", 
e);
         }
     }
 
-    private static byte readByte(ByteBuffer buffer, DataInput input, IntRef 
bytesRemaining) throws IOException {
-        if (buffer.remaining() < 1 && bytesRemaining.value > 0) {
-            readMore(buffer, input, bytesRemaining);
-        }
-
-        return buffer.get();
-    }
-
-    private static long readVarLong(ByteBuffer buffer, DataInput input, IntRef 
bytesRemaining) throws IOException {
-        if (buffer.remaining() < 10 && bytesRemaining.value > 0) {
-            readMore(buffer, input, bytesRemaining);
-        }
-
-        return ByteUtils.readVarlong(buffer);
-    }
-
-    private static int readVarInt(ByteBuffer buffer, DataInput input, IntRef 
bytesRemaining) throws IOException {
-        if (buffer.remaining() < 5 && bytesRemaining.value > 0) {
-            readMore(buffer, input, bytesRemaining);
-        }
-
-        return ByteUtils.readVarint(buffer);
-    }
-
-    private static int skipLengthDelimitedField(ByteBuffer buffer, DataInput 
input, IntRef bytesRemaining) throws IOException {
-        boolean needMore = false;
-        int sizeInBytes = -1;
-        int bytesToSkip = -1;
-
-        while (true) {
-            if (needMore) {
-                readMore(buffer, input, bytesRemaining);
-                needMore = false;
-            }
-
-            if (bytesToSkip < 0) {
-                if (buffer.remaining() < 5 && bytesRemaining.value > 0) {
-                    needMore = true;
-                } else {
-                    sizeInBytes = ByteUtils.readVarint(buffer);
-                    if (sizeInBytes <= 0)
-                        return sizeInBytes;
-                    else
-                        bytesToSkip = sizeInBytes;
 
+    /**
+     * Skips n bytes from the data input.
+     *
+     * No-op for case where bytesToSkip <= 0. This could occur for cases where 
field is expected to be null.
+     * @throws  InvalidRecordException if the number of bytes could not be 
skipped.
+     */
+    private static void skipBytes(InputStream in, int bytesToSkip) throws 
IOException {
+        if (bytesToSkip <= 0) return;
+
+        // Starting JDK 12, this implementation could be replaced by 
InputStream#skipNBytes
+        while (bytesToSkip > 0) {
+            long ns = in.skip(bytesToSkip);
+            if (ns > 0 && ns <= bytesToSkip) {
+                // adjust number to skip
+                bytesToSkip -= ns;
+            } else if (ns == 0) { // no bytes skipped
+                // read one byte to check for EOS
+                if (in.read() == -1) {

Review Comment:
   What should we do if `ns == 0`, but not reaching EOS? Should we also throw 
an exception in this case?



##########
clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java:
##########
@@ -367,32 +403,164 @@ public void testStreamingIteratorConsistency() {
         }
     }
 
-    @Test
-    public void testSkipKeyValueIteratorCorrectness() {
-        Header[] headers = {new RecordHeader("k1", "v1".getBytes()), new 
RecordHeader("k2", "v2".getBytes())};
+    @ParameterizedTest
+    @EnumSource(value = CompressionType.class)
+    public void testSkipKeyValueIteratorCorrectness(CompressionType 
compressionType) throws NoSuchAlgorithmException {
+        Header[] headers = {new RecordHeader("k1", "v1".getBytes()), new 
RecordHeader("k2", null)};
+        byte[] largeRecordValue = new byte[200 * 1024]; // 200KB
+        RANDOM.nextBytes(largeRecordValue);
 
         MemoryRecords records = 
MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L,
-            CompressionType.LZ4, TimestampType.CREATE_TIME,
+            compressionType, TimestampType.CREATE_TIME,
+            // one sample with small value size
             new SimpleRecord(1L, "a".getBytes(), "1".getBytes()),
-            new SimpleRecord(2L, "b".getBytes(), "2".getBytes()),
-            new SimpleRecord(3L, "c".getBytes(), "3".getBytes()),
-            new SimpleRecord(1000L, "abc".getBytes(), "0".getBytes()),
+            // one sample with null value
+            new SimpleRecord(2L, "b".getBytes(), null),
+            // one sample with null key
+            new SimpleRecord(3L, null, "3".getBytes()),
+            // one sample with null key and null value
+            new SimpleRecord(4L, null, (byte[]) null),
+            // one sample with large value size
+            new SimpleRecord(1000L, "abc".getBytes(), largeRecordValue),
+            // one sample with headers, one of the header has null value
             new SimpleRecord(9999L, "abc".getBytes(), "0".getBytes(), headers)
             );
+
         DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer());
-        try (CloseableIterator<Record> streamingIterator = 
batch.skipKeyValueIterator(BufferSupplier.NO_CACHING)) {
-            assertEquals(Arrays.asList(
-                new PartialDefaultRecord(9, (byte) 0, 0L, 1L, -1, 1, 1),
-                new PartialDefaultRecord(9, (byte) 0, 1L, 2L, -1, 1, 1),
-                new PartialDefaultRecord(9, (byte) 0, 2L, 3L, -1, 1, 1),
-                new PartialDefaultRecord(12, (byte) 0, 3L, 1000L, -1, 3, 1),
-                new PartialDefaultRecord(25, (byte) 0, 4L, 9999L, -1, 3, 1)
-                ),
-                Utils.toList(streamingIterator)
-            );
+
+        try (BufferSupplier bufferSupplier = BufferSupplier.create();
+             CloseableIterator<Record> skipKeyValueIterator = 
batch.skipKeyValueIterator(bufferSupplier)) {
+
+            if (CompressionType.NONE == compressionType) {
+                // assert that for uncompressed data stream record iterator is 
not used
+                assertTrue(skipKeyValueIterator instanceof 
DefaultRecordBatch.RecordIterator);
+                // superficial validation for correctness. Deep validation is 
already performed in other tests
+                assertEquals(Utils.toList(records.records()).size(), 
Utils.toList(skipKeyValueIterator).size());
+            } else {
+                // assert that a streaming iterator is used for compressed 
records
+                assertTrue(skipKeyValueIterator instanceof 
DefaultRecordBatch.StreamRecordIterator);
+                // assert correctness for compressed records
+                assertIterableEquals(Arrays.asList(
+                        new PartialDefaultRecord(9, (byte) 0, 0L, 1L, -1, 1, 
1),
+                        new PartialDefaultRecord(8, (byte) 0, 1L, 2L, -1, 1, 
-1),
+                        new PartialDefaultRecord(8, (byte) 0, 2L, 3L, -1, -1, 
1),
+                        new PartialDefaultRecord(7, (byte) 0, 3L, 4L, -1, -1, 
-1),
+                        new PartialDefaultRecord(15 + largeRecordValue.length, 
(byte) 0, 4L, 1000L, -1, 3, largeRecordValue.length),
+                        new PartialDefaultRecord(23, (byte) 0, 5L, 9999L, -1, 
3, 1)
+                    ), Utils.toList(skipKeyValueIterator));
+            }
+        }
+    }
+
+    @ParameterizedTest
+    @MethodSource
+    public void testBufferReuseInSkipKeyValueIterator(CompressionType 
compressionType, int expectedNumBufferAllocations, byte[] recordValue) {
+        MemoryRecords records = 
MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L,
+            compressionType, TimestampType.CREATE_TIME,
+            new SimpleRecord(1000L, "a".getBytes(), "0".getBytes()),
+            new SimpleRecord(9999L, "b".getBytes(), recordValue)
+        );
+
+        DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer());
+
+        try (BufferSupplier bufferSupplier = spy(BufferSupplier.create());
+             CloseableIterator<Record> streamingIterator = 
batch.skipKeyValueIterator(bufferSupplier)) {
+
+            // Consume through the iterator
+            Utils.toList(streamingIterator);
+
+            // Close the iterator to release any buffers
+            streamingIterator.close();
+
+            // assert number of buffer allocations
+            verify(bufferSupplier, 
times(expectedNumBufferAllocations)).get(anyInt());
+            verify(bufferSupplier, 
times(expectedNumBufferAllocations)).release(any(ByteBuffer.class));
+        }
+    }
+    private static Stream<Arguments> testBufferReuseInSkipKeyValueIterator() 
throws NoSuchAlgorithmException {
+        byte[] smallRecordValue = "1".getBytes();
+        byte[] largeRecordValue = new byte[512 * 1024]; // 512KB
+        RANDOM.nextBytes(largeRecordValue);
+
+        return Stream.of(
+            /*
+             * 1 allocation per batch (i.e. per iterator instance) for buffer 
holding uncompressed data
+             * = 1 buffer allocations
+             */
+            Arguments.of(CompressionType.GZIP, 1, smallRecordValue),
+            Arguments.of(CompressionType.GZIP, 1, largeRecordValue),
+            Arguments.of(CompressionType.SNAPPY, 1, smallRecordValue),
+            Arguments.of(CompressionType.SNAPPY, 1, largeRecordValue),
+            /*
+             * 1 allocation per batch (i.e. per iterator instance) for buffer 
holding compressed data
+             * 1 allocation per batch (i.e. per iterator instance) for buffer 
holding uncompressed data
+             * = 2 buffer allocations
+             */
+            Arguments.of(CompressionType.LZ4, 2, smallRecordValue),
+            Arguments.of(CompressionType.LZ4, 2, largeRecordValue),
+            Arguments.of(CompressionType.ZSTD, 2, smallRecordValue),
+            Arguments.of(CompressionType.ZSTD, 2, largeRecordValue)
+        );
+    }
+
+    @ParameterizedTest
+    @MethodSource
+    public void testZstdJniForSkipKeyValueIterator(int expectedJniCalls, 
byte[] recordValue) throws IOException {
+        MemoryRecords records = 
MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L,
+            CompressionType.ZSTD, TimestampType.CREATE_TIME,
+            new SimpleRecord(9L, "hakuna-matata".getBytes(), recordValue)
+        );
+
+        // Buffer containing compressed data
+        final ByteBuffer compressedBuf = records.buffer();
+        // Create a RecordBatch object
+        final DefaultRecordBatch batch = spy(new 
DefaultRecordBatch(compressedBuf.duplicate()));
+        final CompressionType mockCompression = 
mock(CompressionType.ZSTD.getClass());
+        doReturn(mockCompression).when(batch).compressionType();
+
+        // Buffer containing compressed records to be used for creating 
zstd-jni stream
+        ByteBuffer recordsBuffer = compressedBuf.duplicate();
+        recordsBuffer.position(RECORDS_OFFSET);
+
+        try (final BufferSupplier bufferSupplier = BufferSupplier.create();
+             final InputStream zstdStream = 
spy(ZstdFactory.wrapForInput(recordsBuffer, batch.magic(), bufferSupplier));
+             final InputStream chunkedStream = new 
ChunkedBytesStream(zstdStream, bufferSupplier, 16 * 1024, false)) {
+
+            when(mockCompression.wrapForInput(any(ByteBuffer.class), 
anyByte(), any(BufferSupplier.class))).thenReturn(chunkedStream);
+
+            try (CloseableIterator<Record> streamingIterator = 
batch.skipKeyValueIterator(bufferSupplier)) {
+                assertNotNull(streamingIterator);
+                Utils.toList(streamingIterator);
+                // verify the number of read() calls to zstd JNI stream. Each 
read() call is a JNI call.
+                verify(zstdStream, 
times(expectedJniCalls)).read(any(byte[].class), anyInt(), anyInt());
+                // verify that we don't use the underlying skip() 
functionality. The underlying skip() allocates
+                // 1 buffer per skip call from he buffer pool whereas our 
implementation

Review Comment:
   whereas our implementation?



##########
clients/src/main/java/org/apache/kafka/common/utils/ChunkedBytesStream.java:
##########
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import java.io.BufferedInputStream;
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * ChunkedBytesStream is a copy of {@link ByteBufferInputStream} with the 
following differences:

Review Comment:
   a copy of `ByteBufferInputStream`? `BufferedInputStream`?



##########
clients/src/main/java/org/apache/kafka/common/utils/ChunkedBytesStream.java:
##########
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.common.utils;
+
+import java.io.BufferedInputStream;
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * ChunkedBytesStream is a copy of {@link ByteBufferInputStream} with the 
following differences:
+ * - Unlike {@link java.io.BufferedInputStream#skip(long)} this class could be 
configured to not push skip() to
+ * input stream. We may want to avoid pushing this to input stream because 
it's implementation maybe inefficient,
+ * e.g. the case of ZstdInputStream which allocates a new buffer from buffer 
pool, per skip call.
+ * - Unlike {@link java.io.BufferedInputStream}, which allocates an 
intermediate buffer, this uses a buffer supplier to
+ * create the intermediate buffer.
+ * <p>
+ * Note that:
+ * - this class is not thread safe and shouldn't be used in scenarios where 
multiple threads access this.
+ * - the implementation of this class is performance sensitive. Minor changes 
such as usage of ByteBuffer instead of byte[]
+ * can significantly impact performance, hence, proceed with caution.
+ */
+public class ChunkedBytesStream extends FilterInputStream {
+    /**
+     * Supplies the ByteBuffer which is used as intermediate buffer to store 
the chunk of output data.
+     */
+    private final BufferSupplier bufferSupplier;
+    /**
+     * Intermediate buffer to store the chunk of output data. The 
ChunkedBytesStream is considered closed if
+     * this buffer is null.
+     */
+    private byte[] intermediateBuf;
+    /**
+     * The index one greater than the index of the last valid byte in
+     * the buffer.
+     * This value is always in the range <code>0</code> through 
<code>intermediateBuf.length</code>;
+     * elements <code>intermediateBuf[0]</code>  through 
<code>intermediateBuf[count-1]
+     * </code>contain buffered input data obtained
+     * from the underlying  input stream.
+     */
+    protected int count = 0;
+    /**
+     * The current position in the buffer. This is the index of the next
+     * character to be read from the <code>buf</code> array.
+     * <p>
+     * This value is always in the range <code>0</code>
+     * through <code>count</code>. If it is less
+     * than <code>count</code>, then  <code>intermediateBuf[pos]</code>
+     * is the next byte to be supplied as input;
+     * if it is equal to <code>count</code>, then
+     * the  next <code>read</code> or <code>skip</code>
+     * operation will require more bytes to be
+     * read from the contained  input stream.
+     */
+    protected int pos = 0;
+    /**
+     * Reference for the intermediate buffer. This reference is only kept for 
releasing the buffer from the
+     * buffer supplier.
+     */
+    private final ByteBuffer intermediateBufRef;
+    /**
+     * Determines if the skip be pushed down
+     */
+    private final boolean pushSkipToSourceStream;

Review Comment:
   Could we add more explanation for this variable? It's quite difficult to 
know from this comment. Also, the variable name is really hard to understand. 
My understanding is if it's true, we don't even read anything from the source 
stream for N bytes, if false, we read N bytes and fill into intermediate 
buffer, and then skip it. Is that right? 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to