This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new f9b7ac2e92 GH-37841: [Java] Dictionary decoding not using the
compression factory from the ArrowReader (#38371)
f9b7ac2e92 is described below
commit f9b7ac2e922bceed8bab09b1e28d7261cbe8b41d
Author: Vibhatha Lakmal Abeykoon <[email protected]>
AuthorDate: Thu Feb 1 23:08:21 2024 +0530
GH-37841: [Java] Dictionary decoding not using the compression factory from
the ArrowReader (#38371)
### Rationale for this change
This PR addresses https://github.com/apache/arrow/issues/37841.
### What changes are included in this PR?
Adding compression-based write and read for Dictionary data.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No
* Closes: #37841
Lead-authored-by: Vibhatha Lakmal Abeykoon <[email protected]>
Co-authored-by: vibhatha <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../TestArrowReaderWriterWithCompression.java | 206 ++++++++++++++++++---
.../org/apache/arrow/vector/ipc/ArrowReader.java | 2 +-
.../org/apache/arrow/vector/ipc/ArrowWriter.java | 23 ++-
3 files changed, 201 insertions(+), 30 deletions(-)
diff --git
a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
index 6104cb1a13..af28333746 100644
---
a/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
+++
b/java/compression/src/test/java/org/apache/arrow/compression/TestArrowReaderWriterWithCompression.java
@@ -18,7 +18,9 @@
package org.apache.arrow.compression;
import java.io.ByteArrayOutputStream;
+import java.io.IOException;
import java.nio.channels.Channels;
+import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
@@ -27,63 +29,223 @@ import java.util.Optional;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.GenerateSampleData;
+import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowFileReader;
import org.apache.arrow.vector.ipc.ArrowFileWriter;
+import org.apache.arrow.vector.ipc.ArrowStreamReader;
+import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
+import org.junit.After;
import org.junit.Assert;
-import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
public class TestArrowReaderWriterWithCompression {
- @Test
- public void testArrowFileZstdRoundTrip() throws Exception {
- // Prepare sample data
- final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ private BufferAllocator allocator;
+ private ByteArrayOutputStream out;
+ private VectorSchemaRoot root;
+
+ @BeforeEach
+ public void setup() {
+ if (allocator == null) {
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ }
+ out = new ByteArrayOutputStream();
+ root = null;
+ }
+
+ @After
+ public void tearDown() {
+ if (root != null) {
+ root.close();
+ }
+ if (allocator != null) {
+ allocator.close();
+ }
+ if (out != null) {
+ out.reset();
+ }
+
+ }
+
+ private void createAndWriteArrowFile(DictionaryProvider provider,
+ CompressionUtil.CodecType codecType) throws IOException {
List<Field> fields = new ArrayList<>();
fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()),
new ArrayList<>()));
- VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(fields),
allocator);
+ root = VectorSchemaRoot.create(new Schema(fields), allocator);
+
final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);
- // Write an in-memory compressed arrow file
- ByteArrayOutputStream out = new ByteArrayOutputStream();
- try (final ArrowFileWriter writer =
- new ArrowFileWriter(root, null, Channels.newChannel(out), new
HashMap<>(),
- IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE,
CompressionUtil.CodecType.ZSTD, Optional.of(7))) {
+ try (final ArrowFileWriter writer = new ArrowFileWriter(root, provider,
Channels.newChannel(out),
+ new HashMap<>(), IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
}
+ }
+
+ private void createAndWriteArrowStream(DictionaryProvider provider,
+ CompressionUtil.CodecType codecType)
throws IOException {
+ List<Field> fields = new ArrayList<>();
+ fields.add(new Field("col", FieldType.notNullable(new ArrowType.Utf8()),
new ArrayList<>()));
+ root = VectorSchemaRoot.create(new Schema(fields), allocator);
+
+ final int rowCount = 10;
+ GenerateSampleData.generateTestData(root.getVector(0), rowCount);
+ root.setRowCount(rowCount);
+
+ try (final ArrowStreamWriter writer = new ArrowStreamWriter(root,
provider, Channels.newChannel(out),
+ IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType,
Optional.of(7))) {
+ writer.start();
+ writer.writeBatch();
+ writer.end();
+ }
+ }
- // Read the in-memory compressed arrow file with CommonsCompressionFactory
provided
+ private Dictionary createDictionary(VarCharVector dictionaryVector) {
+ setVector(dictionaryVector,
+ "foo".getBytes(StandardCharsets.UTF_8),
+ "bar".getBytes(StandardCharsets.UTF_8),
+ "baz".getBytes(StandardCharsets.UTF_8));
+
+ return new Dictionary(dictionaryVector,
+ new DictionaryEncoding(/*id=*/1L, /*ordered=*/false,
/*indexType=*/null));
+ }
+
+ @Test
+ public void testArrowFileZstdRoundTrip() throws Exception {
+ createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
+ // with compression
+ try (ArrowFileReader reader =
+ new ArrowFileReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ CommonsCompressionFactory.INSTANCE)) {
+ Assertions.assertEquals(1, reader.getRecordBlocks().size());
+ Assertions.assertTrue(reader.loadNextBatch());
+ Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
+ Assertions.assertFalse(reader.loadNextBatch());
+ }
+ // without compression
try (ArrowFileReader reader =
- new ArrowFileReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()),
- allocator, CommonsCompressionFactory.INSTANCE)) {
- Assert.assertEquals(1, reader.getRecordBlocks().size());
+ new ArrowFileReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ NoCompressionCodec.Factory.INSTANCE)) {
+ Assertions.assertEquals(1, reader.getRecordBlocks().size());
+ Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+ reader::loadNextBatch);
+ Assertions.assertEquals("Please add arrow-compression module to use
CommonsCompressionFactory for ZSTD",
+ exception.getMessage());
+ }
+ }
+
+ @Test
+ public void testArrowStreamZstdRoundTrip() throws Exception {
+ createAndWriteArrowStream(null, CompressionUtil.CodecType.ZSTD);
+ // with compression
+ try (ArrowStreamReader reader =
+ new ArrowStreamReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ CommonsCompressionFactory.INSTANCE)) {
Assert.assertTrue(reader.loadNextBatch());
Assert.assertTrue(root.equals(reader.getVectorSchemaRoot()));
Assert.assertFalse(reader.loadNextBatch());
}
+ // without compression
+ try (ArrowStreamReader reader =
+ new ArrowStreamReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ NoCompressionCodec.Factory.INSTANCE)) {
+ Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+ reader::loadNextBatch);
+ Assert.assertEquals(
+ "Please add arrow-compression module to use
CommonsCompressionFactory for ZSTD",
+ exception.getMessage()
+ );
+ }
+ }
- // Read the in-memory compressed arrow file without CompressionFactory
provided
+ @Test
+ public void testArrowFileZstdRoundTripWithDictionary() throws Exception {
+ VarCharVector dictionaryVector = (VarCharVector)
+ FieldType.nullable(new
ArrowType.Utf8()).createNewSingleVector("f1_file", allocator, null);
+ Dictionary dictionary = createDictionary(dictionaryVector);
+ DictionaryProvider.MapDictionaryProvider provider = new
DictionaryProvider.MapDictionaryProvider();
+ provider.put(dictionary);
+
+ createAndWriteArrowFile(provider, CompressionUtil.CodecType.ZSTD);
+
+ // with compression
+ try (ArrowFileReader reader =
+ new ArrowFileReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ CommonsCompressionFactory.INSTANCE)) {
+ Assertions.assertEquals(1, reader.getRecordBlocks().size());
+ Assertions.assertTrue(reader.loadNextBatch());
+ Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
+ Assertions.assertFalse(reader.loadNextBatch());
+ }
+ // without compression
try (ArrowFileReader reader =
- new ArrowFileReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()),
- allocator, NoCompressionCodec.Factory.INSTANCE)) {
- Assert.assertEquals(1, reader.getRecordBlocks().size());
+ new ArrowFileReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ NoCompressionCodec.Factory.INSTANCE)) {
+ Assertions.assertEquals(1, reader.getRecordBlocks().size());
+ Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+ reader::loadNextBatch);
+ Assertions.assertEquals("Please add arrow-compression module to use
CommonsCompressionFactory for ZSTD",
+ exception.getMessage());
+ }
+ dictionaryVector.close();
+ }
+
+ @Test
+ public void testArrowStreamZstdRoundTripWithDictionary() throws Exception {
+ VarCharVector dictionaryVector = (VarCharVector)
+ FieldType.nullable(new
ArrowType.Utf8()).createNewSingleVector("f1_stream", allocator, null);
+ Dictionary dictionary = createDictionary(dictionaryVector);
+ DictionaryProvider.MapDictionaryProvider provider = new
DictionaryProvider.MapDictionaryProvider();
+ provider.put(dictionary);
+
+ createAndWriteArrowStream(provider, CompressionUtil.CodecType.ZSTD);
+
+ // with compression
+ try (ArrowStreamReader reader =
+ new ArrowStreamReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ CommonsCompressionFactory.INSTANCE)) {
+ Assertions.assertTrue(reader.loadNextBatch());
+ Assertions.assertTrue(root.equals(reader.getVectorSchemaRoot()));
+ Assertions.assertFalse(reader.loadNextBatch());
+ }
+ // without compression
+ try (ArrowStreamReader reader =
+ new ArrowStreamReader(new
ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator,
+ NoCompressionCodec.Factory.INSTANCE)) {
+ Exception exception = Assert.assertThrows(IllegalArgumentException.class,
+ reader::loadNextBatch);
+ Assertions.assertEquals("Please add arrow-compression module to use
CommonsCompressionFactory for ZSTD",
+ exception.getMessage());
+ }
+ dictionaryVector.close();
+ }
- Exception exception =
Assert.assertThrows(IllegalArgumentException.class, () ->
reader.loadNextBatch());
- String expectedMessage = "Please add arrow-compression module to use
CommonsCompressionFactory for ZSTD";
- Assert.assertEquals(expectedMessage, exception.getMessage());
+ public static void setVector(VarCharVector vector, byte[]... values) {
+ final int length = values.length;
+ vector.allocateNewSafe();
+ for (int i = 0; i < length; i++) {
+ if (values[i] != null) {
+ vector.set(i, values[i]);
+ }
}
+ vector.setValueCount(length);
}
}
diff --git
a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
index 04c57d7e82..01f4e925c6 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowReader.java
@@ -251,7 +251,7 @@ public abstract class ArrowReader implements
DictionaryProvider, AutoCloseable {
VectorSchemaRoot root = new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector), 0);
- VectorLoader loader = new VectorLoader(root);
+ VectorLoader loader = new VectorLoader(root, this.compressionFactory);
try {
loader.load(dictionaryBatch.getDictionary());
} finally {
diff --git
a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index a33c55de53..1cc201ae56 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -61,9 +61,14 @@ public abstract class ArrowWriter implements AutoCloseable {
private final DictionaryProvider dictionaryProvider;
private final Set<Long> dictionaryIdsUsed = new HashSet<>();
+ private final CompressionCodec.Factory compressionFactory;
+ private final CompressionUtil.CodecType codecType;
+ private final Optional<Integer> compressionLevel;
private boolean started = false;
private boolean ended = false;
+ private final CompressionCodec codec;
+
protected IpcOption option;
protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider,
WritableByteChannel out) {
@@ -89,16 +94,19 @@ public abstract class ArrowWriter implements AutoCloseable {
protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider,
WritableByteChannel out, IpcOption option,
CompressionCodec.Factory compressionFactory,
CompressionUtil.CodecType codecType,
Optional<Integer> compressionLevel) {
- this.unloader = new VectorUnloader(
- root, /*includeNullCount*/ true,
- compressionLevel.isPresent() ?
- compressionFactory.createCodec(codecType, compressionLevel.get()) :
- compressionFactory.createCodec(codecType),
- /*alignBuffers*/ true);
this.out = new WriteChannel(out);
this.option = option;
this.dictionaryProvider = provider;
+ this.compressionFactory = compressionFactory;
+ this.codecType = codecType;
+ this.compressionLevel = compressionLevel;
+ this.codec = this.compressionLevel.isPresent() ?
+ this.compressionFactory.createCodec(this.codecType,
this.compressionLevel.get()) :
+ this.compressionFactory.createCodec(this.codecType);
+ this.unloader = new VectorUnloader(root, /*includeNullCount*/ true, codec,
+ /*alignBuffers*/ true);
+
List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());
MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(),
option.metadataVersion);
@@ -133,7 +141,8 @@ public abstract class ArrowWriter implements AutoCloseable {
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
- VectorUnloader unloader = new VectorUnloader(dictRoot);
+ VectorUnloader unloader = new VectorUnloader(dictRoot,
/*includeNullCount*/ true, this.codec,
+ /*alignBuffers*/ true);
ArrowRecordBatch batch = unloader.getRecordBatch();
ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch,
false);
try {