This is an automated email from the ASF dual-hosted git repository.

yuzelin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 516d8dd429 [arrow] Fix that complex writers didn't reset inner writer 
state (#6591)
516d8dd429 is described below

commit 516d8dd42919b7ccc4dfe865f69bd26937a36b23
Author: yuzelin <[email protected]>
AuthorDate: Wed Nov 12 18:01:53 2025 +0800

    [arrow] Fix that complex writers didn't reset inner writer state (#6591)
---
 .../paimon/arrow/writer/ArrowFieldWriters.java     |  15 ++-
 .../paimon/arrow/vector/ArrowFormatWriterTest.java | 145 +++++++++++++++++++--
 2 files changed, 148 insertions(+), 12 deletions(-)

diff --git 
a/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
 
b/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
index 1c7bb742f5..9e4f371a79 100644
--- 
a/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
+++ 
b/paimon-arrow/src/main/java/org/apache/paimon/arrow/writer/ArrowFieldWriters.java
@@ -526,7 +526,8 @@ public class ArrowFieldWriters {
 
         @Override
         public void reset() {
-            fieldVector.reset();
+            super.reset();
+            elementWriter.reset();
             offset = 0;
         }
 
@@ -613,7 +614,9 @@ public class ArrowFieldWriters {
 
         @Override
         public void reset() {
-            fieldVector.reset();
+            super.reset();
+            keyWriter.reset();
+            valueWriter.reset();
             offset = 0;
         }
 
@@ -769,6 +772,14 @@ public class ArrowFieldWriters {
             this.fieldWriters = fieldWriters;
         }
 
+        @Override
+        public void reset() {
+            super.reset();
+            for (ArrowFieldWriter fieldWriter : fieldWriters) {
+                fieldWriter.reset();
+            }
+        }
+
         @Override
         protected void doWrite(
                 ColumnVector columnVector,
diff --git 
a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
 
b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
index 76df181683..6436be955d 100644
--- 
a/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
+++ 
b/paimon-arrow/src/test/java/org/apache/paimon/arrow/vector/ArrowFormatWriterTest.java
@@ -22,6 +22,8 @@ import org.apache.paimon.arrow.ArrowBundleRecords;
 import org.apache.paimon.arrow.reader.ArrowBatchReader;
 import org.apache.paimon.data.BinaryString;
 import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.GenericArray;
+import org.apache.paimon.data.GenericMap;
 import org.apache.paimon.data.GenericRow;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.data.Timestamp;
@@ -34,8 +36,12 @@ import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.memory.OutOfMemoryException;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
-import org.assertj.core.api.Assertions;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.ValueSource;
@@ -43,11 +49,15 @@ import org.junit.jupiter.params.provider.ValueSource;
 import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 import java.util.concurrent.ThreadLocalRandom;
 
+import static org.assertj.core.api.Assertions.assertThat;
+
 /** Test for {@link org.apache.paimon.arrow.vector.ArrowFormatWriter}. */
 public class ArrowFormatWriterTest {
 
@@ -115,7 +125,7 @@ public class ArrowFormatWriterTest {
                 InternalRow expectec = list.get(i);
 
                 for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
-                    Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+                    assertThat(fieldGetter.getFieldOrNull(actual))
                             .isEqualTo(fieldGetter.getFieldOrNull(expectec));
                 }
             }
@@ -158,7 +168,7 @@ public class ArrowFormatWriterTest {
                 InternalRow expectec = list.get(i);
 
                 for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
-                    Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+                    assertThat(fieldGetter.getFieldOrNull(actual))
                             .isEqualTo(fieldGetter.getFieldOrNull(expectec));
                 }
             }
@@ -194,9 +204,9 @@ public class ArrowFormatWriterTest {
 
             if (limitMemory) {
                 for (int i = 0; i < 64; i++) {
-                    Assertions.assertThat(writer.write(genericRow)).isTrue();
+                    assertThat(writer.write(genericRow)).isTrue();
                 }
-                Assertions.assertThat(writer.write(genericRow)).isFalse();
+                assertThat(writer.write(genericRow)).isFalse();
             }
             writer.reset();
 
@@ -211,8 +221,8 @@ public class ArrowFormatWriterTest {
             }
 
             if (limitMemory) {
-                
Assertions.assertThat(writer.memoryUsed()).isLessThan(memoryLimit);
-                
Assertions.assertThat(writer.getAllocator().getAllocatedMemory())
+                assertThat(writer.memoryUsed()).isLessThan(memoryLimit);
+                assertThat(writer.getAllocator().getAllocatedMemory())
                         .isGreaterThan(memoryLimit)
                         .isLessThan(2 * memoryLimit);
             }
@@ -244,7 +254,7 @@ public class ArrowFormatWriterTest {
                 InternalRow expectec = list.get(i);
 
                 for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
-                    Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+                    assertThat(fieldGetter.getFieldOrNull(actual))
                             .isEqualTo(fieldGetter.getFieldOrNull(expectec));
                 }
             }
@@ -290,11 +300,126 @@ public class ArrowFormatWriterTest {
             }
             writer.flush();
             ArrowCStruct cStruct = writer.toCStruct();
-            Assertions.assertThat(cStruct).isNotNull();
+            assertThat(cStruct).isNotNull();
             writer.release();
         }
     }
 
+    @Test
+    public void testWriteArrayMapTwice() {
+        try (ArrowFormatWriter arrowFormatWriter =
+                new ArrowFormatWriter(
+                        RowType.of(
+                                DataTypes.ARRAY(
+                                        DataTypes.MAP(DataTypes.STRING(), 
DataTypes.STRING()))),
+                        1,
+                        true)) {
+            writeAndCheckArrayMap(arrowFormatWriter);
+            writeAndCheckArrayMap(arrowFormatWriter);
+        }
+    }
+
+    private void writeAndCheckArrayMap(ArrowFormatWriter arrowFormatWriter) {
+        GenericRow genericRow = new GenericRow(1);
+        Map<BinaryString, BinaryString> map = new HashMap<>();
+        map.put(BinaryString.fromString("a"), BinaryString.fromString("b"));
+        map.put(BinaryString.fromString("c"), BinaryString.fromString("d"));
+        GenericArray array = new GenericArray(new Object[] {new 
GenericMap(map)});
+        genericRow.setField(0, array);
+        arrowFormatWriter.write(genericRow);
+        arrowFormatWriter.flush();
+
+        VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+        ListVector listVector = (ListVector) vsr.getVector(0);
+        MapVector mapVector = (MapVector) listVector.getDataVector();
+        assertThat(mapVector.getValueCount()).isEqualTo(1);
+        VarCharVector keyVector =
+                (VarCharVector) 
mapVector.getDataVector().getChildrenFromFields().get(0);
+        assertThat(keyVector.getValueCount()).isEqualTo(2);
+        assertThat(new String(keyVector.get(0))).isEqualTo("a");
+        assertThat(new String(keyVector.get(1))).isEqualTo("c");
+        VarCharVector valueVector =
+                (VarCharVector) 
mapVector.getDataVector().getChildrenFromFields().get(1);
+        assertThat(valueVector.getValueCount()).isEqualTo(2);
+        assertThat(new String(valueVector.get(0))).isEqualTo("b");
+        assertThat(new String(valueVector.get(1))).isEqualTo("d");
+        arrowFormatWriter.reset();
+    }
+
+    @Test
+    public void testWriteMapArrayTwice() {
+        try (ArrowFormatWriter arrowFormatWriter =
+                new ArrowFormatWriter(
+                        RowType.of(
+                                DataTypes.MAP(DataTypes.INT(), 
DataTypes.ARRAY(DataTypes.INT()))),
+                        1,
+                        true)) {
+            writeAndCheckMapArray(arrowFormatWriter);
+            writeAndCheckMapArray(arrowFormatWriter);
+        }
+    }
+
+    private void writeAndCheckMapArray(ArrowFormatWriter arrowFormatWriter) {
+        GenericRow genericRow = new GenericRow(1);
+        GenericArray array1 = new GenericArray(new Object[] {1, 2});
+        GenericArray array2 = new GenericArray(new Object[] {3, 4});
+        Map<Integer, GenericArray> map = new HashMap<>();
+        map.put(1, array1);
+        map.put(2, array2);
+        GenericMap genericMap = new GenericMap(map);
+        genericRow.setField(0, genericMap);
+        arrowFormatWriter.write(genericRow);
+        arrowFormatWriter.flush();
+
+        VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+        MapVector mapVector = (MapVector) vsr.getVector(0);
+        IntVector keyVector = (IntVector) 
mapVector.getDataVector().getChildrenFromFields().get(0);
+        assertThat(keyVector.getValueCount()).isEqualTo(2);
+        assertThat(keyVector.get(0)).isEqualTo(1);
+        assertThat(keyVector.get(1)).isEqualTo(2);
+        ListVector valueVector =
+                (ListVector) 
mapVector.getDataVector().getChildrenFromFields().get(1);
+        assertThat(valueVector.getValueCount()).isEqualTo(2);
+        IntVector innerValueVector = (IntVector) valueVector.getDataVector();
+        assertThat(innerValueVector.getValueCount()).isEqualTo(4);
+        assertThat(innerValueVector.get(0)).isEqualTo(1);
+        assertThat(innerValueVector.get(1)).isEqualTo(2);
+        assertThat(innerValueVector.get(2)).isEqualTo(3);
+        assertThat(innerValueVector.get(3)).isEqualTo(4);
+        arrowFormatWriter.reset();
+    }
+
+    @Test
+    public void testWriteRowArrayTwice() {
+        try (ArrowFormatWriter arrowFormatWriter =
+                new ArrowFormatWriter(
+                        
RowType.of(DataTypes.ROW(DataTypes.ARRAY(DataTypes.INT()))), 1, true)) {
+            writeAndCheckRowArray(arrowFormatWriter);
+            writeAndCheckRowArray(arrowFormatWriter);
+        }
+    }
+
+    private void writeAndCheckRowArray(ArrowFormatWriter arrowFormatWriter) {
+        GenericRow genericRow = new GenericRow(1);
+        GenericRow innerRow = new GenericRow(1);
+        GenericArray array = new GenericArray(new Object[] {1, 2});
+        innerRow.setField(0, array);
+        genericRow.setField(0, innerRow);
+        arrowFormatWriter.write(genericRow);
+        arrowFormatWriter.flush();
+
+        VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
+        assertThat(vsr.getRowCount()).isEqualTo(1);
+        StructVector structVector = (StructVector) vsr.getVector(0);
+        ListVector listVector = (ListVector) 
structVector.getChildrenFromFields().get(0);
+        assertThat(listVector.getValueCount()).isEqualTo(1);
+        IntVector dataVector = (IntVector) listVector.getDataVector();
+        assertThat(dataVector.getValueCount()).isEqualTo(2);
+        assertThat(dataVector.get(0)).isEqualTo(1);
+        assertThat(dataVector.get(1)).isEqualTo(2);
+        arrowFormatWriter.reset();
+    }
+
     private void writeAndCheck(ArrowFormatCWriter writer) {
         List<InternalRow> list = new ArrayList<>();
         List<InternalRow.FieldGetter> fieldGetters = new ArrayList<>();
@@ -320,7 +445,7 @@ public class ArrowFormatWriterTest {
             InternalRow expectec = list.get(i);
 
             for (InternalRow.FieldGetter fieldGetter : fieldGetters) {
-                Assertions.assertThat(fieldGetter.getFieldOrNull(actual))
+                assertThat(fieldGetter.getFieldOrNull(actual))
                         .isEqualTo(fieldGetter.getFieldOrNull(expectec));
             }
         }

Reply via email to