This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new f720167 [Fix](reader) fix arrow timezone bug (#232) f720167 is described below commit f7201676f9263b3eb1cf19bf1b27decbd8774610 Author: wudi <676366...@qq.com> AuthorDate: Wed Sep 25 11:42:29 2024 +0800 [Fix](reader) fix arrow timezone bug (#232) --- .../apache/doris/spark/serialization/RowBatch.java | 23 +-- .../doris/spark/serialization/TestRowBatch.java | 229 +++++++++++++++++++-- 2 files changed, 224 insertions(+), 28 deletions(-) diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index c3e70e9..ed61b59 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -17,11 +17,6 @@ package org.apache.doris.spark.serialization; -import org.apache.doris.sdk.thrift.TScanBatchResult; -import org.apache.doris.spark.exception.DorisException; -import org.apache.doris.spark.rest.models.Schema; -import org.apache.doris.spark.util.IPUtils; - import com.google.common.base.Preconditions; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BaseIntVector; @@ -48,8 +43,11 @@ import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; -import org.apache.arrow.vector.types.Types.MinorType; import org.apache.commons.lang3.ArrayUtils; +import org.apache.doris.sdk.thrift.TScanBatchResult; +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.rest.models.Schema; +import org.apache.doris.spark.util.IPUtils; import org.apache.spark.sql.types.Decimal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -385,13 +383,6 @@ public class RowBatch { break; case "DATETIME": case "DATETIMEV2": - - Preconditions.checkArgument( - mt.equals(Types.MinorType.TIMESTAMPMICRO) || mt.equals(MinorType.VARCHAR) || - mt.equals(MinorType.TIMESTAMPMILLI) || mt.equals(MinorType.TIMESTAMPSEC), - typeMismatchMessage(currentType, mt)); - typeMismatchMessage(currentType, mt); - if (mt.equals(Types.MinorType.VARCHAR)) { VarCharVector varCharVector = (VarCharVector) curFieldVector; for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { @@ -404,10 +395,8 @@ public class RowBatch { } } else if (curFieldVector instanceof TimeStampVector) { TimeStampVector timeStampVector = (TimeStampVector) curFieldVector; - for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) { if (timeStampVector.isNull(rowIndex)) { - addValueToRow(rowIndex, null); continue; } @@ -415,7 +404,9 @@ public class RowBatch { String formatted = DATE_TIME_FORMATTER.format(dateTime); addValueToRow(rowIndex, formatted); } - + } else { + String errMsg = String.format("Unsupported type for DATETIMEV2, minorType %s, class is %s", mt.name(), curFieldVector.getClass()); + throw new java.lang.IllegalArgumentException(errMsg); } break; case "CHAR": diff --git a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java index f12014f..20387e4 100644 --- a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java +++ b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java @@ -17,24 +17,13 @@ package org.apache.doris.spark.serialization; -import org.apache.arrow.vector.DateDayVector; -import org.apache.arrow.vector.TimeStampMicroVector; -import org.apache.arrow.vector.UInt4Vector; -import org.apache.arrow.vector.types.DateUnit; -import org.apache.arrow.vector.types.TimeUnit; -import org.apache.doris.sdk.thrift.TScanBatchResult; -import org.apache.doris.sdk.thrift.TStatus; -import org.apache.doris.sdk.thrift.TStatusCode; -import org.apache.doris.spark.exception.DorisException; -import org.apache.doris.spark.rest.RestService; -import org.apache.doris.spark.rest.models.Schema; - import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; @@ -42,7 +31,14 @@ import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -52,11 +48,19 @@ import org.apache.arrow.vector.complex.impl.NullableStructWriter; import org.apache.arrow.vector.complex.impl.UnionMapWriter; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.commons.lang3.ArrayUtils; +import org.apache.doris.sdk.thrift.TScanBatchResult; +import org.apache.doris.sdk.thrift.TStatus; +import org.apache.doris.sdk.thrift.TStatusCode; +import org.apache.doris.spark.exception.DorisException; +import org.apache.doris.spark.rest.RestService; +import org.apache.doris.spark.rest.models.Schema; import org.apache.spark.sql.types.Decimal; import org.junit.Assert; import org.junit.Rule; @@ -74,6 +78,8 @@ import java.nio.charset.StandardCharsets; import java.sql.Date; import java.time.LocalDateTime; import java.time.ZoneId; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; @@ -1163,4 +1169,203 @@ public class TestRowBatch { thrown.expectMessage(startsWith("Get row offset:")); rowBatch.next(); } + + @Test + public void timestampVector() throws IOException, DorisException { + List<Field> childrenBuilder = new ArrayList<>(); + childrenBuilder.add( + new Field( + "k0", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)), + null)); + childrenBuilder.add( + new Field( + "k1", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), + null)); + childrenBuilder.add( + new Field( + "k2", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.SECOND, null)), + null)); + childrenBuilder.add( + new Field( + "k3", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")), + null)); + childrenBuilder.add( + new Field( + "k4", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), + null)); + childrenBuilder.add( + new Field( + "k5", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.SECOND, "UTC")), + null)); + + VectorSchemaRoot root = + VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder, null), + new RootAllocator(Integer.MAX_VALUE)); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = + new ArrowStreamWriter( + root, new DictionaryProvider.MapDictionaryProvider(), outputStream); + + arrowStreamWriter.start(); + root.setRowCount(1); + + FieldVector vector = root.getVector("k0"); + TimeStampMicroVector mircoVec = (TimeStampMicroVector) vector; + mircoVec.allocateNew(1); + mircoVec.setIndexDefined(0); + mircoVec.setSafe(0, 1721892143586123L); + vector.setValueCount(1); + + vector = root.getVector("k1"); + TimeStampMilliVector milliVector = (TimeStampMilliVector) vector; + milliVector.allocateNew(1); + milliVector.setIndexDefined(0); + milliVector.setSafe(0, 1721892143586L); + vector.setValueCount(1); + + vector = root.getVector("k2"); + TimeStampSecVector secVector = (TimeStampSecVector) vector; + secVector.allocateNew(1); + secVector.setIndexDefined(0); + secVector.setSafe(0, 1721892143L); + vector.setValueCount(1); + + vector = root.getVector("k3"); + TimeStampMicroTZVector mircoTZVec = (TimeStampMicroTZVector) vector; + mircoTZVec.allocateNew(1); + mircoTZVec.setIndexDefined(0); + mircoTZVec.setSafe(0, 1721892143586123L); + vector.setValueCount(1); + + vector = root.getVector("k4"); + TimeStampMilliTZVector milliTZVector = (TimeStampMilliTZVector) vector; + milliTZVector.allocateNew(1); + milliTZVector.setIndexDefined(0); + milliTZVector.setSafe(0, 1721892143586L); + vector.setValueCount(1); + + vector = root.getVector("k5"); + TimeStampSecTZVector secTZVector = (TimeStampSecTZVector) vector; + secTZVector.allocateNew(1); + secTZVector.setIndexDefined(0); + secTZVector.setSafe(0, 1721892143L); + vector.setValueCount(1); + + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + String schemaStr = + "{\"properties\":[{\"type\":\"DATETIME\",\"name\":\"k0\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k1\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k2\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k3\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k4\",\"comment\":\"\"}, {\"type\":\"DATETIME\",\"name\":\"k5\",\"comment\":\"\"}]," + + "\"status\":200}"; + + Schema schema = RestService.parseSchema(schemaStr, logger); + RowBatch rowBatch = new RowBatch(scanBatchResult, schema); + List<Object> next = rowBatch.next(); + Assert.assertEquals(next.size(), 6); + Assert.assertEquals( + next.get(0), + LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586123000) + .atZone(ZoneId.of("UTC+8")) + .withZoneSameInstant(ZoneId.systemDefault()) + .toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"))); + Assert.assertEquals( + next.get(1), + LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586000000) + .atZone(ZoneId.of("UTC+8")) + .withZoneSameInstant(ZoneId.systemDefault()) + .toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS"))); + Assert.assertEquals( + next.get(2), + LocalDateTime.of(2024, 7, 25, 15, 22, 23, 0) + .atZone(ZoneId.of("UTC+8")) + .withZoneSameInstant(ZoneId.systemDefault()) + .toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))); + Assert.assertEquals( + next.get(3), + LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586123000) + .atZone(ZoneId.of("UTC+8")) + .withZoneSameInstant(ZoneId.systemDefault()) + .toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"))); + Assert.assertEquals( + next.get(4), + LocalDateTime.of(2024, 7, 25, 15, 22, 23, 586000000) + .atZone(ZoneId.of("UTC+8")) + .withZoneSameInstant(ZoneId.systemDefault()) + .toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS"))); + Assert.assertEquals( + next.get(5), + LocalDateTime.of(2024, 7, 25, 15, 22, 23, 0) + .atZone(ZoneId.of("UTC+8")) + .withZoneSameInstant(ZoneId.systemDefault()) + .toLocalDateTime().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))); + } + + @Test + public void timestampTypeNotMatch() throws IOException, DorisException { + List<Field> childrenBuilder = new ArrayList<>(); + childrenBuilder.add( + new Field( + "k0", + FieldType.nullable(new ArrowType.Int(32, false)), + null)); + + VectorSchemaRoot root = + VectorSchemaRoot.create( + new org.apache.arrow.vector.types.pojo.Schema(childrenBuilder, null), + new RootAllocator(Integer.MAX_VALUE)); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + ArrowStreamWriter arrowStreamWriter = + new ArrowStreamWriter( + root, new DictionaryProvider.MapDictionaryProvider(), outputStream); + + arrowStreamWriter.start(); + root.setRowCount(1); + + FieldVector vector = root.getVector("k0"); + UInt4Vector uInt4Vector = (UInt4Vector) vector; + uInt4Vector.setInitialCapacity(1); + uInt4Vector.allocateNew(); + uInt4Vector.setIndexDefined(0); + uInt4Vector.setSafe(0, 0); + + vector.setValueCount(1); + arrowStreamWriter.writeBatch(); + + arrowStreamWriter.end(); + arrowStreamWriter.close(); + + TStatus status = new TStatus(); + status.setStatusCode(TStatusCode.OK); + TScanBatchResult scanBatchResult = new TScanBatchResult(); + scanBatchResult.setStatus(status); + scanBatchResult.setEos(false); + scanBatchResult.setRows(outputStream.toByteArray()); + + String schemaStr = + "{\"properties\":[" + + "{\"type\":\"DATETIMEV2\",\"name\":\"k0\",\"comment\":\"\"}" + + "], \"status\":200}"; + + Schema schema = RestService.parseSchema(schemaStr, logger); + thrown.expect(DorisException.class); + thrown.expectMessage(startsWith("Unsupported type for DATETIMEV2")); + new RowBatch(scanBatchResult, schema); + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org