This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new 20b1228 [fix](connector) fix arrow deserialize issue due to data
being inconsistent with column order (#256)
20b1228 is described below
commit 20b12282ce4735353e3d86d6b1c079e206200f64
Author: gnehil <[email protected]>
AuthorDate: Thu Jan 9 15:32:09 2025 +0800
[fix](connector) fix arrow deserialize issue due to data being inconsistent
with column order (#256)
---
.../doris/spark/client/DorisFrontendClient.java | 6 ++-
.../spark/client/read/AbstractThriftReader.java | 18 ++++-----
.../apache/doris/spark/client/read/RowBatch.java | 43 +++++++++++-----------
.../apache/doris/spark/sql/ScalaDorisRowRDD.scala | 1 -
.../apache/doris/spark/util/SchemaConvertors.scala | 1 -
5 files changed, 36 insertions(+), 33 deletions(-)
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
index 67b510e..ce6b289 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
@@ -316,7 +316,11 @@ public class DorisFrontendClient implements Serializable {
throw new DorisException();
}
String entity = EntityUtils.toString(response.getEntity());
- return MAPPER.readValue(extractEntity(entity,
"data").traverse(), QueryPlan.class);
+ JsonNode dataJsonNode = extractEntity(entity, "data");
+ if (dataJsonNode.get("exception") != null) {
+ throw new DorisException("query plan failed, exception: "
+ dataJsonNode.get("exception").asText());
+ }
+ return MAPPER.readValue(dataJsonNode.traverse(),
QueryPlan.class);
} catch (Exception e) {
throw new RuntimeException("query plan request failed", e);
}
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
index d5f9f88..608e30c 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
@@ -88,8 +88,9 @@ public abstract class AbstractThriftReader extends
DorisReader {
this.contextId = scanOpenResult.getContextId();
Schema schema = getDorisSchema();
this.dorisSchema = processDorisSchema(partition, schema);
- logger.debug("origin thrift read Schema: " + schema + ", processed
schema: " + dorisSchema);
-
+ if (logger.isDebugEnabled()) {
+ logger.debug("origin thrift read Schema: " + schema + ", processed
schema: " + dorisSchema);
+ }
if (isAsync) {
int blockingQueueSize =
config.getValue(DorisOptions.DORIS_DESERIALIZE_QUEUE_SIZE);
this.rowBatchQueue = new ArrayBlockingQueue<>(blockingQueueSize);
@@ -241,22 +242,21 @@ public abstract class AbstractThriftReader extends
DorisReader {
Schema tableSchema = frontend.getTableSchema(partition.getDatabase(),
partition.getTable());
Map<String, Field> fieldTypeMap = tableSchema.getProperties().stream()
.collect(Collectors.toMap(Field::getName,
Function.identity()));
+ Map<String, Field> scanTypeMap = originSchema.getProperties().stream()
+ .collect(Collectors.toMap(Field::getName,
Function.identity()));
String[] readColumns = partition.getReadColumns();
List<Field> newFieldList = new ArrayList<>();
- int offset = 0;
- for (int i = 0; i < readColumns.length; i++) {
- String readColumn = readColumns[i];
- if (!fieldTypeMap.containsKey(readColumn) && readColumn.contains("
AS ")) {
+ for (String readColumn : readColumns) {
+ if (readColumn.contains(" AS ")) {
int asIdx = readColumn.indexOf(" AS ");
String realColumn = readColumn.substring(asIdx +
4).trim().replaceAll("`", "");
- if (fieldTypeMap.containsKey(realColumn)
+ if (fieldTypeMap.containsKey(realColumn) &&
scanTypeMap.containsKey(realColumn)
&&
("BITMAP".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType())
||
"HLL".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType()))) {
newFieldList.add(new Field(realColumn,
TPrimitiveType.VARCHAR.name(), null, 0, 0, null));
- offset++;
}
} else {
- newFieldList.add(originSchema.getProperties().get(i + offset));
+
newFieldList.add(scanTypeMap.get(readColumn.trim().replaceAll("`", "")));
}
}
processedSchema.setProperties(newFieldList);
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
index c1613e8..840825c 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/RowBatch.java
@@ -194,6 +194,7 @@ public class RowBatch implements Serializable {
FieldVector curFieldVector = fieldVectors.get(col);
MinorType mt = curFieldVector.getMinorType();
+ final String colName = schema.get(col).getName();
final String currentType = schema.get(col).getType();
switch (currentType) {
case "NULL_TYPE":
@@ -203,7 +204,7 @@ public class RowBatch implements Serializable {
break;
case "BOOLEAN":
Preconditions.checkArgument(mt.equals(MinorType.BIT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
BitVector bitVector = (BitVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue = bitVector.isNull(rowIndex) ?
null : bitVector.get(rowIndex) != 0;
@@ -212,7 +213,7 @@ public class RowBatch implements Serializable {
break;
case "TINYINT":
Preconditions.checkArgument(mt.equals(MinorType.TINYINT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
TinyIntVector tinyIntVector = (TinyIntVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue = tinyIntVector.isNull(rowIndex)
? null : tinyIntVector.get(rowIndex);
@@ -221,7 +222,7 @@ public class RowBatch implements Serializable {
break;
case "SMALLINT":
Preconditions.checkArgument(mt.equals(MinorType.SMALLINT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
SmallIntVector smallIntVector = (SmallIntVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue =
smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex);
@@ -230,7 +231,7 @@ public class RowBatch implements Serializable {
break;
case "INT":
Preconditions.checkArgument(mt.equals(MinorType.INT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
IntVector intVector = (IntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue = intVector.isNull(rowIndex) ?
null : intVector.get(rowIndex);
@@ -239,7 +240,7 @@ public class RowBatch implements Serializable {
break;
case "BIGINT":
Preconditions.checkArgument(mt.equals(MinorType.BIGINT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
BigIntVector bigIntVector = (BigIntVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue = bigIntVector.isNull(rowIndex)
? null : bigIntVector.get(rowIndex);
@@ -248,7 +249,7 @@ public class RowBatch implements Serializable {
break;
case "LARGEINT":
Preconditions.checkArgument(mt.equals(MinorType.FIXEDSIZEBINARY) ||
- mt.equals(MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
+ mt.equals(MinorType.VARCHAR),
typeMismatchMessage(colName, currentType, mt));
if (mt.equals(MinorType.FIXEDSIZEBINARY)) {
FixedSizeBinaryVector largeIntVector =
(FixedSizeBinaryVector) curFieldVector;
for (int rowIndex = 0; rowIndex <
rowCountInOneBatch; rowIndex++) {
@@ -276,7 +277,7 @@ public class RowBatch implements Serializable {
break;
case "IPV4":
Preconditions.checkArgument(mt.equals(MinorType.UINT4)
|| mt.equals(MinorType.INT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
BaseIntVector ipv4Vector;
if (mt.equals(MinorType.INT)) {
ipv4Vector = (IntVector) curFieldVector;
@@ -291,7 +292,7 @@ public class RowBatch implements Serializable {
break;
case "FLOAT":
Preconditions.checkArgument(mt.equals(MinorType.FLOAT4),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
Float4Vector float4Vector = (Float4Vector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue = float4Vector.isNull(rowIndex)
? null : float4Vector.get(rowIndex);
@@ -301,7 +302,7 @@ public class RowBatch implements Serializable {
case "TIME":
case "DOUBLE":
Preconditions.checkArgument(mt.equals(MinorType.FLOAT8),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
Float8Vector float8Vector = (Float8Vector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue = float8Vector.isNull(rowIndex)
? null : float8Vector.get(rowIndex);
@@ -310,7 +311,7 @@ public class RowBatch implements Serializable {
break;
case "BINARY":
Preconditions.checkArgument(mt.equals(MinorType.VARBINARY),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
VarBinaryVector varBinaryVector = (VarBinaryVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
Object fieldValue =
varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex);
@@ -319,7 +320,7 @@ public class RowBatch implements Serializable {
break;
case "DECIMAL":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
VarCharVector varCharVectorForDecimal =
(VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
if (varCharVectorForDecimal.isNull(rowIndex)) {
@@ -343,7 +344,7 @@ public class RowBatch implements Serializable {
case "DECIMAL64":
case "DECIMAL128I":
Preconditions.checkArgument(mt.equals(MinorType.DECIMAL),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
DecimalVector decimalVector = (DecimalVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
if (decimalVector.isNull(rowIndex)) {
@@ -357,7 +358,7 @@ public class RowBatch implements Serializable {
case "DATE":
case "DATEV2":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR)
- || mt.equals(MinorType.DATEDAY),
typeMismatchMessage(currentType, mt));
+ || mt.equals(MinorType.DATEDAY),
typeMismatchMessage(colName, currentType, mt));
if (mt.equals(MinorType.VARCHAR)) {
VarCharVector date = (VarCharVector)
curFieldVector;
for (int rowIndex = 0; rowIndex <
rowCountInOneBatch; rowIndex++) {
@@ -417,7 +418,7 @@ public class RowBatch implements Serializable {
case "JSONB":
case "VARIANT":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
VarCharVector varCharVector = (VarCharVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
if (varCharVector.isNull(rowIndex)) {
@@ -430,7 +431,7 @@ public class RowBatch implements Serializable {
break;
case "IPV6":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
VarCharVector ipv6VarcharVector = (VarCharVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
if (ipv6VarcharVector.isNull(rowIndex)) {
@@ -444,7 +445,7 @@ public class RowBatch implements Serializable {
break;
case "ARRAY":
Preconditions.checkArgument(mt.equals(MinorType.LIST),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
ListVector listVector = (ListVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
if (listVector.isNull(rowIndex)) {
@@ -457,7 +458,7 @@ public class RowBatch implements Serializable {
break;
case "MAP":
Preconditions.checkArgument(mt.equals(MinorType.MAP),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
MapVector mapVector = (MapVector) curFieldVector;
UnionMapReader reader = mapVector.getReader();
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
@@ -476,7 +477,7 @@ public class RowBatch implements Serializable {
break;
case "STRUCT":
Preconditions.checkArgument(mt.equals(MinorType.STRUCT),
- typeMismatchMessage(currentType, mt));
+ typeMismatchMessage(colName, currentType, mt));
StructVector structVector = (StructVector)
curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch;
rowIndex++) {
if (structVector.isNull(rowIndex)) {
@@ -508,9 +509,9 @@ public class RowBatch implements Serializable {
return rowBatch.get(offsetInRowBatch++).getCols();
}
- private String typeMismatchMessage(final String sparkType, final MinorType
arrowType) {
- final String messageTemplate = "Spark type is %1$s, but arrow type is
%2$s.";
- return String.format(messageTemplate, sparkType, arrowType.name());
+ private String typeMismatchMessage(final String columnName, final String
sparkType, final MinorType arrowType) {
+ final String messageTemplate = "Spark type for column %1$s is %2$s,
but arrow type is %3$s.";
+ return String.format(messageTemplate, columnName, sparkType,
arrowType.name());
}
public int getReadRowCount() {
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
index 9713bf3..0e6038d 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
@@ -38,7 +38,6 @@ private[spark] class ScalaDorisRowRDDIterator(context:
TaskContext,
extends AbstractDorisRDDIterator[Row](context, partition) {
override def initReader(config: DorisConfig): Unit = {
- config.setProperty(DorisOptions.DORIS_READ_FIELDS, schema.map(f =>
s"`${f.name}`").mkString(","))
config.getValue(DorisOptions.READ_MODE).toLowerCase match {
case "thrift" =>
config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS,
classOf[DorisRowThriftReader].getName)
case "arrow" =>
config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS,
classOf[DorisRowFlightSqlReader].getName)
diff --git
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
index f85eb28..303aa1f 100644
---
a/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
+++
b/spark-doris-connector/spark-doris-connector-base/src/main/scala/org/apache/doris/spark/util/SchemaConvertors.scala
@@ -67,7 +67,6 @@ object SchemaConvertors {
def convertToSchema(tscanColumnDescs: Seq[TScanColumnDesc]): Schema = {
val schema = new Schema(tscanColumnDescs.length)
tscanColumnDescs.foreach(desc => {
- // println(desc.getName + " " + desc.getType.name())
schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, ""))
})
schema
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]