This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new 7a45abe [feature] support read map and struct type (#116)
7a45abe is described below
commit 7a45abe3f955e938f6161e95c43a911464996b5a
Author: gnehil <[email protected]>
AuthorDate: Tue Oct 17 11:58:35 2023 +0800
[feature] support read map and struct type (#116)
---
.../apache/doris/spark/serialization/RowBatch.java | 37 +++++
.../org/apache/doris/spark/sql/SchemaUtils.scala | 3 +-
.../doris/spark/serialization/TestRowBatch.java | 180 +++++++++++++++++++--
.../doris/spark/sql/TestSparkConnector.scala | 1 +
4 files changed, 203 insertions(+), 18 deletions(-)
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index 3d66db5..b43b0a2 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -37,12 +37,16 @@ import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.UnionMapReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
import java.io.ByteArrayInputStream;
import java.io.IOException;
@@ -52,7 +56,9 @@ import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.time.LocalDate;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.NoSuchElementException;
/**
@@ -338,6 +344,37 @@ public class RowBatch {
addValueToRow(rowIndex, value);
}
break;
+ case "MAP":
+
Preconditions.checkArgument(mt.equals(Types.MinorType.MAP),
+ typeMismatchMessage(currentType, mt));
+ MapVector mapVector = (MapVector) curFieldVector;
+ UnionMapReader reader = mapVector.getReader();
+ for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
+ if (mapVector.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ continue;
+ }
+ reader.setPosition(rowIndex);
+ Map<String, String> value = new HashMap<>();
+ while (reader.next()) {
+
value.put(reader.key().readObject().toString(),
reader.value().readObject().toString());
+ }
+ addValueToRow(rowIndex,
JavaConverters.mapAsScalaMapConverter(value).asScala());
+ }
+ break;
+ case "STRUCT":
+
Preconditions.checkArgument(mt.equals(Types.MinorType.STRUCT),
+ typeMismatchMessage(currentType, mt));
+ StructVector structVector = (StructVector)
curFieldVector;
+ for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
+ if (structVector.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ continue;
+ }
+ String value =
structVector.getObject(rowIndex).toString();
+ addValueToRow(rowIndex, value);
+ }
+ break;
default:
String errMsg = "Unsupported type " +
schema.get(col).getType();
logger.error(errMsg);
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index 677cc2e..44baa95 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -32,7 +32,6 @@ import org.slf4j.LoggerFactory
import java.sql.Timestamp
import java.time.{LocalDateTime, ZoneOffset}
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import scala.collection.mutable
private[spark] object SchemaUtils {
@@ -126,6 +125,8 @@ private[spark] object SchemaUtils {
case "TIME" => DataTypes.DoubleType
case "STRING" => DataTypes.StringType
case "ARRAY" => DataTypes.StringType
+ case "MAP" => MapType(DataTypes.StringType,
DataTypes.StringType)
+ case "STRUCT" => DataTypes.StringType
case "HLL" =>
throw new DorisException("Unsupported type " + dorisType)
case _ =>
diff --git
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
index ace928f..cb7e0b8 100644
---
a/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
+++
b/spark-doris-connector/src/test/java/org/apache/doris/spark/serialization/TestRowBatch.java
@@ -25,6 +25,8 @@ import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.Schema;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
@@ -39,6 +41,10 @@ import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.NullableStructWriter;
+import org.apache.arrow.vector.complex.impl.UnionMapWriter;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.FloatingPointPrecision;
@@ -53,11 +59,13 @@ import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import scala.collection.JavaConverters;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.util.Arrays;
import java.util.List;
@@ -100,7 +108,7 @@ public class TestRowBatch {
root.setRowCount(3);
FieldVector vector = root.getVector("k0");
- BitVector bitVector = (BitVector)vector;
+ BitVector bitVector = (BitVector) vector;
bitVector.setInitialCapacity(3);
bitVector.allocateNew(3);
bitVector.setSafe(0, 1);
@@ -109,7 +117,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k1");
- TinyIntVector tinyIntVector = (TinyIntVector)vector;
+ TinyIntVector tinyIntVector = (TinyIntVector) vector;
tinyIntVector.setInitialCapacity(3);
tinyIntVector.allocateNew(3);
tinyIntVector.setSafe(0, 1);
@@ -118,7 +126,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k2");
- SmallIntVector smallIntVector = (SmallIntVector)vector;
+ SmallIntVector smallIntVector = (SmallIntVector) vector;
smallIntVector.setInitialCapacity(3);
smallIntVector.allocateNew(3);
smallIntVector.setSafe(0, 1);
@@ -127,7 +135,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k3");
- IntVector intVector = (IntVector)vector;
+ IntVector intVector = (IntVector) vector;
intVector.setInitialCapacity(3);
intVector.allocateNew(3);
intVector.setSafe(0, 1);
@@ -136,7 +144,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k4");
- BigIntVector bigIntVector = (BigIntVector)vector;
+ BigIntVector bigIntVector = (BigIntVector) vector;
bigIntVector.setInitialCapacity(3);
bigIntVector.allocateNew(3);
bigIntVector.setSafe(0, 1);
@@ -145,7 +153,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k5");
- VarCharVector varCharVector = (VarCharVector)vector;
+ VarCharVector varCharVector = (VarCharVector) vector;
varCharVector.setInitialCapacity(3);
varCharVector.allocateNew();
varCharVector.setIndexDefined(0);
@@ -160,7 +168,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k6");
- VarCharVector charVector = (VarCharVector)vector;
+ VarCharVector charVector = (VarCharVector) vector;
charVector.setInitialCapacity(3);
charVector.allocateNew();
charVector.setIndexDefined(0);
@@ -175,7 +183,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k8");
- Float8Vector float8Vector = (Float8Vector)vector;
+ Float8Vector float8Vector = (Float8Vector) vector;
float8Vector.setInitialCapacity(3);
float8Vector.allocateNew(3);
float8Vector.setSafe(0, 1.1);
@@ -184,7 +192,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k9");
- Float4Vector float4Vector = (Float4Vector)vector;
+ Float4Vector float4Vector = (Float4Vector) vector;
float4Vector.setInitialCapacity(3);
float4Vector.allocateNew(3);
float4Vector.setSafe(0, 1.1f);
@@ -193,7 +201,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k10");
- VarCharVector datecharVector = (VarCharVector)vector;
+ VarCharVector datecharVector = (VarCharVector) vector;
datecharVector.setInitialCapacity(3);
datecharVector.allocateNew();
datecharVector.setIndexDefined(0);
@@ -208,7 +216,7 @@ public class TestRowBatch {
vector.setValueCount(3);
vector = root.getVector("k11");
- VarCharVector timecharVector = (VarCharVector)vector;
+ VarCharVector timecharVector = (VarCharVector) vector;
timecharVector.setInitialCapacity(3);
timecharVector.allocateNew();
timecharVector.setIndexDefined(0);
@@ -364,15 +372,15 @@ public class TestRowBatch {
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow0 = rowBatch.next();
- Assert.assertArrayEquals(binaryRow0, (byte[])actualRow0.get(0));
+ Assert.assertArrayEquals(binaryRow0, (byte[]) actualRow0.get(0));
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow1 = rowBatch.next();
- Assert.assertArrayEquals(binaryRow1, (byte[])actualRow1.get(0));
+ Assert.assertArrayEquals(binaryRow1, (byte[]) actualRow1.get(0));
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow2 = rowBatch.next();
- Assert.assertArrayEquals(binaryRow2, (byte[])actualRow2.get(0));
+ Assert.assertArrayEquals(binaryRow2, (byte[]) actualRow2.get(0));
Assert.assertFalse(rowBatch.hasNext());
thrown.expect(NoSuchElementException.class);
@@ -428,15 +436,15 @@ public class TestRowBatch {
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow0 = rowBatch.next();
- Assert.assertEquals(Decimal.apply(12340000000L, 11, 9),
(Decimal)actualRow0.get(0));
+ Assert.assertEquals(Decimal.apply(12340000000L, 11, 9), (Decimal)
actualRow0.get(0));
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow1 = rowBatch.next();
- Assert.assertEquals(Decimal.apply(88880000000L, 11, 9),
(Decimal)actualRow1.get(0));
+ Assert.assertEquals(Decimal.apply(88880000000L, 11, 9), (Decimal)
actualRow1.get(0));
Assert.assertTrue(rowBatch.hasNext());
List<Object> actualRow2 = rowBatch.next();
- Assert.assertEquals(Decimal.apply(10000000000L, 11, 9),
(Decimal)actualRow2.get(0));
+ Assert.assertEquals(Decimal.apply(10000000000L, 11, 9), (Decimal)
actualRow2.get(0));
Assert.assertFalse(rowBatch.hasNext());
thrown.expect(NoSuchElementException.class);
@@ -591,4 +599,142 @@ public class TestRowBatch {
}
+ @Test
+ public void testMap() throws IOException, DorisException {
+
+ ImmutableList<Field> mapChildren = ImmutableList.of(
+ new Field("child", new FieldType(false, new
ArrowType.Struct(), null),
+ ImmutableList.of(
+ new Field("key", new FieldType(false, new
ArrowType.Utf8(), null), null),
+ new Field("value", new FieldType(false, new
ArrowType.Int(32, true), null),
+ null)
+ )
+ ));
+
+ ImmutableList<Field> fields = ImmutableList.of(
+ new Field("col_map", new FieldType(false, new
ArrowType.Map(false), null),
+ mapChildren)
+ );
+
+ RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ VectorSchemaRoot root = VectorSchemaRoot.create(
+ new org.apache.arrow.vector.types.pojo.Schema(fields, null),
allocator);
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+ root,
+ new DictionaryProvider.MapDictionaryProvider(),
+ outputStream);
+
+ arrowStreamWriter.start();
+ root.setRowCount(3);
+
+ MapVector mapVector = (MapVector) root.getVector("col_map");
+ mapVector.allocateNew();
+ UnionMapWriter mapWriter = mapVector.getWriter();
+ for (int i = 0; i < 3; i++) {
+ mapWriter.setPosition(i);
+ mapWriter.startMap();
+ mapWriter.startEntry();
+ String key = "k" + (i + 1);
+ byte[] bytes = key.getBytes(StandardCharsets.UTF_8);
+ ArrowBuf buffer = allocator.buffer(bytes.length);
+ buffer.setBytes(0, bytes);
+ mapWriter.key().varChar().writeVarChar(0, bytes.length, buffer);
+ buffer.close();
+ mapWriter.value().integer().writeInt(i);
+ mapWriter.endEntry();
+ mapWriter.endMap();
+ }
+ mapWriter.setValueCount(3);
+
+ arrowStreamWriter.writeBatch();
+
+ arrowStreamWriter.end();
+ arrowStreamWriter.close();
+
+ TStatus status = new TStatus();
+ status.setStatusCode(TStatusCode.OK);
+ TScanBatchResult scanBatchResult = new TScanBatchResult();
+ scanBatchResult.setStatus(status);
+ scanBatchResult.setEos(false);
+ scanBatchResult.setRows(outputStream.toByteArray());
+
+ String schemaStr =
"{\"properties\":[{\"type\":\"MAP\",\"name\":\"col_map\",\"comment\":\"\"}" +
+ "], \"status\":200}";
+
+
+ Schema schema = RestService.parseSchema(schemaStr, logger);
+
+ RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+ Assert.assertTrue(rowBatch.hasNext());
+
Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k1",
"0")).asScala(),
+ rowBatch.next().get(0));
+ Assert.assertTrue(rowBatch.hasNext());
+
Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k2",
"1")).asScala(),
+ rowBatch.next().get(0));
+ Assert.assertTrue(rowBatch.hasNext());
+
Assert.assertEquals(JavaConverters.mapAsScalaMapConverter(ImmutableMap.of("k3",
"2")).asScala(),
+ rowBatch.next().get(0));
+ Assert.assertFalse(rowBatch.hasNext());
+
+ }
+
+ @Test
+ public void testStruct() throws IOException, DorisException {
+
+ ImmutableList<Field> fields = ImmutableList.of(
+ new Field("col_struct", new FieldType(false, new
ArrowType.Struct(), null),
+ ImmutableList.of(new Field("a", new FieldType(false,
new ArrowType.Utf8(), null), null),
+ new Field("b", new FieldType(false, new
ArrowType.Int(32, true), null), null))
+ ));
+
+ RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ VectorSchemaRoot root = VectorSchemaRoot.create(
+ new org.apache.arrow.vector.types.pojo.Schema(fields, null),
allocator);
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+ root,
+ new DictionaryProvider.MapDictionaryProvider(),
+ outputStream);
+
+ arrowStreamWriter.start();
+ root.setRowCount(3);
+
+ StructVector structVector = (StructVector)
root.getVector("col_struct");
+ structVector.allocateNew();
+ NullableStructWriter writer = structVector.getWriter();
+ writer.setPosition(0);
+ writer.start();
+ byte[] bytes = "a1".getBytes(StandardCharsets.UTF_8);
+ ArrowBuf buffer = allocator.buffer(bytes.length);
+ buffer.setBytes(0, bytes);
+ writer.varChar("a").writeVarChar(0, bytes.length, buffer);
+ buffer.close();
+ writer.integer("b").writeInt(1);
+ writer.end();
+ writer.setValueCount(1);
+
+ arrowStreamWriter.writeBatch();
+
+ arrowStreamWriter.end();
+ arrowStreamWriter.close();
+
+ TStatus status = new TStatus();
+ status.setStatusCode(TStatusCode.OK);
+ TScanBatchResult scanBatchResult = new TScanBatchResult();
+ scanBatchResult.setStatus(status);
+ scanBatchResult.setEos(false);
+ scanBatchResult.setRows(outputStream.toByteArray());
+
+ String schemaStr =
"{\"properties\":[{\"type\":\"STRUCT\",\"name\":\"col_struct\",\"comment\":\"\"}"
+
+ "], \"status\":200}";
+
+ Schema schema = RestService.parseSchema(schemaStr, logger);
+
+ RowBatch rowBatch = new RowBatch(scanBatchResult, schema);
+ Assert.assertTrue(rowBatch.hasNext());
+ Assert.assertEquals("{\"a\":\"a1\",\"b\":1}", rowBatch.next().get(0));
+
+ }
+
}
diff --git
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
index 54771df..3f05da2 100644
---
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
+++
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestSparkConnector.scala
@@ -115,5 +115,6 @@ class TestSparkConnector {
.start().awaitTermination()
spark.stop()
}
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]