This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new 0daf6c4 [optimize][refractor] Optimizing memory usage for writing
data (#140)
0daf6c4 is described below
commit 0daf6c4de9c2b1686cd6862e5ada3b1cba36640a
Author: gnehil <[email protected]>
AuthorDate: Wed Sep 13 15:47:57 2023 +0800
[optimize][refractor] Optimizing memory usage for writing data (#140)
1.Optimize data transmission
2.Optimize partitioned data iteration
---
.../apache/doris/spark/backend/BackendClient.java | 3 +-
.../java/org/apache/doris/spark/cfg/Settings.java | 1 +
.../apache/doris/spark/load/DorisStreamLoad.java | 238 ++++++---------------
.../org/apache/doris/spark/load/RecordBatch.java | 153 +++++++++++++
.../doris/spark/load/RecordBatchInputStream.java | 221 +++++++++++++++++++
.../apache/doris/spark/serialization/RowBatch.java | 38 ++--
.../java/org/apache/doris/spark/util/DataUtil.java | 53 +++--
.../org/apache/doris/spark/util/ListUtils.java | 2 +-
.../scala/org/apache/doris/spark/package.scala | 6 +-
.../org/apache/doris/spark/sql/ScalaDorisRow.scala | 6 +-
.../org/apache/doris/spark/sql/SchemaUtils.scala | 58 ++++-
.../apache/doris/spark/writer/DorisWriter.scala | 91 ++------
.../org/apache/doris/spark/util/DataUtilTest.java | 32 ---
.../apache/doris/spark/sql/SchemaUtilsTest.scala | 54 +++++
14 files changed, 637 insertions(+), 319 deletions(-)
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
index aaafe09..b10797b 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
@@ -29,9 +29,10 @@ import org.apache.doris.spark.cfg.ConfigurationOptions;
import org.apache.doris.spark.exception.ConnectedFailedException;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.exception.DorisInternalException;
-import org.apache.doris.spark.util.ErrorMessages;
import org.apache.doris.spark.cfg.Settings;
import org.apache.doris.spark.serialization.Routing;
+import org.apache.doris.spark.util.ErrorMessages;
+
import org.apache.thrift.TConfiguration;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
index 798ec8c..c941fdf 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
@@ -25,6 +25,7 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.doris.spark.exception.IllegalArgumentException;
import org.apache.doris.spark.util.ErrorMessages;
import org.apache.doris.spark.util.IOUtils;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index ac920cd..9ecfa40 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -22,8 +22,6 @@ import org.apache.doris.spark.exception.StreamLoadException;
import org.apache.doris.spark.rest.RestService;
import org.apache.doris.spark.rest.models.BackendV2;
import org.apache.doris.spark.rest.models.RespContent;
-import org.apache.doris.spark.util.DataUtil;
-import org.apache.doris.spark.util.ListUtils;
import org.apache.doris.spark.util.ResponseUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -39,71 +37,72 @@ import org.apache.http.HttpResponse;
import org.apache.http.HttpStatus;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPut;
+import org.apache.http.client.methods.HttpRequestBase;
import org.apache.http.entity.BufferedHttpEntity;
-import org.apache.http.entity.StringEntity;
+import org.apache.http.entity.InputStreamEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.util.EntityUtils;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
-import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
/**
* DorisStreamLoad
**/
public class DorisStreamLoad implements Serializable {
- private String FIELD_DELIMITER;
- private final String LINE_DELIMITER;
- private static final String NULL_VALUE = "\\N";
private static final Logger LOG =
LoggerFactory.getLogger(DorisStreamLoad.class);
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+
private final static List<String> DORIS_SUCCESS_STATUS = new
ArrayList<>(Arrays.asList("Success", "Publish Timeout"));
- private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?";
- private static String abortUrlPattern =
"http://%s/api/%s/%s/_stream_load_2pc?";
+ private static final String loadUrlPattern =
"http://%s/api/%s/%s/_stream_load?";
+
+ private static final String abortUrlPattern =
"http://%s/api/%s/%s/_stream_load_2pc?";
- private String user;
- private String passwd;
private String loadUrlStr;
- private String db;
- private String tbl;
- private String authEncoded;
- private String columns;
- private String maxFilterRatio;
- private Map<String, String> streamLoadProp;
+ private final String db;
+ private final String tbl;
+ private final String authEncoded;
+ private final String columns;
+ private final String maxFilterRatio;
+ private final Map<String, String> streamLoadProp;
private static final long cacheExpireTimeout = 4 * 60;
private final LoadingCache<String, List<BackendV2.BackendRowV2>> cache;
private final String fileType;
-
- private boolean readJsonByLine = false;
-
+ private String FIELD_DELIMITER;
+ private final String LINE_DELIMITER;
private boolean streamingPassthrough = false;
+ private final Integer batchSize;
+ private boolean enable2PC;
public DorisStreamLoad(SparkSettings settings) {
String[] dbTable =
settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
this.db = dbTable[0];
this.tbl = dbTable[1];
- this.user =
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER);
- this.passwd =
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD);
+ String user =
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER);
+ String passwd =
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD);
this.authEncoded = getAuthEncoded(user, passwd);
this.columns =
settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS);
this.maxFilterRatio =
settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO);
@@ -113,18 +112,15 @@ public class DorisStreamLoad implements Serializable {
if ("csv".equals(fileType)) {
FIELD_DELIMITER =
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
} else if ("json".equalsIgnoreCase(fileType)) {
- readJsonByLine =
Boolean.parseBoolean(streamLoadProp.getOrDefault("read_json_by_line", "false"));
- boolean stripOuterArray =
Boolean.parseBoolean(streamLoadProp.getOrDefault("strip_outer_array", "false"));
- if (readJsonByLine && stripOuterArray) {
- throw new IllegalArgumentException("Only one of options
'read_json_by_line' and 'strip_outer_array' can be set to true");
- } else if (!readJsonByLine && !stripOuterArray) {
- LOG.info("set default json mode: strip_outer_array");
- streamLoadProp.put("strip_outer_array", "true");
- }
+ streamLoadProp.put("read_json_by_line", "true");
}
LINE_DELIMITER =
escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
this.streamingPassthrough =
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
+ this.batchSize =
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
+ ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT);
+ this.enable2PC =
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
+ ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
}
public String getLoadUrlStr() {
@@ -141,9 +137,7 @@ public class DorisStreamLoad implements Serializable {
private HttpPut getHttpPut(String label, String loadUrlStr, Boolean
enable2PC) {
HttpPut httpPut = new HttpPut(loadUrlStr);
- httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
- httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
- httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain;
charset=UTF-8");
+ addCommonHeader(httpPut);
httpPut.setHeader("label", label);
if (StringUtils.isNotBlank(columns)) {
httpPut.setHeader("columns", columns);
@@ -155,7 +149,11 @@ public class DorisStreamLoad implements Serializable {
httpPut.setHeader("two_phase_commit", "true");
}
if (MapUtils.isNotEmpty(streamLoadProp)) {
- streamLoadProp.forEach(httpPut::setHeader);
+ streamLoadProp.forEach((k, v) -> {
+ if (!"strip_outer_array".equalsIgnoreCase(k)) {
+ httpPut.setHeader(k, v);
+ }
+ });
}
return httpPut;
}
@@ -165,10 +163,10 @@ public class DorisStreamLoad implements Serializable {
public String respMsg;
public String respContent;
- public LoadResponse(int status, String respMsg, String respContent) {
- this.status = status;
- this.respMsg = respMsg;
- this.respContent = respContent;
+ public LoadResponse(HttpResponse response) throws IOException {
+ this.status = response.getStatusLine().getStatusCode();
+ this.respMsg = response.getStatusLine().getReasonPhrase();
+ this.respContent = EntityUtils.toString(new
BufferedHttpEntity(response.getEntity()), StandardCharsets.UTF_8);
}
@Override
@@ -177,95 +175,34 @@ public class DorisStreamLoad implements Serializable {
}
}
- public List<Integer> loadV2(List<List<Object>> rows, String[] dfColumns,
Boolean enable2PC) throws StreamLoadException, JsonProcessingException {
-
- List<String> loadData = parseLoadData(rows, dfColumns);
- List<Integer> txnIds = new ArrayList<>(loadData.size());
-
- try {
- for (String data : loadData) {
- txnIds.add(load(data, enable2PC));
- }
- } catch (StreamLoadException e) {
- if (enable2PC && !txnIds.isEmpty()) {
- LOG.error("load batch failed, abort previously pre-committed
transactions");
- for (Integer txnId : txnIds) {
- abort(txnId);
- }
- }
- throw e;
- }
-
- return txnIds;
-
- }
-
- public List<Integer> loadStream(List<List<Object>> rows, String[]
dfColumns, Boolean enable2PC)
+ public int load(Iterator<InternalRow> rows, StructType schema)
throws StreamLoadException, JsonProcessingException {
- List<String> loadData;
-
- if (this.streamingPassthrough) {
- handleStreamPassThrough();
- loadData = passthrough(rows);
- } else {
- loadData = parseLoadData(rows, dfColumns);
- }
-
- List<Integer> txnIds = new ArrayList<>(loadData.size());
-
- try {
- for (String data : loadData) {
- txnIds.add(load(data, enable2PC));
- }
- } catch (StreamLoadException e) {
- if (enable2PC && !txnIds.isEmpty()) {
- LOG.error("load batch failed, abort previously pre-committed
transactions");
- for (Integer txnId : txnIds) {
- abort(txnId);
- }
- }
- throw e;
- }
-
- return txnIds;
-
- }
-
- public int load(String value, Boolean enable2PC) throws
StreamLoadException {
-
String label = generateLoadLabel();
-
LoadResponse loadResponse;
- int responseHttpStatus = -1;
try (CloseableHttpClient httpClient = getHttpClient()) {
String loadUrlStr = String.format(loadUrlPattern, getBackend(),
db, tbl);
- LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, value);
- // only to record the BE node in case of an exception
this.loadUrlStr = loadUrlStr;
-
HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC);
- httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8));
+ RecordBatchInputStream recodeBatchInputStream = new
RecordBatchInputStream(RecordBatch.newBuilder(rows)
+ .batchSize(batchSize)
+ .format(fileType)
+ .sep(FIELD_DELIMITER)
+ .delim(LINE_DELIMITER)
+ .schema(schema).build(), streamingPassthrough);
+ httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream));
HttpResponse httpResponse = httpClient.execute(httpPut);
- responseHttpStatus = httpResponse.getStatusLine().getStatusCode();
- String respMsg = httpResponse.getStatusLine().getReasonPhrase();
- String response = EntityUtils.toString(new
BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8);
- loadResponse = new LoadResponse(responseHttpStatus, respMsg,
response);
+ loadResponse = new LoadResponse(httpResponse);
} catch (IOException e) {
- e.printStackTrace();
- String err = "http request exception,load url : " + loadUrlStr +
",failed to execute spark stream load with label: " + label;
- LOG.warn(err, e);
- loadResponse = new LoadResponse(responseHttpStatus,
e.getMessage(), err);
+ throw new RuntimeException(e);
}
if (loadResponse.status != HttpStatus.SC_OK) {
LOG.info("Stream load Response HTTP Status Error:{}",
loadResponse);
- // throw new StreamLoadException("stream load error: " +
loadResponse.respContent);
throw new StreamLoadException("stream load error");
} else {
- ObjectMapper obj = new ObjectMapper();
try {
- RespContent respContent =
obj.readValue(loadResponse.respContent, RespContent.class);
+ RespContent respContent =
MAPPER.readValue(loadResponse.respContent, RespContent.class);
if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) {
LOG.error("Stream load Response RES STATUS Error:{}",
loadResponse);
throw new StreamLoadException("stream load error");
@@ -279,6 +216,14 @@ public class DorisStreamLoad implements Serializable {
}
+ public Integer loadStream(Iterator<InternalRow> rows, StructType schema)
+ throws StreamLoadException, JsonProcessingException {
+ if (this.streamingPassthrough) {
+ handleStreamPassThrough();
+ }
+ return load(rows, schema);
+ }
+
public void commit(int txnId) throws StreamLoadException {
try (CloseableHttpClient client = getHttpClient()) {
@@ -286,9 +231,7 @@ public class DorisStreamLoad implements Serializable {
String backend = getBackend();
String abortUrl = String.format(abortUrlPattern, backend, db, tbl);
HttpPut httpPut = new HttpPut(abortUrl);
- httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " +
authEncoded);
- httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
- httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain;
charset=UTF-8");
+ addCommonHeader(httpPut);
httpPut.setHeader("txn_operation", "commit");
httpPut.setHeader("txn_id", String.valueOf(txnId));
@@ -306,10 +249,9 @@ public class DorisStreamLoad implements Serializable {
throw new StreamLoadException("stream load error: " +
reasonPhrase);
}
- ObjectMapper mapper = new ObjectMapper();
if (response.getEntity() != null) {
String loadResult = EntityUtils.toString(response.getEntity());
- Map<String, String> res = mapper.readValue(loadResult, new
TypeReference<HashMap<String, String>>() {
+ Map<String, String> res = MAPPER.readValue(loadResult, new
TypeReference<HashMap<String, String>>() {
});
if (res.get("status").equals("Fail") &&
!ResponseUtil.isCommitted(res.get("msg"))) {
throw new StreamLoadException("Commit failed " +
loadResult);
@@ -331,9 +273,7 @@ public class DorisStreamLoad implements Serializable {
try (CloseableHttpClient client = getHttpClient()) {
String abortUrl = String.format(abortUrlPattern, getBackend(), db,
tbl);
HttpPut httpPut = new HttpPut(abortUrl);
- httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " +
authEncoded);
- httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
- httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain;
charset=UTF-8");
+ addCommonHeader(httpPut);
httpPut.setHeader("txn_operation", "abort");
httpPut.setHeader("txn_id", String.valueOf(txnId));
@@ -344,9 +284,8 @@ public class DorisStreamLoad implements Serializable {
throw new StreamLoadException("Fail to abort transaction " +
txnId + " with url " + abortUrl);
}
- ObjectMapper mapper = new ObjectMapper();
String loadResult = EntityUtils.toString(response.getEntity());
- Map<String, String> res = mapper.readValue(loadResult, new
TypeReference<HashMap<String, String>>() {
+ Map<String, String> res = MAPPER.readValue(loadResult, new
TypeReference<HashMap<String, String>>() {
});
if (!"Success".equals(res.get("status"))) {
if (ResponseUtil.isCommitted(res.get("msg"))) {
@@ -410,52 +349,13 @@ public class DorisStreamLoad implements Serializable {
}
- private List<String> parseLoadData(List<List<Object>> rows, String[]
dfColumns) throws StreamLoadException, JsonProcessingException {
-
- List<String> loadDataList;
-
- switch (fileType.toUpperCase()) {
-
- case "CSV":
- loadDataList = Collections.singletonList(
- rows.stream()
- .map(row -> row.stream()
- .map(DataUtil::handleColumnValue)
- .map(Object::toString)
-
.collect(Collectors.joining(FIELD_DELIMITER))
- ).collect(Collectors.joining(LINE_DELIMITER)));
- break;
- case "JSON":
- List<Map<Object, Object>> dataList = new ArrayList<>();
- try {
- for (List<Object> row : rows) {
- Map<Object, Object> dataMap = new HashMap<>();
- if (dfColumns.length == row.size()) {
- for (int i = 0; i < dfColumns.length; i++) {
- dataMap.put(dfColumns[i],
DataUtil.handleColumnValue(row.get(i)));
- }
- }
- dataList.add(dataMap);
- }
- } catch (Exception e) {
- throw new StreamLoadException("The number of configured
columns does not match the number of data columns.");
- }
- // splits large collections to normal collection to avoid the
"Requested array size exceeds VM limit" exception
- loadDataList = ListUtils.getSerializedList(dataList,
readJsonByLine ? LINE_DELIMITER : null);
- break;
- default:
- throw new StreamLoadException(String.format("Unsupported file
format in stream load: %s.", fileType));
-
- }
-
- return loadDataList;
-
- }
-
private String generateLoadLabel() {
Calendar calendar = Calendar.getInstance();
- return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s",
calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1,
calendar.get(Calendar.DAY_OF_MONTH), calendar.get(Calendar.HOUR_OF_DAY),
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND),
UUID.randomUUID().toString().replaceAll("-", ""));
+ return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s",
+ calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1,
calendar.get(Calendar.DAY_OF_MONTH),
+ calendar.get(Calendar.HOUR_OF_DAY),
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND),
+ UUID.randomUUID().toString().replaceAll("-", ""));
}
@@ -478,6 +378,12 @@ public class DorisStreamLoad implements Serializable {
return hexData;
}
+ private void addCommonHeader(HttpRequestBase httpReq) {
+ httpReq.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
+ httpReq.setHeader(HttpHeaders.EXPECT, "100-continue");
+ httpReq.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain;
charset=UTF-8");
+ }
+
private void handleStreamPassThrough() {
if ("json".equalsIgnoreCase(fileType)) {
@@ -488,8 +394,4 @@ public class DorisStreamLoad implements Serializable {
}
- private List<String> passthrough(List<List<Object>> values) {
- return values.stream().map(list ->
list.get(0).toString()).collect(Collectors.toList());
- }
-
}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
new file mode 100644
index 0000000..779c057
--- /dev/null
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
+
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+
+/**
+ * Wrapper Object for batch loading
+ */
+public class RecordBatch {
+
+ private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;
+
+ /**
+ * Spark row data iterator
+ */
+ private final Iterator<InternalRow> iterator;
+
+ /**
+ * batch size for single load
+ */
+ private final int batchSize;
+
+ /**
+ * stream load format
+ */
+ private final String format;
+
+ /**
+ * column separator, only used when the format is csv
+ */
+ private final String sep;
+
+ /**
+ * line delimiter
+ */
+ private final byte[] delim;
+
+ /**
+ * schema of row
+ */
+ private final StructType schema;
+
+ private RecordBatch(Iterator<InternalRow> iterator, int batchSize, String
format, String sep, byte[] delim,
+ StructType schema) {
+ this.iterator = iterator;
+ this.batchSize = batchSize;
+ this.format = format;
+ this.sep = sep;
+ this.delim = delim;
+ this.schema = schema;
+ }
+
+ public Iterator<InternalRow> getIterator() {
+ return iterator;
+ }
+
+ public int getBatchSize() {
+ return batchSize;
+ }
+
+ public String getFormat() {
+ return format;
+ }
+
+ public String getSep() {
+ return sep;
+ }
+
+ public byte[] getDelim() {
+ return delim;
+ }
+
+ public StructType getSchema() {
+ return schema;
+ }
+ public static Builder newBuilder(Iterator<InternalRow> iterator) {
+ return new Builder(iterator);
+ }
+
+ /**
+ * RecordBatch Builder
+ */
+ public static class Builder {
+
+ private final Iterator<InternalRow> iterator;
+
+ private int batchSize;
+
+ private String format;
+
+ private String sep;
+
+ private byte[] delim;
+
+ private StructType schema;
+
+ public Builder(Iterator<InternalRow> iterator) {
+ this.iterator = iterator;
+ }
+
+ public Builder batchSize(int batchSize) {
+ this.batchSize = batchSize;
+ return this;
+ }
+
+ public Builder format(String format) {
+ this.format = format;
+ return this;
+ }
+
+ public Builder sep(String sep) {
+ this.sep = sep;
+ return this;
+ }
+
+ public Builder delim(String delim) {
+ this.delim = delim.getBytes(DEFAULT_CHARSET);
+ return this;
+ }
+
+ public Builder schema(StructType schema) {
+ this.schema = schema;
+ return this;
+ }
+
+ public RecordBatch build() {
+ return new RecordBatch(iterator, batchSize, format, sep, delim,
schema);
+ }
+
+ }
+
+}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
new file mode 100644
index 0000000..9444c1d
--- /dev/null
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
@@ -0,0 +1,221 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load;
+
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.exception.IllegalArgumentException;
+import org.apache.doris.spark.exception.ShouldNeverHappenException;
+import org.apache.doris.spark.util.DataUtil;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+
+/**
+ * InputStream for batch load
+ */
+public class RecordBatchInputStream extends InputStream {
+
+ public static final Logger LOG =
LoggerFactory.getLogger(RecordBatchInputStream.class);
+
+ private static final int DEFAULT_BUF_SIZE = 4096;
+
+ /**
+ * Load record batch
+ */
+ private final RecordBatch recordBatch;
+
+ /**
+ * first line flag
+ */
+ private boolean isFirst = true;
+
+ /**
+ * record buffer
+ */
+ private ByteBuffer buffer = ByteBuffer.allocate(0);
+
+ /**
+ * record count has been read
+ */
+ private int readCount = 0;
+
+ /**
+ * streaming mode pass through data without process
+ */
+ private final boolean passThrough;
+
+ public RecordBatchInputStream(RecordBatch recordBatch, boolean
passThrough) {
+ this.recordBatch = recordBatch;
+ this.passThrough = passThrough;
+ }
+
+ @Override
+ public int read() throws IOException {
+ try {
+ if (buffer.remaining() == 0 && endOfBatch()) {
+ return -1; // End of stream
+ }
+ } catch (DorisException e) {
+ throw new IOException(e);
+ }
+ return buffer.get() & 0xFF;
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ try {
+ if (buffer.remaining() == 0 && endOfBatch()) {
+ return -1; // End of stream
+ }
+ } catch (DorisException e) {
+ throw new IOException(e);
+ }
+ int bytesRead = Math.min(len, buffer.remaining());
+ buffer.get(b, off, bytesRead);
+ return bytesRead;
+ }
+
+ /**
+ * Check if the current batch read is over.
+ * If the number of reads is greater than or equal to the batch size or
there is no next record, return false,
+ * otherwise return true.
+ *
+ * @return Whether the current batch read is over
+ * @throws DorisException
+ */
+ public boolean endOfBatch() throws DorisException {
+ Iterator<InternalRow> iterator = recordBatch.getIterator();
+ if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) {
+ return true;
+ }
+ readNext(iterator);
+ return false;
+ }
+
+ /**
+ * read next record into buffer
+ *
+ * @param iterator row iterator
+ * @throws DorisException
+ */
+ private void readNext(Iterator<InternalRow> iterator) throws
DorisException {
+ if (!iterator.hasNext()) {
+ throw new ShouldNeverHappenException();
+ }
+ byte[] delim = recordBatch.getDelim();
+ byte[] rowBytes = rowToByte(iterator.next());
+ if (isFirst) {
+ ensureCapacity(rowBytes.length);
+ buffer.put(rowBytes);
+ buffer.flip();
+ isFirst = false;
+ } else {
+ ensureCapacity(delim.length + rowBytes.length);
+ buffer.put(delim);
+ buffer.put(rowBytes);
+ buffer.flip();
+ }
+ readCount++;
+ }
+
+ /**
+ * Check if the buffer has enough capacity.
+ *
+ * @param need required buffer space
+ */
+ private void ensureCapacity(int need) {
+
+ int capacity = buffer.capacity();
+
+ if (need <= capacity) {
+ buffer.clear();
+ return;
+ }
+
+ // need to extend
+ int newCapacity = calculateNewCapacity(capacity, need);
+ LOG.info("expand buffer, min cap: {}, now cap: {}, new cap: {}", need,
capacity, newCapacity);
+ buffer = ByteBuffer.allocate(newCapacity);
+
+ }
+
+ /**
+ * Calculate new capacity for buffer expansion.
+ *
+ * @param capacity current buffer capacity
+ * @param minCapacity required min buffer space
+ * @return new capacity
+ */
+ private int calculateNewCapacity(int capacity, int minCapacity) {
+ int newCapacity;
+ if (capacity == 0) {
+ newCapacity = DEFAULT_BUF_SIZE;
+ while (newCapacity < minCapacity) {
+ newCapacity = newCapacity << 1;
+ }
+ } else {
+ newCapacity = capacity << 1;
+ }
+ return newCapacity;
+ }
+
+ /**
+ * Convert Spark row data to byte array
+ *
+ * @param row row data
+ * @return byte array
+ * @throws DorisException
+ */
+ private byte[] rowToByte(InternalRow row) throws DorisException {
+
+ byte[] bytes;
+
+ if (passThrough) {
+ bytes = row.getString(0).getBytes(StandardCharsets.UTF_8);
+ return bytes;
+ }
+
+ switch (recordBatch.getFormat().toLowerCase()) {
+ case "csv":
+ bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(),
recordBatch.getSep());
+ break;
+ case "json":
+ try {
+ bytes = DataUtil.rowToJsonBytes(row,
recordBatch.getSchema());
+ } catch (JsonProcessingException e) {
+ throw new DorisException("parse row to json bytes failed",
e);
+ }
+ break;
+ default:
+ throw new IllegalArgumentException("format",
recordBatch.getFormat());
+ }
+
+ return bytes;
+
+ }
+
+
+}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index faa8ef5..3d66db5 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -17,19 +17,11 @@
package org.apache.doris.spark.serialization;
-import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.math.BigDecimal;
-import java.math.BigInteger;
-import java.nio.charset.StandardCharsets;
-import java.sql.Date;
-import java.time.LocalDate;
-import java.time.LocalDateTime;
-import java.time.format.DateTimeFormatter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.NoSuchElementException;
+import org.apache.doris.sdk.thrift.TScanBatchResult;
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.rest.models.Schema;
+import com.google.common.base.Preconditions;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
@@ -47,17 +39,21 @@ import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
-
-import org.apache.doris.sdk.thrift.TScanBatchResult;
-import org.apache.doris.spark.exception.DorisException;
-import org.apache.doris.spark.rest.models.Schema;
-
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.google.common.base.Preconditions;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.nio.charset.StandardCharsets;
+import java.sql.Date;
+import java.time.LocalDate;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.NoSuchElementException;
/**
* row batch data container.
@@ -128,7 +124,11 @@ public class RowBatch {
}
public boolean hasNext() {
- return offsetInRowBatch < readRowCount;
+ if (offsetInRowBatch >= readRowCount) {
+ rowBatch.clear();
+ return false;
+ }
+ return true;
}
private void addValueToRow(int rowIndex, Object obj) {
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
index 5877447..aea6dde 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
@@ -17,35 +17,48 @@
package org.apache.doris.spark.util;
-import scala.collection.JavaConversions;
-import scala.collection.mutable.WrappedArray;
+import org.apache.doris.spark.sql.SchemaUtils;
-import java.sql.Date;
-import java.sql.Timestamp;
-import java.util.Arrays;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
-public class DataUtil {
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
- public static final String NULL_VALUE = "\\N";
+public class DataUtil {
- public static Object handleColumnValue(Object value) {
+ private static final ObjectMapper MAPPER = new ObjectMapper();
- if (value == null) {
- return NULL_VALUE;
- }
+ public static final String NULL_VALUE = "\\N";
- if (value instanceof Date || value instanceof Timestamp) {
- return value.toString();
+ public static byte[] rowToCsvBytes(InternalRow row, StructType schema,
String sep) {
+ StringBuilder builder = new StringBuilder();
+ StructField[] fields = schema.fields();
+ int n = row.numFields();
+ if (n > 0) {
+ builder.append(SchemaUtils.rowColumnValue(row, 0,
fields[0].dataType()));
+ int i = 1;
+ while (i < n) {
+ builder.append(sep);
+ builder.append(SchemaUtils.rowColumnValue(row, i,
fields[i].dataType()));
+ i++;
+ }
}
+ return builder.toString().getBytes(StandardCharsets.UTF_8);
+ }
- if (value instanceof WrappedArray) {
-
- Object[] arr = JavaConversions.seqAsJavaList((WrappedArray)
value).toArray();
- return Arrays.toString(arr);
+ public static byte[] rowToJsonBytes(InternalRow row, StructType schema)
+ throws JsonProcessingException {
+ StructField[] fields = schema.fields();
+ Map<String, Object> rowMap = new HashMap<>(row.numFields());
+ for (int i = 0; i < fields.length; i++) {
+ rowMap.put(fields[i].name(), SchemaUtils.rowColumnValue(row, i,
fields[i].dataType()));
}
-
- return value;
-
+ return MAPPER.writeValueAsBytes(rowMap);
}
}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
index d8d31b9..fbfab9a 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
@@ -34,7 +34,7 @@ public class ListUtils {
private static final ObjectMapper MAPPER = new ObjectMapper();
public static List<String> getSerializedList(List<Map<Object, Object>>
batch,
- String lineDelimiter) throws
JsonProcessingException {
+ String lineDelimiter) throws
JsonProcessingException {
List<String> result = new ArrayList<>();
divideAndSerialize(batch, result, lineDelimiter);
return result;
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
index d08bdc0..9dee515 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
@@ -18,18 +18,18 @@
package org.apache.doris
import scala.language.implicitConversions
-
import org.apache.doris.spark.rdd.DorisSpark
import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
package object spark {
- implicit def sparkContextFunctions(sc: SparkContext) = new
SparkContextFunctions(sc)
+ implicit def sparkContextFunctions(sc: SparkContext): SparkContextFunctions
= new SparkContextFunctions(sc)
class SparkContextFunctions(sc: SparkContext) extends Serializable {
def dorisRDD(
tableIdentifier: Option[String] = None,
query: Option[String] = None,
- cfg: Option[Map[String, String]] = None) =
+ cfg: Option[Map[String, String]] = None): RDD[AnyRef] =
DorisSpark.dorisRDD(sc, tableIdentifier, query, cfg)
}
}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
index 06f5ca3..ec8f887 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
@@ -27,7 +27,7 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String])
extends Row {
/** No-arg constructor for Kryo serialization. */
def this() = this(null)
- def iterator = values.iterator
+ def iterator: Iterator[Any] = values.iterator
override def length: Int = values.length
@@ -51,9 +51,9 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String])
extends Row {
override def getByte(i: Int): Byte = getAs[Byte](i)
- override def getString(i: Int): String = get(i).toString()
+ override def getString(i: Int): String = get(i).toString
override def copy(): Row = this
- override def toSeq = values.toSeq
+ override def toSeq: Seq[Any] = values
}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index c8aa034..f5a6a15 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -18,16 +18,23 @@
package org.apache.doris.spark.sql
import org.apache.doris.sdk.thrift.TScanColumnDesc
-
-import scala.collection.JavaConversions._
+import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE,
DORIS_READ_FIELD}
import org.apache.doris.spark.cfg.Settings
import org.apache.doris.spark.exception.DorisException
import org.apache.doris.spark.rest.RestService
import org.apache.doris.spark.rest.models.{Field, Schema}
-import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE,
DORIS_READ_FIELD}
+import org.apache.doris.spark.util.DataUtil
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.slf4j.LoggerFactory
+import java.sql.Timestamp
+import java.time.{LocalDateTime, ZoneOffset}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
private[spark] object SchemaUtils {
private val logger =
LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$"))
@@ -137,4 +144,49 @@ private[spark] object SchemaUtils {
tscanColumnDescs.foreach(desc => schema.put(new Field(desc.getName,
desc.getType.name, "", 0, 0, "")))
schema
}
+
+ def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType:
DataType): Any = {
+
+ dataType match {
+ case NullType => DataUtil.NULL_VALUE
+ case BooleanType => row.getBoolean(ordinal)
+ case ByteType => row.getByte(ordinal)
+ case ShortType => row.getShort(ordinal)
+ case IntegerType => row.getInt(ordinal)
+ case LongType => row.getLong(ordinal)
+ case FloatType => row.getFloat(ordinal)
+ case DoubleType => row.getDouble(ordinal)
+ case StringType => row.getUTF8String(ordinal).toString
+ case TimestampType =>
+ LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000,
(row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC)
+ new Timestamp(row.getLong(ordinal) / 1000).toString
+ case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
+ case BinaryType => row.getBinary(ordinal)
+ case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale)
+ case at: ArrayType =>
+ val arrayData = row.getArray(ordinal)
+ var i = 0
+ val buffer = mutable.Buffer[Any]()
+ while (i < arrayData.numElements()) {
+ if (arrayData.isNullAt(i)) buffer += null else buffer +=
rowColumnValue(arrayData, i, at.elementType)
+ i += 1
+ }
+ s"[${buffer.mkString(",")}]"
+ case mt: MapType =>
+ val mapData = row.getMap(ordinal)
+ val keys = mapData.keyArray()
+ val values = mapData.valueArray()
+ var i = 0
+ val map = mutable.Map[Any, Any]()
+ while (i < keys.numElements()) {
+ map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values,
i, mt.valueType)
+ i += 1
+ }
+ map.toMap.asJava
+ case st: StructType => row.getStruct(ordinal, st.length)
+ case _ => throw new DorisException(s"Unsupported spark type:
${dataType.typeName}")
+ }
+
+ }
+
}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index e32267e..b278a38 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -39,8 +39,6 @@ class DorisWriter(settings: SparkSettings) extends
Serializable {
private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])
- val batchSize: Int =
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
- ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
private val maxRetryTimes: Int =
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
private val sinkTaskPartitionSize: Integer =
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
@@ -55,45 +53,16 @@ class DorisWriter(settings: SparkSettings) extends
Serializable {
private val dorisStreamLoader: DorisStreamLoad =
CachedDorisStreamLoadClient.getOrCreate(settings)
def write(dataFrame: DataFrame): Unit = {
+ doWrite(dataFrame, dorisStreamLoader.load)
+ }
- val sc = dataFrame.sqlContext.sparkContext
- val preCommittedTxnAcc =
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
- if (enable2PC) {
- sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc,
dorisStreamLoader))
- }
+ def writeStream(dataFrame: DataFrame): Unit = {
+ doWrite(dataFrame, dorisStreamLoader.loadStream)
+ }
- var resultRdd = dataFrame.rdd
- val dfColumns = dataFrame.columns
- if (Objects.nonNull(sinkTaskPartitionSize)) {
- resultRdd = if (sinkTaskUseRepartition)
resultRdd.repartition(sinkTaskPartitionSize) else
resultRdd.coalesce(sinkTaskPartitionSize)
- }
- resultRdd
- .map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava)
- .foreachPartition(partition => {
- partition
- .grouped(batchSize)
- .foreach(batch => flush(batch, dfColumns))
- })
-
- /**
- * flush data to Doris and do retry when flush error
- *
- */
- def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit =
{
- Utils.retry[util.List[Integer], Exception](maxRetryTimes,
Duration.ofMillis(batchInterValMs.toLong), logger) {
- dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC)
- } match {
- case Success(txnIds) => if (enable2PC)
handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
- case Failure(e) =>
- if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
- throw new IOException(
- s"Failed to load batch data on BE:
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes}
retry times.", e)
- }
- }
+ private def doWrite(dataFrame: DataFrame, loadFunc:
(util.Iterator[InternalRow], StructType) => Int): Unit = {
- }
- def writeStream(dataFrame: DataFrame): Unit = {
val sc = dataFrame.sqlContext.sparkContext
val preCommittedTxnAcc =
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
@@ -103,47 +72,31 @@ class DorisWriter(settings: SparkSettings) extends
Serializable {
var resultRdd = dataFrame.queryExecution.toRdd
val schema = dataFrame.schema
- val dfColumns = dataFrame.columns
if (Objects.nonNull(sinkTaskPartitionSize)) {
resultRdd = if (sinkTaskUseRepartition)
resultRdd.repartition(sinkTaskPartitionSize) else
resultRdd.coalesce(sinkTaskPartitionSize)
}
- resultRdd
- .foreachPartition(partition => {
- partition
- .grouped(batchSize)
- .foreach(batch =>
- flush(batch, dfColumns))
- })
-
- /**
- * flush data to Doris and do retry when flush error
- *
- */
- def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = {
- Utils.retry[util.List[Integer], Exception](maxRetryTimes,
Duration.ofMillis(batchInterValMs.toLong), logger) {
- dorisStreamLoader.loadStream(convertToObjectList(batch, schema),
dfColumns, enable2PC)
- } match {
- case Success(txnIds) => if (enable2PC)
handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
- case Failure(e) =>
- if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
- throw new IOException(
- s"Failed to load batch data on BE:
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes}
retry times.", e)
+ resultRdd.foreachPartition(iterator => {
+ while (iterator.hasNext) {
+ // do load batch with retries
+ Utils.retry[Int, Exception](maxRetryTimes,
Duration.ofMillis(batchInterValMs.toLong), logger) {
+ loadFunc(iterator.asJava, schema)
+ } match {
+ case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId,
preCommittedTxnAcc)
+ case Failure(e) =>
+ if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+ throw new IOException(
+ s"Failed to load batch data on BE:
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes}
retry times.", e)
+ }
}
- }
-
- def convertToObjectList(rows: Seq[InternalRow], schema: StructType):
util.List[util.List[Object]] = {
- rows.map(row => {
- row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava
- }).asJava
- }
+ })
}
- private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc:
CollectionAccumulator[Int]): Unit = {
- txnIds.foreach(txnId => acc.add(txnId))
+ private def handleLoadSuccess(txnId: Int, acc: CollectionAccumulator[Int]):
Unit = {
+ acc.add(txnId)
}
- def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
+ private def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
// if task run failed, acc value will not be returned to driver,
// should abort all pre committed transactions inside the task
logger.info("load task failed, start aborting previously pre-committed
transactions")
diff --git
a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java
b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java
deleted file mode 100644
index 020a241..0000000
---
a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.spark.util;
-
-import junit.framework.TestCase;
-import org.junit.Assert;
-import scala.collection.mutable.WrappedArray;
-
-import java.sql.Timestamp;
-
-public class DataUtilTest extends TestCase {
-
- public void testHandleColumnValue() {
- Assert.assertEquals("2023-08-14 18:00:00.0",
DataUtil.handleColumnValue(Timestamp.valueOf("2023-08-14 18:00:00")));
- Assert.assertEquals("[1, 2, 3]",
DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3})));
- }
-}
\ No newline at end of file
diff --git
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala
new file mode 100644
index 0000000..e3868cb
--- /dev/null
+++
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.sql
+
+import org.apache.spark.sql.SparkSession
+import org.junit.{Assert, Ignore, Test}
+
+import java.sql.{Date, Timestamp}
+import scala.collection.JavaConverters._
+
+@Ignore
+class SchemaUtilsTest {
+
+ @Test
+ def rowColumnValueTest(): Unit = {
+
+ val spark = SparkSession.builder().master("local").getOrCreate()
+
+ val df = spark.createDataFrame(Seq(
+ (1, Date.valueOf("2023-09-08"), Timestamp.valueOf("2023-09-08
17:00:00"), Array(1, 2, 3), Map[String, String]("a" -> "1"))
+ )).toDF("c1", "c2", "c3", "c4", "c5")
+
+ val schema = df.schema
+
+ df.queryExecution.toRdd.foreach(row => {
+
+ val fields = schema.fields
+ Assert.assertEquals(1, SchemaUtils.rowColumnValue(row, 0,
fields(0).dataType))
+ Assert.assertEquals("2023-09-08", SchemaUtils.rowColumnValue(row, 1,
fields(1).dataType))
+ Assert.assertEquals("2023-09-08 17:00:00.0",
SchemaUtils.rowColumnValue(row, 2, fields(2).dataType))
+ Assert.assertEquals("[1,2,3]", SchemaUtils.rowColumnValue(row, 3,
fields(3).dataType))
+ println(SchemaUtils.rowColumnValue(row, 4, fields(4).dataType))
+ Assert.assertEquals(Map("a" -> "1").asJava,
SchemaUtils.rowColumnValue(row, 4, fields(4).dataType))
+
+ })
+
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]