This is an automated email from the ASF dual-hosted git repository. yuzelin pushed a commit to branch release-1.3 in repository https://gitbox.apache.org/repos/asf/paimon.git
commit cafcf4e511ae3dfccaa5790d06c7cf6bf4e21d23 Author: yuzelin <[email protected]> AuthorDate: Wed Nov 12 18:01:53 2025 +0800 [arrow] Fix that complex writers didn't reset inner writer state (#6591) (cherry picked from commit 516d8dd42919b7ccc4dfe865f69bd26937a36b23) --- .../paimon/arrow/writer/ArrowFieldWriters.java | 15 ++- .../paimon/arrow/vector/ArrowFormatWriterTest.java | 145 +++++++++++++++++++-- 2 files changed, 148 insertions(+), 12 deletions(-) diff --git a/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java b/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java index 1c7bb742f5..9e4f371a79 100644 --- a/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java +++ b/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java @@ -526,7 +526,8 @@ public class ArrowFieldWriters { @Override public void reset() { - fieldVector.reset(); + super.reset(); + elementWriter.reset(); offset = 0; } @@ -613,7 +614,9 @@ public class ArrowFieldWriters { @Override public void reset() { - fieldVector.reset(); + super.reset(); + keyWriter.reset(); + valueWriter.reset(); offset = 0; } @@ -769,6 +772,14 @@ public class ArrowFieldWriters { this.fieldWriters = fieldWriters; } + @Override + public void reset() { + super.reset(); + for (ArrowFieldWriter fieldWriter : fieldWriters) { + fieldWriter.reset(); + } + } + @Override protected void doWrite( ColumnVector columnVector, diff --git a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java index 76df181683..6436be955d 100644 --- a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java +++ b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java @@ -22,6 +22,8 @@ import org.apache.paimon.arrow.ArrowBundleRecords; import org.apache.paimon.arrow.reader.ArrowBatchReader; import org.apache.paimon.data.BinaryString; import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.GenericArray; +import org.apache.paimon.data.GenericMap; import org.apache.paimon.data.GenericRow; import org.apache.paimon.data.InternalRow; import org.apache.paimon.data.Timestamp; @@ -34,8 +36,12 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.assertj.core.api.Assertions; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -43,11 +49,15 @@ import org.junit.jupiter.params.provider.ValueSource; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; +import static org.assertj.core.api.Assertions.assertThat; + /** Test for {@link org.apache.paimon.arrow.vector.ArrowFormatWriter}. */ public class ArrowFormatWriterTest { @@ -115,7 +125,7 @@ public class ArrowFormatWriterTest { InternalRow expectec = list.get(i); for (InternalRow.FieldGetter fieldGetter : fieldGetters) { - Assertions.assertThat(fieldGetter.getFieldOrNull(actual)) + assertThat(fieldGetter.getFieldOrNull(actual)) .isEqualTo(fieldGetter.getFieldOrNull(expectec)); } } @@ -158,7 +168,7 @@ public class ArrowFormatWriterTest { InternalRow expectec = list.get(i); for (InternalRow.FieldGetter fieldGetter : fieldGetters) { - Assertions.assertThat(fieldGetter.getFieldOrNull(actual)) + assertThat(fieldGetter.getFieldOrNull(actual)) .isEqualTo(fieldGetter.getFieldOrNull(expectec)); } } @@ -194,9 +204,9 @@ public class ArrowFormatWriterTest { if (limitMemory) { for (int i = 0; i < 64; i++) { - Assertions.assertThat(writer.write(genericRow)).isTrue(); + assertThat(writer.write(genericRow)).isTrue(); } - Assertions.assertThat(writer.write(genericRow)).isFalse(); + assertThat(writer.write(genericRow)).isFalse(); } writer.reset(); @@ -211,8 +221,8 @@ public class ArrowFormatWriterTest { } if (limitMemory) { - Assertions.assertThat(writer.memoryUsed()).isLessThan(memoryLimit); - Assertions.assertThat(writer.getAllocator().getAllocatedMemory()) + assertThat(writer.memoryUsed()).isLessThan(memoryLimit); + assertThat(writer.getAllocator().getAllocatedMemory()) .isGreaterThan(memoryLimit) .isLessThan(2 * memoryLimit); } @@ -244,7 +254,7 @@ public class ArrowFormatWriterTest { InternalRow expectec = list.get(i); for (InternalRow.FieldGetter fieldGetter : fieldGetters) { - Assertions.assertThat(fieldGetter.getFieldOrNull(actual)) + assertThat(fieldGetter.getFieldOrNull(actual)) .isEqualTo(fieldGetter.getFieldOrNull(expectec)); } } @@ -290,11 +300,126 @@ public class ArrowFormatWriterTest { } writer.flush(); ArrowCStruct cStruct = writer.toCStruct(); - Assertions.assertThat(cStruct).isNotNull(); + assertThat(cStruct).isNotNull(); writer.release(); } } + @Test + public void testWriteArrayMapTwice() { + try (ArrowFormatWriter arrowFormatWriter = + new ArrowFormatWriter( + RowType.of( + DataTypes.ARRAY( + DataTypes.MAP(DataTypes.STRING(), DataTypes.STRING()))), + 1, + true)) { + writeAndCheckArrayMap(arrowFormatWriter); + writeAndCheckArrayMap(arrowFormatWriter); + } + } + + private void writeAndCheckArrayMap(ArrowFormatWriter arrowFormatWriter) { + GenericRow genericRow = new GenericRow(1); + Map<BinaryString, BinaryString> map = new HashMap<>(); + map.put(BinaryString.fromString("a"), BinaryString.fromString("b")); + map.put(BinaryString.fromString("c"), BinaryString.fromString("d")); + GenericArray array = new GenericArray(new Object[] {new GenericMap(map)}); + genericRow.setField(0, array); + arrowFormatWriter.write(genericRow); + arrowFormatWriter.flush(); + + VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); + ListVector listVector = (ListVector) vsr.getVector(0); + MapVector mapVector = (MapVector) listVector.getDataVector(); + assertThat(mapVector.getValueCount()).isEqualTo(1); + VarCharVector keyVector = + (VarCharVector) mapVector.getDataVector().getChildrenFromFields().get(0); + assertThat(keyVector.getValueCount()).isEqualTo(2); + assertThat(new String(keyVector.get(0))).isEqualTo("a"); + assertThat(new String(keyVector.get(1))).isEqualTo("c"); + VarCharVector valueVector = + (VarCharVector) mapVector.getDataVector().getChildrenFromFields().get(1); + assertThat(valueVector.getValueCount()).isEqualTo(2); + assertThat(new String(valueVector.get(0))).isEqualTo("b"); + assertThat(new String(valueVector.get(1))).isEqualTo("d"); + arrowFormatWriter.reset(); + } + + @Test + public void testWriteMapArrayTwice() { + try (ArrowFormatWriter arrowFormatWriter = + new ArrowFormatWriter( + RowType.of( + DataTypes.MAP(DataTypes.INT(), DataTypes.ARRAY(DataTypes.INT()))), + 1, + true)) { + writeAndCheckMapArray(arrowFormatWriter); + writeAndCheckMapArray(arrowFormatWriter); + } + } + + private void writeAndCheckMapArray(ArrowFormatWriter arrowFormatWriter) { + GenericRow genericRow = new GenericRow(1); + GenericArray array1 = new GenericArray(new Object[] {1, 2}); + GenericArray array2 = new GenericArray(new Object[] {3, 4}); + Map<Integer, GenericArray> map = new HashMap<>(); + map.put(1, array1); + map.put(2, array2); + GenericMap genericMap = new GenericMap(map); + genericRow.setField(0, genericMap); + arrowFormatWriter.write(genericRow); + arrowFormatWriter.flush(); + + VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); + MapVector mapVector = (MapVector) vsr.getVector(0); + IntVector keyVector = (IntVector) mapVector.getDataVector().getChildrenFromFields().get(0); + assertThat(keyVector.getValueCount()).isEqualTo(2); + assertThat(keyVector.get(0)).isEqualTo(1); + assertThat(keyVector.get(1)).isEqualTo(2); + ListVector valueVector = + (ListVector) mapVector.getDataVector().getChildrenFromFields().get(1); + assertThat(valueVector.getValueCount()).isEqualTo(2); + IntVector innerValueVector = (IntVector) valueVector.getDataVector(); + assertThat(innerValueVector.getValueCount()).isEqualTo(4); + assertThat(innerValueVector.get(0)).isEqualTo(1); + assertThat(innerValueVector.get(1)).isEqualTo(2); + assertThat(innerValueVector.get(2)).isEqualTo(3); + assertThat(innerValueVector.get(3)).isEqualTo(4); + arrowFormatWriter.reset(); + } + + @Test + public void testWriteRowArrayTwice() { + try (ArrowFormatWriter arrowFormatWriter = + new ArrowFormatWriter( + RowType.of(DataTypes.ROW(DataTypes.ARRAY(DataTypes.INT()))), 1, true)) { + writeAndCheckRowArray(arrowFormatWriter); + writeAndCheckRowArray(arrowFormatWriter); + } + } + + private void writeAndCheckRowArray(ArrowFormatWriter arrowFormatWriter) { + GenericRow genericRow = new GenericRow(1); + GenericRow innerRow = new GenericRow(1); + GenericArray array = new GenericArray(new Object[] {1, 2}); + innerRow.setField(0, array); + genericRow.setField(0, innerRow); + arrowFormatWriter.write(genericRow); + arrowFormatWriter.flush(); + + VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); + assertThat(vsr.getRowCount()).isEqualTo(1); + StructVector structVector = (StructVector) vsr.getVector(0); + ListVector listVector = (ListVector) structVector.getChildrenFromFields().get(0); + assertThat(listVector.getValueCount()).isEqualTo(1); + IntVector dataVector = (IntVector) listVector.getDataVector(); + assertThat(dataVector.getValueCount()).isEqualTo(2); + assertThat(dataVector.get(0)).isEqualTo(1); + assertThat(dataVector.get(1)).isEqualTo(2); + arrowFormatWriter.reset(); + } + private void writeAndCheck(ArrowFormatCWriter writer) { List<InternalRow> list = new ArrayList<>(); List<InternalRow.FieldGetter> fieldGetters = new ArrayList<>(); @@ -320,7 +445,7 @@ public class ArrowFormatWriterTest { InternalRow expectec = list.get(i); for (InternalRow.FieldGetter fieldGetter : fieldGetters) { - Assertions.assertThat(fieldGetter.getFieldOrNull(actual)) + assertThat(fieldGetter.getFieldOrNull(actual)) .isEqualTo(fieldGetter.getFieldOrNull(expectec)); } }
