This is an automated email from the ASF dual-hosted git repository.
yuzelin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 516d8dd429 [arrow] Fix that complex writers didn't reset inner writer
state (#6591)
516d8dd429 is described below
commit 516d8dd42919b7ccc4dfe865f69bd26937a36b23
Author: yuzelin <[email protected]>
AuthorDate: Wed Nov 12 18:01:53 2025 +0800
[arrow] Fix that complex writers didn't reset inner writer state (#6591)
---
.../paimon/arrow/writer/ArrowFieldWriters.java | 15 ++-
.../paimon/arrow/vector/ArrowFormatWriterTest.java | 145 +++++++++++++++++++--
2 files changed, 148 insertions(+), 12 deletions(-)
diff --git
a/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
b/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
index 1c7bb742f5..9e4f371a79 100644
---
a/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
+++
b/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
@@ -526,7 +526,8 @@ public class ArrowFieldWriters {
@Override
public void reset() {
- fieldVector.reset();
+ super.reset();
+ elementWriter.reset();
offset = 0;
}
@@ -613,7 +614,9 @@ public class ArrowFieldWriters {
@Override
public void reset() {
- fieldVector.reset();
+ super.reset();
+ keyWriter.reset();
+ valueWriter.reset();
offset = 0;
}
@@ -769,6 +772,14 @@ public class ArrowFieldWriters {
this.fieldWriters = fieldWriters;
}
+ @Override
+ public void reset() {
+ super.reset();
+ for (ArrowFieldWriter fieldWriter : fieldWriters) {
+ fieldWriter.reset();
+ }
+ }
+
@Override
protected void doWrite(
ColumnVector columnVector,
diff --git
a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
index 76df181683..6436be955d 100644
---
a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
+++
b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
@@ -22,6 +22,8 @@ import org.apache.paimon.arrow.ArrowBundleRecords;
import org.apache.paimon.arrow.reader.ArrowBatchReader;
import org.apache.paimon.data.BinaryString;
import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.GenericArray;
+import org.apache.paimon.data.GenericMap;
import org.apache.paimon.data.GenericRow;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.data.Timestamp;
@@ -34,8 +36,12 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.OutOfMemoryException;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
-import org.assertj.core.api.Assertions;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
@@ -43,11 +49,15 @@ import org.junit.jupiter.params.provider.ValueSource;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
+import static org.assertj.core.api.Assertions.assertThat;
+
/** Test for {@link org.apache.paimon.arrow.vector.ArrowFormatWriter}. */
public class ArrowFormatWriterTest {
@@ -115,7 +125,7 @@ public class ArrowFormatWriterTest {
InternalRow expectec = list.get(i);
for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
- Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+ assertThat(fieldGetter.getFieldOrNull(actual))
.isEqualTo(fieldGetter.getFieldOrNull(expectec));
}
}
@@ -158,7 +168,7 @@ public class ArrowFormatWriterTest {
InternalRow expectec = list.get(i);
for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
- Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+ assertThat(fieldGetter.getFieldOrNull(actual))
.isEqualTo(fieldGetter.getFieldOrNull(expectec));
}
}
@@ -194,9 +204,9 @@ public class ArrowFormatWriterTest {
if (limitMemory) {
for (int i = 0; i < 64; i++) {
- Assertions.assertThat(writer.write(genericRow)).isTrue();
+ assertThat(writer.write(genericRow)).isTrue();
}
- Assertions.assertThat(writer.write(genericRow)).isFalse();
+ assertThat(writer.write(genericRow)).isFalse();
}
writer.reset();
@@ -211,8 +221,8 @@ public class ArrowFormatWriterTest {
}
if (limitMemory) {
-
Assertions.assertThat(writer.memoryUsed()).isLessThan(memoryLimit);
-
Assertions.assertThat(writer.getAllocator().getAllocatedMemory())
+ assertThat(writer.memoryUsed()).isLessThan(memoryLimit);
+ assertThat(writer.getAllocator().getAllocatedMemory())
.isGreaterThan(memoryLimit)
.isLessThan(2 * memoryLimit);
}
@@ -244,7 +254,7 @@ public class ArrowFormatWriterTest {
InternalRow expectec = list.get(i);
for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
- Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+ assertThat(fieldGetter.getFieldOrNull(actual))
.isEqualTo(fieldGetter.getFieldOrNull(expectec));
}
}
@@ -290,11 +300,126 @@ public class ArrowFormatWriterTest {
}
writer.flush();
ArrowCStruct cStruct = writer.toCStruct();
- Assertions.assertThat(cStruct).isNotNull();
+ assertThat(cStruct).isNotNull();
writer.release();
}
}
+ @Test
+ public void testWriteArrayMapTwice() {
+ try (ArrowFormatWriter arrowFormatWriter =
+ new ArrowFormatWriter(
+ RowType.of(
+ DataTypes.ARRAY(
+ DataTypes.MAP(DataTypes.STRING(),
DataTypes.STRING()))),
+ 1,
+ true)) {
+ writeAndCheckArrayMap(arrowFormatWriter);
+ writeAndCheckArrayMap(arrowFormatWriter);
+ }
+ }
+
+ private void writeAndCheckArrayMap(ArrowFormatWriter arrowFormatWriter) {
+ GenericRow genericRow = new GenericRow(1);
+ Map<BinaryString, BinaryString> map = new HashMap<>();
+ map.put(BinaryString.fromString("a"), BinaryString.fromString("b"));
+ map.put(BinaryString.fromString("c"), BinaryString.fromString("d"));
+ GenericArray array = new GenericArray(new Object[] {new
GenericMap(map)});
+ genericRow.setField(0, array);
+ arrowFormatWriter.write(genericRow);
+ arrowFormatWriter.flush();
+
+ VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+ ListVector listVector = (ListVector) vsr.getVector(0);
+ MapVector mapVector = (MapVector) listVector.getDataVector();
+ assertThat(mapVector.getValueCount()).isEqualTo(1);
+ VarCharVector keyVector =
+ (VarCharVector)
mapVector.getDataVector().getChildrenFromFields().get(0);
+ assertThat(keyVector.getValueCount()).isEqualTo(2);
+ assertThat(new String(keyVector.get(0))).isEqualTo("a");
+ assertThat(new String(keyVector.get(1))).isEqualTo("c");
+ VarCharVector valueVector =
+ (VarCharVector)
mapVector.getDataVector().getChildrenFromFields().get(1);
+ assertThat(valueVector.getValueCount()).isEqualTo(2);
+ assertThat(new String(valueVector.get(0))).isEqualTo("b");
+ assertThat(new String(valueVector.get(1))).isEqualTo("d");
+ arrowFormatWriter.reset();
+ }
+
+ @Test
+ public void testWriteMapArrayTwice() {
+ try (ArrowFormatWriter arrowFormatWriter =
+ new ArrowFormatWriter(
+ RowType.of(
+ DataTypes.MAP(DataTypes.INT(),
DataTypes.ARRAY(DataTypes.INT()))),
+ 1,
+ true)) {
+ writeAndCheckMapArray(arrowFormatWriter);
+ writeAndCheckMapArray(arrowFormatWriter);
+ }
+ }
+
+ private void writeAndCheckMapArray(ArrowFormatWriter arrowFormatWriter) {
+ GenericRow genericRow = new GenericRow(1);
+ GenericArray array1 = new GenericArray(new Object[] {1, 2});
+ GenericArray array2 = new GenericArray(new Object[] {3, 4});
+ Map<Integer, GenericArray> map = new HashMap<>();
+ map.put(1, array1);
+ map.put(2, array2);
+ GenericMap genericMap = new GenericMap(map);
+ genericRow.setField(0, genericMap);
+ arrowFormatWriter.write(genericRow);
+ arrowFormatWriter.flush();
+
+ VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+ MapVector mapVector = (MapVector) vsr.getVector(0);
+ IntVector keyVector = (IntVector)
mapVector.getDataVector().getChildrenFromFields().get(0);
+ assertThat(keyVector.getValueCount()).isEqualTo(2);
+ assertThat(keyVector.get(0)).isEqualTo(1);
+ assertThat(keyVector.get(1)).isEqualTo(2);
+ ListVector valueVector =
+ (ListVector)
mapVector.getDataVector().getChildrenFromFields().get(1);
+ assertThat(valueVector.getValueCount()).isEqualTo(2);
+ IntVector innerValueVector = (IntVector) valueVector.getDataVector();
+ assertThat(innerValueVector.getValueCount()).isEqualTo(4);
+ assertThat(innerValueVector.get(0)).isEqualTo(1);
+ assertThat(innerValueVector.get(1)).isEqualTo(2);
+ assertThat(innerValueVector.get(2)).isEqualTo(3);
+ assertThat(innerValueVector.get(3)).isEqualTo(4);
+ arrowFormatWriter.reset();
+ }
+
+ @Test
+ public void testWriteRowArrayTwice() {
+ try (ArrowFormatWriter arrowFormatWriter =
+ new ArrowFormatWriter(
+
RowType.of(DataTypes.ROW(DataTypes.ARRAY(DataTypes.INT()))), 1, true)) {
+ writeAndCheckRowArray(arrowFormatWriter);
+ writeAndCheckRowArray(arrowFormatWriter);
+ }
+ }
+
+ private void writeAndCheckRowArray(ArrowFormatWriter arrowFormatWriter) {
+ GenericRow genericRow = new GenericRow(1);
+ GenericRow innerRow = new GenericRow(1);
+ GenericArray array = new GenericArray(new Object[] {1, 2});
+ innerRow.setField(0, array);
+ genericRow.setField(0, innerRow);
+ arrowFormatWriter.write(genericRow);
+ arrowFormatWriter.flush();
+
+ VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+ assertThat(vsr.getRowCount()).isEqualTo(1);
+ StructVector structVector = (StructVector) vsr.getVector(0);
+ ListVector listVector = (ListVector)
structVector.getChildrenFromFields().get(0);
+ assertThat(listVector.getValueCount()).isEqualTo(1);
+ IntVector dataVector = (IntVector) listVector.getDataVector();
+ assertThat(dataVector.getValueCount()).isEqualTo(2);
+ assertThat(dataVector.get(0)).isEqualTo(1);
+ assertThat(dataVector.get(1)).isEqualTo(2);
+ arrowFormatWriter.reset();
+ }
+
private void writeAndCheck(ArrowFormatCWriter writer) {
List<InternalRow> list = new ArrayList<>();
List<InternalRow.FieldGetter> fieldGetters = new ArrayList<>();
@@ -320,7 +445,7 @@ public class ArrowFormatWriterTest {
InternalRow expectec = list.get(i);
for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
- Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+ assertThat(fieldGetter.getFieldOrNull(actual))
.isEqualTo(fieldGetter.getFieldOrNull(expectec));
}
}