divijvaidya commented on code in PR #13135: URL: https://github.com/apache/kafka/pull/13135#discussion_r1200358594
########## clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java: ########## @@ -367,32 +403,164 @@ public void testStreamingIteratorConsistency() { } } - @Test - public void testSkipKeyValueIteratorCorrectness() { - Header[] headers = {new RecordHeader("k1", "v1".getBytes()), new RecordHeader("k2", "v2".getBytes())}; + @ParameterizedTest + @EnumSource(value = CompressionType.class) + public void testSkipKeyValueIteratorCorrectness(CompressionType compressionType) throws NoSuchAlgorithmException { + Header[] headers = {new RecordHeader("k1", "v1".getBytes()), new RecordHeader("k2", null)}; + byte[] largeRecordValue = new byte[200 * 1024]; // 200KB + RANDOM.nextBytes(largeRecordValue); MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, - CompressionType.LZ4, TimestampType.CREATE_TIME, + compressionType, TimestampType.CREATE_TIME, + // one sample with small value size new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), - new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), - new SimpleRecord(3L, "c".getBytes(), "3".getBytes()), - new SimpleRecord(1000L, "abc".getBytes(), "0".getBytes()), + // one sample with null value + new SimpleRecord(2L, "b".getBytes(), null), + // one sample with null key + new SimpleRecord(3L, null, "3".getBytes()), + // one sample with null key and null value + new SimpleRecord(4L, null, (byte[]) null), + // one sample with large value size + new SimpleRecord(1000L, "abc".getBytes(), largeRecordValue), + // one sample with headers, one of the header has null value new SimpleRecord(9999L, "abc".getBytes(), "0".getBytes(), headers) ); + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); - try (CloseableIterator<Record> streamingIterator = batch.skipKeyValueIterator(BufferSupplier.NO_CACHING)) { - assertEquals(Arrays.asList( - new PartialDefaultRecord(9, (byte) 0, 0L, 1L, -1, 1, 1), - new PartialDefaultRecord(9, (byte) 0, 1L, 2L, -1, 1, 1), - new PartialDefaultRecord(9, (byte) 0, 2L, 3L, -1, 1, 1), - new PartialDefaultRecord(12, (byte) 0, 3L, 1000L, -1, 3, 1), - new PartialDefaultRecord(25, (byte) 0, 4L, 9999L, -1, 3, 1) - ), - Utils.toList(streamingIterator) - ); + + try (BufferSupplier bufferSupplier = BufferSupplier.create(); + CloseableIterator<Record> skipKeyValueIterator = batch.skipKeyValueIterator(bufferSupplier)) { + + if (CompressionType.NONE == compressionType) { + // assert that for uncompressed data stream record iterator is not used + assertTrue(skipKeyValueIterator instanceof DefaultRecordBatch.RecordIterator); + // superficial validation for correctness. Deep validation is already performed in other tests + assertEquals(Utils.toList(records.records()).size(), Utils.toList(skipKeyValueIterator).size()); + } else { + // assert that a streaming iterator is used for compressed records + assertTrue(skipKeyValueIterator instanceof DefaultRecordBatch.StreamRecordIterator); + // assert correctness for compressed records + assertIterableEquals(Arrays.asList( + new PartialDefaultRecord(9, (byte) 0, 0L, 1L, -1, 1, 1), + new PartialDefaultRecord(8, (byte) 0, 1L, 2L, -1, 1, -1), + new PartialDefaultRecord(8, (byte) 0, 2L, 3L, -1, -1, 1), + new PartialDefaultRecord(7, (byte) 0, 3L, 4L, -1, -1, -1), + new PartialDefaultRecord(15 + largeRecordValue.length, (byte) 0, 4L, 1000L, -1, 3, largeRecordValue.length), + new PartialDefaultRecord(23, (byte) 0, 5L, 9999L, -1, 3, 1) + ), Utils.toList(skipKeyValueIterator)); + } + } + } + + @ParameterizedTest + @MethodSource + public void testBufferReuseInSkipKeyValueIterator(CompressionType compressionType, int expectedNumBufferAllocations, byte[] recordValue) { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + compressionType, TimestampType.CREATE_TIME, + new SimpleRecord(1000L, "a".getBytes(), "0".getBytes()), + new SimpleRecord(9999L, "b".getBytes(), recordValue) + ); + + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + + try (BufferSupplier bufferSupplier = spy(BufferSupplier.create()); + CloseableIterator<Record> streamingIterator = batch.skipKeyValueIterator(bufferSupplier)) { + + // Consume through the iterator + Utils.toList(streamingIterator); + + // Close the iterator to release any buffers + streamingIterator.close(); + + // assert number of buffer allocations + verify(bufferSupplier, times(expectedNumBufferAllocations)).get(anyInt()); + verify(bufferSupplier, times(expectedNumBufferAllocations)).release(any(ByteBuffer.class)); + } + } + private static Stream<Arguments> testBufferReuseInSkipKeyValueIterator() throws NoSuchAlgorithmException { + byte[] smallRecordValue = "1".getBytes(); + byte[] largeRecordValue = new byte[512 * 1024]; // 512KB + RANDOM.nextBytes(largeRecordValue); + + return Stream.of( + /* + * 1 allocation per batch (i.e. per iterator instance) for buffer holding uncompressed data + * = 1 buffer allocations + */ + Arguments.of(CompressionType.GZIP, 1, smallRecordValue), + Arguments.of(CompressionType.GZIP, 1, largeRecordValue), + Arguments.of(CompressionType.SNAPPY, 1, smallRecordValue), + Arguments.of(CompressionType.SNAPPY, 1, largeRecordValue), + /* + * 1 allocation per batch (i.e. per iterator instance) for buffer holding compressed data + * 1 allocation per batch (i.e. per iterator instance) for buffer holding uncompressed data + * = 2 buffer allocations + */ + Arguments.of(CompressionType.LZ4, 2, smallRecordValue), + Arguments.of(CompressionType.LZ4, 2, largeRecordValue), + Arguments.of(CompressionType.ZSTD, 2, smallRecordValue), + Arguments.of(CompressionType.ZSTD, 2, largeRecordValue) + ); + } + + @ParameterizedTest + @MethodSource + public void testZstdJniForSkipKeyValueIterator(int expectedJniCalls, byte[] recordValue) throws IOException { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.ZSTD, TimestampType.CREATE_TIME, + new SimpleRecord(9L, "hakuna-matata".getBytes(), recordValue) + ); + + // Buffer containing compressed data + final ByteBuffer compressedBuf = records.buffer(); + // Create a RecordBatch object + final DefaultRecordBatch batch = spy(new DefaultRecordBatch(compressedBuf.duplicate())); + final CompressionType mockCompression = mock(CompressionType.ZSTD.getClass()); + doReturn(mockCompression).when(batch).compressionType(); + + // Buffer containing compressed records to be used for creating zstd-jni stream + ByteBuffer recordsBuffer = compressedBuf.duplicate(); + recordsBuffer.position(RECORDS_OFFSET); + + try (final BufferSupplier bufferSupplier = BufferSupplier.create(); + final InputStream zstdStream = spy(ZstdFactory.wrapForInput(recordsBuffer, batch.magic(), bufferSupplier)); + final InputStream chunkedStream = new ChunkedBytesStream(zstdStream, bufferSupplier, 16 * 1024, false)) { + + when(mockCompression.wrapForInput(any(ByteBuffer.class), anyByte(), any(BufferSupplier.class))).thenReturn(chunkedStream); + + try (CloseableIterator<Record> streamingIterator = batch.skipKeyValueIterator(bufferSupplier)) { + assertNotNull(streamingIterator); + Utils.toList(streamingIterator); + // verify the number of read() calls to zstd JNI stream. Each read() call is a JNI call. + verify(zstdStream, times(expectedJniCalls)).read(any(byte[].class), anyInt(), anyInt()); + // verify that we don't use the underlying skip() functionality. The underlying skip() allocates + // 1 buffer per skip call from he buffer pool whereas our implementation Review Comment: Fixed -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org