This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new f720167  [Fix](reader) fix arrow timezone bug (#232)
f720167 is described below

commit f7201676f9263b3eb1cf19bf1b27decbd8774610
Author: wudi <676366...@qq.com>
AuthorDate: Wed Sep 25 11:42:29 2024 +0800

    [Fix](reader) fix arrow timezone bug (#232)
---
 .../apache/doris/spark/serialization/RowBatch.java |  23 +--
 .../doris/spark/serialization/TestRowBatch.java    | 229 +++++++++++++++++++--
 2 files changed, 224 insertions(+), 28 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index c3e70e9..ed61b59 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -17,11 +17,6 @@
 
 package org.apache.doris.spark.serialization;
 
-import org.apache.doris.sdk.thrift.TScanBatchResult;
-import org.apache.doris.spark.exception.DorisException;
-import org.apache.doris.spark.rest.models.Schema;
-import org.apache.doris.spark.util.IPUtils;
-
 import com.google.common.base.Preconditions;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.BaseIntVector;
@@ -48,8 +43,11 @@ import org.apache.arrow.vector.complex.impl.UnionMapReader;
 import org.apache.arrow.vector.ipc.ArrowReader;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
 import org.apache.arrow.vector.types.Types;
-import org.apache.arrow.vector.types.Types.MinorType;
 import org.apache.commons.lang3.ArrayUtils;
+import org.apache.doris.sdk.thrift.TScanBatchResult;
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.rest.models.Schema;
+import org.apache.doris.spark.util.IPUtils;
 import org.apache.spark.sql.types.Decimal;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -385,13 +383,6 @@ public class RowBatch {
                         break;
                     case "DATETIME":
                     case "DATETIMEV2":
-
-                        Preconditions.checkArgument(
-                                mt.equals(Types.MinorType.TIMESTAMPMICRO) || 
mt.equals(MinorType.VARCHAR) ||
-                                        mt.equals(MinorType.TIMESTAMPMILLI) || 
mt.equals(MinorType.TIMESTAMPSEC),
-                                typeMismatchMessage(currentType, mt));
-                        typeMismatchMessage(currentType, mt);
-
                         if (mt.equals(Types.MinorType.VARCHAR)) {
                             VarCharVector varCharVector = (VarCharVector) 
curFieldVector;
                             for (int rowIndex = 0; rowIndex < 
rowCountInOneBatch; rowIndex++) {
@@ -404,10 +395,8 @@ public class RowBatch {
                             }
                         } else if (curFieldVector instanceof TimeStampVector) {
                             TimeStampVector timeStampVector = 
(TimeStampVector) curFieldVector;
-
                             for (int rowIndex = 0; rowIndex < 
rowCountInOneBatch; rowIndex++) {
                                 if (timeStampVector.isNull(rowIndex)) {
-
                                     addValueToRow(rowIndex, null);
                                     continue;
                                 }
@@ -415,7 +404,9 @@ public class RowBatch {
                                 String formatted = 
DATE_TIME_FORMATTER.format(dateTime);
                                 addValueToRow(rowIndex, formatted);
                             }
-
+                        } else {
+                            String errMsg = String.format("Unsupported type 
for DATETIMEV2, minorType %s, class is %s", mt.name(), 
curFieldVector.getClass());
+                            throw new 
java.lang.IllegalArgumentException(errMsg);
                         }
                         break;
                     case "CHAR":
diff --git 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
index f12014f..20387e4 100644
--- 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
+++ 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
@@ -17,24 +17,13 @@
 
 package org.apache.doris.spark.serialization;
 
-import org.apache.arrow.vector.DateDayVector;
-import org.apache.arrow.vector.TimeStampMicroVector;
-import org.apache.arrow.vector.UInt4Vector;
-import org.apache.arrow.vector.types.DateUnit;
-import org.apache.arrow.vector.types.TimeUnit;
-import org.apache.doris.sdk.thrift.TScanBatchResult;
-import org.apache.doris.sdk.thrift.TStatus;
-import org.apache.doris.sdk.thrift.TStatusCode;
-import org.apache.doris.spark.exception.DorisException;
-import org.apache.doris.spark.rest.RestService;
-import org.apache.doris.spark.rest.models.Schema;
-
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import org.apache.arrow.memory.ArrowBuf;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.FixedSizeBinaryVector;
@@ -42,7 +31,14 @@ import org.apache.arrow.vector.Float4Vector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
 import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TimeStampMicroTZVector;
+import org.apache.arrow.vector.TimeStampMicroVector;
+import org.apache.arrow.vector.TimeStampMilliTZVector;
+import org.apache.arrow.vector.TimeStampMilliVector;
+import org.apache.arrow.vector.TimeStampSecTZVector;
+import org.apache.arrow.vector.TimeStampSecVector;
 import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.UInt4Vector;
 import org.apache.arrow.vector.VarBinaryVector;
 import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
@@ -52,11 +48,19 @@ import 
org.apache.arrow.vector.complex.impl.NullableStructWriter;
 import org.apache.arrow.vector.complex.impl.UnionMapWriter;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.ArrowStreamWriter;
+import org.apache.arrow.vector.types.DateUnit;
 import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.TimeUnit;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.FieldType;
 import org.apache.commons.lang3.ArrayUtils;
+import org.apache.doris.sdk.thrift.TScanBatchResult;
+import org.apache.doris.sdk.thrift.TStatus;
+import org.apache.doris.sdk.thrift.TStatusCode;
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.rest.RestService;
+import org.apache.doris.spark.rest.models.Schema;
 import org.apache.spark.sql.types.Decimal;
 import org.junit.Assert;
 import org.junit.Rule;
@@ -74,6 +78,8 @@ import java.nio.charset.StandardCharsets;
 import java.sql.Date;
 import java.time.LocalDateTime;
 import java.time.ZoneId;
+import java.time.format.DateTimeFormatter;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 import java.util.NoSuchElementException;
@@ -1163,4 +1169,203 @@ public class TestRowBatch {
         thrown.expectMessage(startsWith("Get row offset:"));
         rowBatch.next();
     }
+
+    @Test
+    public void timestampVector() throws IOException, DorisException {
+        List<Field> childrenBuilder = new ArrayList<>();
+        childrenBuilder.add(
+                new Field(
+                        "k0",
+                        FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.MICROSECOND, null)),
+                        null));
+        childrenBuilder.add(
+                new Field(
+                        "k1",
+                        FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
+                        null));
+        childrenBuilder.add(
+                new Field(
+                        "k2",
+                        FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.SECOND, null)),
+                        null));
+        childrenBuilder.add(
+                new Field(
+                        "k3",
+                        FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")),
+                        null));
+        childrenBuilder.add(
+                new Field(
+                        "k4",
+                        FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")),
+                        null));
+        childrenBuilder.add(
+                new Field(
+                        "k5",
+                        FieldType.nullable(new 
ArrowType.Timestamp(TimeUnit.SECOND, "UTC")),
+                        null));
+
+        VectorSchemaRoot root =
+                VectorSchemaRoot.create(
+                        new 
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder, null),
+                        new RootAllocator(Integer.MAX_VALUE));
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ArrowStreamWriter arrowStreamWriter =
+                new ArrowStreamWriter(
+                        root, new DictionaryProvider.MapDictionaryProvider(), 
outputStream);
+
+        arrowStreamWriter.start();
+        root.setRowCount(1);
+
+        FieldVector vector = root.getVector("k0");
+        TimeStampMicroVector mircoVec = (TimeStampMicroVector) vector;
+        mircoVec.allocateNew(1);
+        mircoVec.setIndexDefined(0);
+        mircoVec.setSafe(0, 1721892143586123L);
+        vector.setValueCount(1);
+
+        vector = root.getVector("k1");
+        TimeStampMilliVector milliVector = (TimeStampMilliVector) vector;
+        milliVector.allocateNew(1);
+        milliVector.setIndexDefined(0);
+        milliVector.setSafe(0, 1721892143586L);
+        vector.setValueCount(1);
+
+        vector = root.getVector("k2");
+        TimeStampSecVector secVector = (TimeStampSecVector) vector;
+        secVector.allocateNew(1);
+        secVector.setIndexDefined(0);
+        secVector.setSafe(0, 1721892143L);
+        vector.setValueCount(1);
+
+        vector = root.getVector("k3");
+        TimeStampMicroTZVector mircoTZVec = (TimeStampMicroTZVector) vector;
+        mircoTZVec.allocateNew(1);
+        mircoTZVec.setIndexDefined(0);
+        mircoTZVec.setSafe(0, 1721892143586123L);
+        vector.setValueCount(1);
+
+        vector = root.getVector("k4");
+        TimeStampMilliTZVector milliTZVector = (TimeStampMilliTZVector) vector;
+        milliTZVector.allocateNew(1);
+        milliTZVector.setIndexDefined(0);
+        milliTZVector.setSafe(0, 1721892143586L);
+        vector.setValueCount(1);
+
+        vector = root.getVector("k5");
+        TimeStampSecTZVector secTZVector = (TimeStampSecTZVector) vector;
+        secTZVector.allocateNew(1);
+        secTZVector.setIndexDefined(0);
+        secTZVector.setSafe(0, 1721892143L);
+        vector.setValueCount(1);
+
+        arrowStreamWriter.writeBatch();
+
+        arrowStreamWriter.end();
+        arrowStreamWriter.close();
+
+        TStatus status = new TStatus();
+        status.setStatusCode(TStatusCode.OK);
+        TScanBatchResult scanBatchResult = new TScanBatchResult();
+        scanBatchResult.setStatus(status);
+        scanBatchResult.setEos(false);
+        scanBatchResult.setRows(outputStream.toByteArray());
+
+        String schemaStr =
+                
"{\"properties\":[{\"type\":\"DATETIME\",\"name\":\"k0\",\"comment\":\"\"}, 
{\"type\":\"DATETIME\",\"name\":\"k1\",\"comment\":\"\"}, 
{\"type\":\"DATETIME\",\"name\":\"k2\",\"comment\":\"\"}, 
{\"type\":\"DATETIME\",\"name\":\"k3\",\"comment\":\"\"}, 
{\"type\":\"DATETIME\",\"name\":\"k4\",\"comment\":\"\"}, 
{\"type\":\"DATETIME\",\"name\":\"k5\",\"comment\":\"\"}],"
+                        + "\"status\":200}";
+
+        Schema schema = RestService.parseSchema(schemaStr, logger);
+        RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+        List<Object> next = rowBatch.next();
+        Assert.assertEquals(next.size(), 6);
+        Assert.assertEquals(
+                next.get(0),
+                LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586123000)
+                        .atZone(ZoneId.of("UTC+8"))
+                        .withZoneSameInstant(ZoneId.systemDefault())
+                        
.toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd 
HH:mm:ss.SSSSSS")));
+        Assert.assertEquals(
+                next.get(1),
+                LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586000000)
+                        .atZone(ZoneId.of("UTC+8"))
+                        .withZoneSameInstant(ZoneId.systemDefault())
+                        
.toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd 
HH:mm:ss.SSS")));
+        Assert.assertEquals(
+                next.get(2),
+                LocalDateTime.of(2024, 7, 25, 15, 22, 23, 0)
+                        .atZone(ZoneId.of("UTC+8"))
+                        .withZoneSameInstant(ZoneId.systemDefault())
+                        
.toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
+        Assert.assertEquals(
+                next.get(3),
+                LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586123000)
+                        .atZone(ZoneId.of("UTC+8"))
+                        .withZoneSameInstant(ZoneId.systemDefault())
+                        
.toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd 
HH:mm:ss.SSSSSS")));
+        Assert.assertEquals(
+                next.get(4),
+                LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586000000)
+                        .atZone(ZoneId.of("UTC+8"))
+                        .withZoneSameInstant(ZoneId.systemDefault())
+                        
.toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd 
HH:mm:ss.SSS")));
+        Assert.assertEquals(
+                next.get(5),
+                LocalDateTime.of(2024, 7, 25, 15, 22, 23, 0)
+                        .atZone(ZoneId.of("UTC+8"))
+                        .withZoneSameInstant(ZoneId.systemDefault())
+                        
.toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")));
+    }
+
+    @Test
+    public void timestampTypeNotMatch() throws IOException, DorisException {
+        List<Field> childrenBuilder = new ArrayList<>();
+        childrenBuilder.add(
+                new Field(
+                        "k0",
+                        FieldType.nullable(new ArrowType.Int(32, false)),
+                        null));
+
+        VectorSchemaRoot root =
+                VectorSchemaRoot.create(
+                        new 
org.apache.arrow.vector.types.pojo.Schema(childrenBuilder, null),
+                        new RootAllocator(Integer.MAX_VALUE));
+        ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+        ArrowStreamWriter arrowStreamWriter =
+                new ArrowStreamWriter(
+                        root, new DictionaryProvider.MapDictionaryProvider(), 
outputStream);
+
+        arrowStreamWriter.start();
+        root.setRowCount(1);
+
+        FieldVector vector = root.getVector("k0");
+        UInt4Vector uInt4Vector = (UInt4Vector) vector;
+        uInt4Vector.setInitialCapacity(1);
+        uInt4Vector.allocateNew();
+        uInt4Vector.setIndexDefined(0);
+        uInt4Vector.setSafe(0, 0);
+
+        vector.setValueCount(1);
+        arrowStreamWriter.writeBatch();
+
+        arrowStreamWriter.end();
+        arrowStreamWriter.close();
+
+        TStatus status = new TStatus();
+        status.setStatusCode(TStatusCode.OK);
+        TScanBatchResult scanBatchResult = new TScanBatchResult();
+        scanBatchResult.setStatus(status);
+        scanBatchResult.setEos(false);
+        scanBatchResult.setRows(outputStream.toByteArray());
+
+        String schemaStr =
+                "{\"properties\":["
+                        + 
"{\"type\":\"DATETIMEV2\",\"name\":\"k0\",\"comment\":\"\"}"
+                        + "], \"status\":200}";
+
+        Schema schema = RestService.parseSchema(schemaStr, logger);
+        thrown.expect(DorisException.class);
+        thrown.expectMessage(startsWith("Unsupported type for DATETIMEV2"));
+        new RowBatch(scanBatchResult, schema);
+    }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to