This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 0daf6c4  [optimize][refractor] Optimizing memory usage for writing 
data (#140)
0daf6c4 is described below

commit 0daf6c4de9c2b1686cd6862e5ada3b1cba36640a
Author: gnehil <[email protected]>
AuthorDate: Wed Sep 13 15:47:57 2023 +0800

    [optimize][refractor] Optimizing memory usage for writing data (#140)
    
    1.Optimize data transmission
    2.Optimize partitioned data iteration
---
 .../apache/doris/spark/backend/BackendClient.java  |   3 +-
 .../java/org/apache/doris/spark/cfg/Settings.java  |   1 +
 .../apache/doris/spark/load/DorisStreamLoad.java   | 238 ++++++---------------
 .../org/apache/doris/spark/load/RecordBatch.java   | 153 +++++++++++++
 .../doris/spark/load/RecordBatchInputStream.java   | 221 +++++++++++++++++++
 .../apache/doris/spark/serialization/RowBatch.java |  38 ++--
 .../java/org/apache/doris/spark/util/DataUtil.java |  53 +++--
 .../org/apache/doris/spark/util/ListUtils.java     |   2 +-
 .../scala/org/apache/doris/spark/package.scala     |   6 +-
 .../org/apache/doris/spark/sql/ScalaDorisRow.scala |   6 +-
 .../org/apache/doris/spark/sql/SchemaUtils.scala   |  58 ++++-
 .../apache/doris/spark/writer/DorisWriter.scala    |  91 ++------
 .../org/apache/doris/spark/util/DataUtilTest.java  |  32 ---
 .../apache/doris/spark/sql/SchemaUtilsTest.scala   |  54 +++++
 14 files changed, 637 insertions(+), 319 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
index aaafe09..b10797b 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/backend/BackendClient.java
@@ -29,9 +29,10 @@ import org.apache.doris.spark.cfg.ConfigurationOptions;
 import org.apache.doris.spark.exception.ConnectedFailedException;
 import org.apache.doris.spark.exception.DorisException;
 import org.apache.doris.spark.exception.DorisInternalException;
-import org.apache.doris.spark.util.ErrorMessages;
 import org.apache.doris.spark.cfg.Settings;
 import org.apache.doris.spark.serialization.Routing;
+import org.apache.doris.spark.util.ErrorMessages;
+
 import org.apache.thrift.TConfiguration;
 import org.apache.thrift.TException;
 import org.apache.thrift.protocol.TBinaryProtocol;
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
index 798ec8c..c941fdf 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
@@ -25,6 +25,7 @@ import org.apache.commons.lang3.StringUtils;
 import org.apache.doris.spark.exception.IllegalArgumentException;
 import org.apache.doris.spark.util.ErrorMessages;
 import org.apache.doris.spark.util.IOUtils;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index ac920cd..9ecfa40 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -22,8 +22,6 @@ import org.apache.doris.spark.exception.StreamLoadException;
 import org.apache.doris.spark.rest.RestService;
 import org.apache.doris.spark.rest.models.BackendV2;
 import org.apache.doris.spark.rest.models.RespContent;
-import org.apache.doris.spark.util.DataUtil;
-import org.apache.doris.spark.util.ListUtils;
 import org.apache.doris.spark.util.ResponseUtil;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
@@ -39,71 +37,72 @@ import org.apache.http.HttpResponse;
 import org.apache.http.HttpStatus;
 import org.apache.http.client.methods.CloseableHttpResponse;
 import org.apache.http.client.methods.HttpPut;
+import org.apache.http.client.methods.HttpRequestBase;
 import org.apache.http.entity.BufferedHttpEntity;
-import org.apache.http.entity.StringEntity;
+import org.apache.http.entity.InputStreamEntity;
 import org.apache.http.impl.client.CloseableHttpClient;
 import org.apache.http.impl.client.HttpClientBuilder;
 import org.apache.http.util.EntityUtils;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.io.Serializable;
 import java.nio.charset.StandardCharsets;
-import java.sql.Timestamp;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Base64;
 import java.util.Calendar;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 import java.util.UUID;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
 
 
 /**
  * DorisStreamLoad
  **/
 public class DorisStreamLoad implements Serializable {
-    private String FIELD_DELIMITER;
-    private final String LINE_DELIMITER;
-    private static final String NULL_VALUE = "\\N";
 
     private static final Logger LOG = 
LoggerFactory.getLogger(DorisStreamLoad.class);
 
+    private static final ObjectMapper MAPPER = new ObjectMapper();
+
     private final static List<String> DORIS_SUCCESS_STATUS = new 
ArrayList<>(Arrays.asList("Success", "Publish Timeout"));
-    private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?";;
 
-    private static String abortUrlPattern = 
"http://%s/api/%s/%s/_stream_load_2pc?";;
+    private static final String loadUrlPattern = 
"http://%s/api/%s/%s/_stream_load?";;
+
+    private static final String abortUrlPattern = 
"http://%s/api/%s/%s/_stream_load_2pc?";;
 
-    private String user;
-    private String passwd;
     private String loadUrlStr;
-    private String db;
-    private String tbl;
-    private String authEncoded;
-    private String columns;
-    private String maxFilterRatio;
-    private Map<String, String> streamLoadProp;
+    private final String db;
+    private final String tbl;
+    private final String authEncoded;
+    private final String columns;
+    private final String maxFilterRatio;
+    private final Map<String, String> streamLoadProp;
     private static final long cacheExpireTimeout = 4 * 60;
     private final LoadingCache<String, List<BackendV2.BackendRowV2>> cache;
     private final String fileType;
-
-    private boolean readJsonByLine = false;
-
+    private String FIELD_DELIMITER;
+    private final String LINE_DELIMITER;
     private boolean streamingPassthrough = false;
+    private final Integer batchSize;
+    private boolean enable2PC;
 
     public DorisStreamLoad(SparkSettings settings) {
         String[] dbTable = 
settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
         this.db = dbTable[0];
         this.tbl = dbTable[1];
-        this.user = 
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER);
-        this.passwd = 
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD);
+        String user = 
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER);
+        String passwd = 
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD);
         this.authEncoded = getAuthEncoded(user, passwd);
         this.columns = 
settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS);
         this.maxFilterRatio = 
settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO);
@@ -113,18 +112,15 @@ public class DorisStreamLoad implements Serializable {
         if ("csv".equals(fileType)) {
             FIELD_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
         } else if ("json".equalsIgnoreCase(fileType)) {
-            readJsonByLine = 
Boolean.parseBoolean(streamLoadProp.getOrDefault("read_json_by_line", "false"));
-            boolean stripOuterArray = 
Boolean.parseBoolean(streamLoadProp.getOrDefault("strip_outer_array", "false"));
-            if (readJsonByLine && stripOuterArray) {
-                throw new IllegalArgumentException("Only one of options 
'read_json_by_line' and 'strip_outer_array' can be set to true");
-            } else if (!readJsonByLine && !stripOuterArray) {
-                LOG.info("set default json mode: strip_outer_array");
-                streamLoadProp.put("strip_outer_array", "true");
-            }
+            streamLoadProp.put("read_json_by_line", "true");
         }
         LINE_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
         this.streamingPassthrough = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
                 ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
+        this.batchSize = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
+                ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT);
+        this.enable2PC = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
+                ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
     }
 
     public String getLoadUrlStr() {
@@ -141,9 +137,7 @@ public class DorisStreamLoad implements Serializable {
 
     private HttpPut getHttpPut(String label, String loadUrlStr, Boolean 
enable2PC) {
         HttpPut httpPut = new HttpPut(loadUrlStr);
-        httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
-        httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
-        httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
+        addCommonHeader(httpPut);
         httpPut.setHeader("label", label);
         if (StringUtils.isNotBlank(columns)) {
             httpPut.setHeader("columns", columns);
@@ -155,7 +149,11 @@ public class DorisStreamLoad implements Serializable {
             httpPut.setHeader("two_phase_commit", "true");
         }
         if (MapUtils.isNotEmpty(streamLoadProp)) {
-            streamLoadProp.forEach(httpPut::setHeader);
+            streamLoadProp.forEach((k, v) -> {
+                if (!"strip_outer_array".equalsIgnoreCase(k)) {
+                    httpPut.setHeader(k, v);
+                }
+            });
         }
         return httpPut;
     }
@@ -165,10 +163,10 @@ public class DorisStreamLoad implements Serializable {
         public String respMsg;
         public String respContent;
 
-        public LoadResponse(int status, String respMsg, String respContent) {
-            this.status = status;
-            this.respMsg = respMsg;
-            this.respContent = respContent;
+        public LoadResponse(HttpResponse response) throws IOException {
+            this.status = response.getStatusLine().getStatusCode();
+            this.respMsg = response.getStatusLine().getReasonPhrase();
+            this.respContent = EntityUtils.toString(new 
BufferedHttpEntity(response.getEntity()), StandardCharsets.UTF_8);
         }
 
         @Override
@@ -177,95 +175,34 @@ public class DorisStreamLoad implements Serializable {
         }
     }
 
-    public List<Integer> loadV2(List<List<Object>> rows, String[] dfColumns, 
Boolean enable2PC) throws StreamLoadException, JsonProcessingException {
-
-        List<String> loadData = parseLoadData(rows, dfColumns);
-        List<Integer> txnIds = new ArrayList<>(loadData.size());
-
-        try {
-            for (String data : loadData) {
-                txnIds.add(load(data, enable2PC));
-            }
-        } catch (StreamLoadException e) {
-            if (enable2PC && !txnIds.isEmpty()) {
-                LOG.error("load batch failed, abort previously pre-committed 
transactions");
-                for (Integer txnId : txnIds) {
-                    abort(txnId);
-                }
-            }
-            throw e;
-        }
-
-        return txnIds;
-
-    }
-
-    public List<Integer> loadStream(List<List<Object>> rows, String[] 
dfColumns, Boolean enable2PC)
+    public int load(Iterator<InternalRow> rows, StructType schema)
             throws StreamLoadException, JsonProcessingException {
 
-        List<String> loadData;
-
-        if (this.streamingPassthrough) {
-            handleStreamPassThrough();
-            loadData = passthrough(rows);
-        } else {
-            loadData = parseLoadData(rows, dfColumns);
-        }
-
-        List<Integer> txnIds = new ArrayList<>(loadData.size());
-
-        try {
-            for (String data : loadData) {
-                txnIds.add(load(data, enable2PC));
-            }
-        } catch (StreamLoadException e) {
-            if (enable2PC && !txnIds.isEmpty()) {
-                LOG.error("load batch failed, abort previously pre-committed 
transactions");
-                for (Integer txnId : txnIds) {
-                    abort(txnId);
-                }
-            }
-            throw e;
-        }
-
-        return txnIds;
-
-    }
-
-    public int load(String value, Boolean enable2PC) throws 
StreamLoadException {
-
         String label = generateLoadLabel();
-
         LoadResponse loadResponse;
-        int responseHttpStatus = -1;
         try (CloseableHttpClient httpClient = getHttpClient()) {
             String loadUrlStr = String.format(loadUrlPattern, getBackend(), 
db, tbl);
-            LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, value);
-            // only to record the BE node in case of an exception
             this.loadUrlStr = loadUrlStr;
-
             HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC);
-            httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8));
+            RecordBatchInputStream recodeBatchInputStream = new 
RecordBatchInputStream(RecordBatch.newBuilder(rows)
+                    .batchSize(batchSize)
+                    .format(fileType)
+                    .sep(FIELD_DELIMITER)
+                    .delim(LINE_DELIMITER)
+                    .schema(schema).build(), streamingPassthrough);
+            httpPut.setEntity(new InputStreamEntity(recodeBatchInputStream));
             HttpResponse httpResponse = httpClient.execute(httpPut);
-            responseHttpStatus = httpResponse.getStatusLine().getStatusCode();
-            String respMsg = httpResponse.getStatusLine().getReasonPhrase();
-            String response = EntityUtils.toString(new 
BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8);
-            loadResponse = new LoadResponse(responseHttpStatus, respMsg, 
response);
+            loadResponse = new LoadResponse(httpResponse);
         } catch (IOException e) {
-            e.printStackTrace();
-            String err = "http request exception,load url : " + loadUrlStr + 
",failed to execute spark stream load with label: " + label;
-            LOG.warn(err, e);
-            loadResponse = new LoadResponse(responseHttpStatus, 
e.getMessage(), err);
+            throw new RuntimeException(e);
         }
 
         if (loadResponse.status != HttpStatus.SC_OK) {
             LOG.info("Stream load Response HTTP Status Error:{}", 
loadResponse);
-            // throw new StreamLoadException("stream load error: " + 
loadResponse.respContent);
             throw new StreamLoadException("stream load error");
         } else {
-            ObjectMapper obj = new ObjectMapper();
             try {
-                RespContent respContent = 
obj.readValue(loadResponse.respContent, RespContent.class);
+                RespContent respContent = 
MAPPER.readValue(loadResponse.respContent, RespContent.class);
                 if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) {
                     LOG.error("Stream load Response RES STATUS Error:{}", 
loadResponse);
                     throw new StreamLoadException("stream load error");
@@ -279,6 +216,14 @@ public class DorisStreamLoad implements Serializable {
 
     }
 
+    public Integer loadStream(Iterator<InternalRow> rows, StructType schema)
+            throws StreamLoadException, JsonProcessingException {
+        if (this.streamingPassthrough) {
+            handleStreamPassThrough();
+        }
+        return load(rows, schema);
+    }
+
     public void commit(int txnId) throws StreamLoadException {
 
         try (CloseableHttpClient client = getHttpClient()) {
@@ -286,9 +231,7 @@ public class DorisStreamLoad implements Serializable {
             String backend = getBackend();
             String abortUrl = String.format(abortUrlPattern, backend, db, tbl);
             HttpPut httpPut = new HttpPut(abortUrl);
-            httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + 
authEncoded);
-            httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
-            httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
+            addCommonHeader(httpPut);
             httpPut.setHeader("txn_operation", "commit");
             httpPut.setHeader("txn_id", String.valueOf(txnId));
 
@@ -306,10 +249,9 @@ public class DorisStreamLoad implements Serializable {
                 throw new StreamLoadException("stream load error: " + 
reasonPhrase);
             }
 
-            ObjectMapper mapper = new ObjectMapper();
             if (response.getEntity() != null) {
                 String loadResult = EntityUtils.toString(response.getEntity());
-                Map<String, String> res = mapper.readValue(loadResult, new 
TypeReference<HashMap<String, String>>() {
+                Map<String, String> res = MAPPER.readValue(loadResult, new 
TypeReference<HashMap<String, String>>() {
                 });
                 if (res.get("status").equals("Fail") && 
!ResponseUtil.isCommitted(res.get("msg"))) {
                     throw new StreamLoadException("Commit failed " + 
loadResult);
@@ -331,9 +273,7 @@ public class DorisStreamLoad implements Serializable {
         try (CloseableHttpClient client = getHttpClient()) {
             String abortUrl = String.format(abortUrlPattern, getBackend(), db, 
tbl);
             HttpPut httpPut = new HttpPut(abortUrl);
-            httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + 
authEncoded);
-            httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
-            httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
+            addCommonHeader(httpPut);
             httpPut.setHeader("txn_operation", "abort");
             httpPut.setHeader("txn_id", String.valueOf(txnId));
 
@@ -344,9 +284,8 @@ public class DorisStreamLoad implements Serializable {
                 throw new StreamLoadException("Fail to abort transaction " + 
txnId + " with url " + abortUrl);
             }
 
-            ObjectMapper mapper = new ObjectMapper();
             String loadResult = EntityUtils.toString(response.getEntity());
-            Map<String, String> res = mapper.readValue(loadResult, new 
TypeReference<HashMap<String, String>>() {
+            Map<String, String> res = MAPPER.readValue(loadResult, new 
TypeReference<HashMap<String, String>>() {
             });
             if (!"Success".equals(res.get("status"))) {
                 if (ResponseUtil.isCommitted(res.get("msg"))) {
@@ -410,52 +349,13 @@ public class DorisStreamLoad implements Serializable {
 
     }
 
-    private List<String> parseLoadData(List<List<Object>> rows, String[] 
dfColumns) throws StreamLoadException, JsonProcessingException {
-
-        List<String> loadDataList;
-
-        switch (fileType.toUpperCase()) {
-
-            case "CSV":
-                loadDataList = Collections.singletonList(
-                        rows.stream()
-                                .map(row -> row.stream()
-                                        .map(DataUtil::handleColumnValue)
-                                        .map(Object::toString)
-                                        
.collect(Collectors.joining(FIELD_DELIMITER))
-                                ).collect(Collectors.joining(LINE_DELIMITER)));
-                break;
-            case "JSON":
-                List<Map<Object, Object>> dataList = new ArrayList<>();
-                try {
-                    for (List<Object> row : rows) {
-                        Map<Object, Object> dataMap = new HashMap<>();
-                        if (dfColumns.length == row.size()) {
-                            for (int i = 0; i < dfColumns.length; i++) {
-                                dataMap.put(dfColumns[i], 
DataUtil.handleColumnValue(row.get(i)));
-                            }
-                        }
-                        dataList.add(dataMap);
-                    }
-                } catch (Exception e) {
-                    throw new StreamLoadException("The number of configured 
columns does not match the number of data columns.");
-                }
-                // splits large collections to normal collection to avoid the 
"Requested array size exceeds VM limit" exception
-                loadDataList = ListUtils.getSerializedList(dataList, 
readJsonByLine ? LINE_DELIMITER : null);
-                break;
-            default:
-                throw new StreamLoadException(String.format("Unsupported file 
format in stream load: %s.", fileType));
-
-        }
-
-        return loadDataList;
-
-    }
-
     private String generateLoadLabel() {
 
         Calendar calendar = Calendar.getInstance();
-        return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", 
calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, 
calendar.get(Calendar.DAY_OF_MONTH), calendar.get(Calendar.HOUR_OF_DAY), 
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), 
UUID.randomUUID().toString().replaceAll("-", ""));
+        return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s",
+                calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, 
calendar.get(Calendar.DAY_OF_MONTH),
+                calendar.get(Calendar.HOUR_OF_DAY), 
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND),
+                UUID.randomUUID().toString().replaceAll("-", ""));
 
     }
 
@@ -478,6 +378,12 @@ public class DorisStreamLoad implements Serializable {
         return hexData;
     }
 
+    private void addCommonHeader(HttpRequestBase httpReq) {
+        httpReq.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
+        httpReq.setHeader(HttpHeaders.EXPECT, "100-continue");
+        httpReq.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
+    }
+
     private void handleStreamPassThrough() {
 
         if ("json".equalsIgnoreCase(fileType)) {
@@ -488,8 +394,4 @@ public class DorisStreamLoad implements Serializable {
 
     }
 
-    private List<String> passthrough(List<List<Object>> values) {
-        return values.stream().map(list -> 
list.get(0).toString()).collect(Collectors.toList());
-    }
-
 }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
new file mode 100644
index 0000000..779c057
--- /dev/null
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatch.java
@@ -0,0 +1,153 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructType;
+
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+
+/**
+ * Wrapper Object for batch loading
+ */
+public class RecordBatch {
+
+    private static final Charset DEFAULT_CHARSET = StandardCharsets.UTF_8;
+
+    /**
+     * Spark row data iterator
+     */
+    private final Iterator<InternalRow> iterator;
+
+    /**
+     * batch size for single load
+     */
+    private final int batchSize;
+
+    /**
+     * stream load format
+     */
+    private final String format;
+
+    /**
+     * column separator, only used when the format is csv
+     */
+    private final String sep;
+
+    /**
+     * line delimiter
+     */
+    private final byte[] delim;
+
+    /**
+     * schema of row
+     */
+    private final StructType schema;
+
+    private RecordBatch(Iterator<InternalRow> iterator, int batchSize, String 
format, String sep, byte[] delim,
+                        StructType schema) {
+        this.iterator = iterator;
+        this.batchSize = batchSize;
+        this.format = format;
+        this.sep = sep;
+        this.delim = delim;
+        this.schema = schema;
+    }
+
+    public Iterator<InternalRow> getIterator() {
+        return iterator;
+    }
+
+    public int getBatchSize() {
+        return batchSize;
+    }
+
+    public String getFormat() {
+        return format;
+    }
+
+    public String getSep() {
+        return sep;
+    }
+
+    public byte[] getDelim() {
+        return delim;
+    }
+
+    public StructType getSchema() {
+        return schema;
+    }
+    public static Builder newBuilder(Iterator<InternalRow> iterator) {
+        return new Builder(iterator);
+    }
+
+    /**
+     * RecordBatch Builder
+     */
+    public static class Builder {
+
+        private final Iterator<InternalRow> iterator;
+
+        private int batchSize;
+
+        private String format;
+
+        private String sep;
+
+        private byte[] delim;
+
+        private StructType schema;
+
+        public Builder(Iterator<InternalRow> iterator) {
+            this.iterator = iterator;
+        }
+
+        public Builder batchSize(int batchSize) {
+            this.batchSize = batchSize;
+            return this;
+        }
+
+        public Builder format(String format) {
+            this.format = format;
+            return this;
+        }
+
+        public Builder sep(String sep) {
+            this.sep = sep;
+            return this;
+        }
+
+        public Builder delim(String delim) {
+            this.delim = delim.getBytes(DEFAULT_CHARSET);
+            return this;
+        }
+
+        public Builder schema(StructType schema) {
+            this.schema = schema;
+            return this;
+        }
+
+        public RecordBatch build() {
+            return new RecordBatch(iterator, batchSize, format, sep, delim, 
schema);
+        }
+
+    }
+
+}
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
new file mode 100644
index 0000000..9444c1d
--- /dev/null
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/RecordBatchInputStream.java
@@ -0,0 +1,221 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load;
+
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.exception.IllegalArgumentException;
+import org.apache.doris.spark.exception.ShouldNeverHappenException;
+import org.apache.doris.spark.util.DataUtil;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+
+/**
+ * InputStream for batch load
+ */
+public class RecordBatchInputStream extends InputStream {
+
+    public static final Logger LOG = 
LoggerFactory.getLogger(RecordBatchInputStream.class);
+
+    private static final int DEFAULT_BUF_SIZE = 4096;
+
+    /**
+     * Load record batch
+     */
+    private final RecordBatch recordBatch;
+
+    /**
+     * first line flag
+     */
+    private boolean isFirst = true;
+
+    /**
+     * record buffer
+     */
+    private ByteBuffer buffer = ByteBuffer.allocate(0);
+
+    /**
+     * record count has been read
+     */
+    private int readCount = 0;
+
+    /**
+     * streaming mode pass through data without process
+     */
+    private final boolean passThrough;
+
+    public RecordBatchInputStream(RecordBatch recordBatch, boolean 
passThrough) {
+        this.recordBatch = recordBatch;
+        this.passThrough = passThrough;
+    }
+
+    @Override
+    public int read() throws IOException {
+        try {
+            if (buffer.remaining() == 0 && endOfBatch()) {
+                return -1; // End of stream
+            }
+        } catch (DorisException e) {
+            throw new IOException(e);
+        }
+        return buffer.get() & 0xFF;
+    }
+
+    @Override
+    public int read(byte[] b, int off, int len) throws IOException {
+        try {
+            if (buffer.remaining() == 0 && endOfBatch()) {
+                return -1; // End of stream
+            }
+        } catch (DorisException e) {
+            throw new IOException(e);
+        }
+        int bytesRead = Math.min(len, buffer.remaining());
+        buffer.get(b, off, bytesRead);
+        return bytesRead;
+    }
+
+    /**
+     * Check if the current batch read is over.
+     * If the number of reads is greater than or equal to the batch size or 
there is no next record, return false,
+     * otherwise return true.
+     *
+     * @return Whether the current batch read is over
+     * @throws DorisException
+     */
+    public boolean endOfBatch() throws DorisException {
+        Iterator<InternalRow> iterator = recordBatch.getIterator();
+        if (readCount >= recordBatch.getBatchSize() || !iterator.hasNext()) {
+            return true;
+        }
+        readNext(iterator);
+        return false;
+    }
+
+    /**
+     * read next record into buffer
+     *
+     * @param iterator row iterator
+     * @throws DorisException
+     */
+    private void readNext(Iterator<InternalRow> iterator) throws 
DorisException {
+        if (!iterator.hasNext()) {
+            throw new ShouldNeverHappenException();
+        }
+        byte[] delim = recordBatch.getDelim();
+        byte[] rowBytes = rowToByte(iterator.next());
+        if (isFirst) {
+            ensureCapacity(rowBytes.length);
+            buffer.put(rowBytes);
+            buffer.flip();
+            isFirst = false;
+        } else {
+            ensureCapacity(delim.length + rowBytes.length);
+            buffer.put(delim);
+            buffer.put(rowBytes);
+            buffer.flip();
+        }
+        readCount++;
+    }
+
+    /**
+     * Check if the buffer has enough capacity.
+     *
+     * @param need required buffer space
+     */
+    private void ensureCapacity(int need) {
+
+        int capacity = buffer.capacity();
+
+        if (need <= capacity) {
+            buffer.clear();
+            return;
+        }
+
+        // need to extend
+        int newCapacity = calculateNewCapacity(capacity, need);
+        LOG.info("expand buffer, min cap: {}, now cap: {}, new cap: {}", need, 
capacity, newCapacity);
+        buffer = ByteBuffer.allocate(newCapacity);
+
+    }
+
+    /**
+     * Calculate new capacity for buffer expansion.
+     *
+     * @param capacity current buffer capacity
+     * @param minCapacity required min buffer space
+     * @return new capacity
+     */
+    private int calculateNewCapacity(int capacity, int minCapacity) {
+        int newCapacity;
+        if (capacity == 0) {
+            newCapacity = DEFAULT_BUF_SIZE;
+            while (newCapacity < minCapacity) {
+                newCapacity = newCapacity << 1;
+            }
+        } else {
+            newCapacity = capacity << 1;
+        }
+        return newCapacity;
+    }
+
+    /**
+     * Convert Spark row data to byte array
+     *
+     * @param row row data
+     * @return byte array
+     * @throws DorisException
+     */
+    private byte[] rowToByte(InternalRow row) throws DorisException {
+
+        byte[] bytes;
+
+        if (passThrough) {
+            bytes = row.getString(0).getBytes(StandardCharsets.UTF_8);
+            return bytes;
+        }
+
+        switch (recordBatch.getFormat().toLowerCase()) {
+            case "csv":
+                bytes = DataUtil.rowToCsvBytes(row, recordBatch.getSchema(), 
recordBatch.getSep());
+                break;
+            case "json":
+                try {
+                    bytes = DataUtil.rowToJsonBytes(row, 
recordBatch.getSchema());
+                } catch (JsonProcessingException e) {
+                    throw new DorisException("parse row to json bytes failed", 
e);
+                }
+                break;
+            default:
+                throw new IllegalArgumentException("format", 
recordBatch.getFormat());
+        }
+
+        return bytes;
+
+    }
+
+
+}
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index faa8ef5..3d66db5 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -17,19 +17,11 @@
 
 package org.apache.doris.spark.serialization;
 
-import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.math.BigDecimal;
-import java.math.BigInteger;
-import java.nio.charset.StandardCharsets;
-import java.sql.Date;
-import java.time.LocalDate;
-import java.time.LocalDateTime;
-import java.time.format.DateTimeFormatter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.NoSuchElementException;
+import org.apache.doris.sdk.thrift.TScanBatchResult;
+import org.apache.doris.spark.exception.DorisException;
+import org.apache.doris.spark.rest.models.Schema;
 
+import com.google.common.base.Preconditions;
 import org.apache.arrow.memory.RootAllocator;
 import org.apache.arrow.vector.BigIntVector;
 import org.apache.arrow.vector.BitVector;
@@ -47,17 +39,21 @@ import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.ipc.ArrowStreamReader;
 import org.apache.arrow.vector.types.Types;
-
-import org.apache.doris.sdk.thrift.TScanBatchResult;
-import org.apache.doris.spark.exception.DorisException;
-import org.apache.doris.spark.rest.models.Schema;
-
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.spark.sql.types.Decimal;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.base.Preconditions;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.nio.charset.StandardCharsets;
+import java.sql.Date;
+import java.time.LocalDate;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.NoSuchElementException;
 
 /**
  * row batch data container.
@@ -128,7 +124,11 @@ public class RowBatch {
     }
 
     public boolean hasNext() {
-        return offsetInRowBatch < readRowCount;
+        if (offsetInRowBatch >= readRowCount) {
+            rowBatch.clear();
+            return false;
+        }
+        return true;
     }
 
     private void addValueToRow(int rowIndex, Object obj) {
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
index 5877447..aea6dde 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/DataUtil.java
@@ -17,35 +17,48 @@
 
 package org.apache.doris.spark.util;
 
-import scala.collection.JavaConversions;
-import scala.collection.mutable.WrappedArray;
+import org.apache.doris.spark.sql.SchemaUtils;
 
-import java.sql.Date;
-import java.sql.Timestamp;
-import java.util.Arrays;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
 
-public class DataUtil {
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.Map;
 
-    public static final String NULL_VALUE = "\\N";
+public class DataUtil {
 
-    public static Object handleColumnValue(Object value) {
+    private static final ObjectMapper MAPPER = new ObjectMapper();
 
-        if (value == null) {
-            return NULL_VALUE;
-        }
+    public static final String NULL_VALUE = "\\N";
 
-        if (value instanceof Date || value instanceof Timestamp) {
-            return value.toString();
+    public static byte[] rowToCsvBytes(InternalRow row, StructType schema, 
String sep) {
+        StringBuilder builder = new StringBuilder();
+        StructField[] fields = schema.fields();
+        int n = row.numFields();
+        if (n > 0) {
+            builder.append(SchemaUtils.rowColumnValue(row, 0, 
fields[0].dataType()));
+            int i = 1;
+            while (i < n) {
+                builder.append(sep);
+                builder.append(SchemaUtils.rowColumnValue(row, i, 
fields[i].dataType()));
+                i++;
+            }
         }
+        return builder.toString().getBytes(StandardCharsets.UTF_8);
+    }
 
-        if (value instanceof WrappedArray) {
-
-            Object[] arr = JavaConversions.seqAsJavaList((WrappedArray) 
value).toArray();
-            return Arrays.toString(arr);
+    public static byte[] rowToJsonBytes(InternalRow row, StructType schema)
+            throws JsonProcessingException {
+        StructField[] fields = schema.fields();
+        Map<String, Object> rowMap = new HashMap<>(row.numFields());
+        for (int i = 0; i < fields.length; i++) {
+            rowMap.put(fields[i].name(), SchemaUtils.rowColumnValue(row, i, 
fields[i].dataType()));
         }
-
-        return value;
-
+        return MAPPER.writeValueAsBytes(rowMap);
     }
 
 }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
index d8d31b9..fbfab9a 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ListUtils.java
@@ -34,7 +34,7 @@ public class ListUtils {
     private static final ObjectMapper MAPPER = new ObjectMapper();
 
     public static List<String> getSerializedList(List<Map<Object, Object>> 
batch,
-                                                 String lineDelimiter) throws 
JsonProcessingException {
+                                           String lineDelimiter) throws 
JsonProcessingException {
         List<String> result = new ArrayList<>();
         divideAndSerialize(batch, result, lineDelimiter);
         return result;
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
index d08bdc0..9dee515 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/package.scala
@@ -18,18 +18,18 @@
 package org.apache.doris
 
 import scala.language.implicitConversions
-
 import org.apache.doris.spark.rdd.DorisSpark
 import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
 
 package object spark {
-  implicit def sparkContextFunctions(sc: SparkContext) = new 
SparkContextFunctions(sc)
+  implicit def sparkContextFunctions(sc: SparkContext): SparkContextFunctions 
= new SparkContextFunctions(sc)
 
   class SparkContextFunctions(sc: SparkContext) extends Serializable {
     def dorisRDD(
         tableIdentifier: Option[String] = None,
         query: Option[String] = None,
-        cfg: Option[Map[String, String]] = None) =
+        cfg: Option[Map[String, String]] = None): RDD[AnyRef] =
       DorisSpark.dorisRDD(sc, tableIdentifier, query, cfg)
   }
 }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
index 06f5ca3..ec8f887 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRow.scala
@@ -27,7 +27,7 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String]) 
extends Row {
   /** No-arg constructor for Kryo serialization. */
   def this() = this(null)
 
-  def iterator = values.iterator
+  def iterator: Iterator[Any] = values.iterator
 
   override def length: Int = values.length
 
@@ -51,9 +51,9 @@ private[spark] class ScalaDorisRow(rowOrder: Seq[String]) 
extends Row {
 
   override def getByte(i: Int): Byte = getAs[Byte](i)
 
-  override def getString(i: Int): String = get(i).toString()
+  override def getString(i: Int): String = get(i).toString
 
   override def copy(): Row = this
 
-  override def toSeq = values.toSeq
+  override def toSeq: Seq[Any] = values
 }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
index c8aa034..f5a6a15 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/SchemaUtils.scala
@@ -18,16 +18,23 @@
 package org.apache.doris.spark.sql
 
 import org.apache.doris.sdk.thrift.TScanColumnDesc
-
-import scala.collection.JavaConversions._
+import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, 
DORIS_READ_FIELD}
 import org.apache.doris.spark.cfg.Settings
 import org.apache.doris.spark.exception.DorisException
 import org.apache.doris.spark.rest.RestService
 import org.apache.doris.spark.rest.models.{Field, Schema}
-import org.apache.doris.spark.cfg.ConfigurationOptions.{DORIS_IGNORE_TYPE, 
DORIS_READ_FIELD}
+import org.apache.doris.spark.util.DataUtil
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.slf4j.LoggerFactory
 
+import java.sql.Timestamp
+import java.time.{LocalDateTime, ZoneOffset}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
 private[spark] object SchemaUtils {
   private val logger = 
LoggerFactory.getLogger(SchemaUtils.getClass.getSimpleName.stripSuffix("$"))
 
@@ -137,4 +144,49 @@ private[spark] object SchemaUtils {
     tscanColumnDescs.foreach(desc => schema.put(new Field(desc.getName, 
desc.getType.name, "", 0, 0, "")))
     schema
   }
+
+  def rowColumnValue(row: SpecializedGetters, ordinal: Int, dataType: 
DataType): Any = {
+
+    dataType match {
+      case NullType => DataUtil.NULL_VALUE
+      case BooleanType => row.getBoolean(ordinal)
+      case ByteType => row.getByte(ordinal)
+      case ShortType => row.getShort(ordinal)
+      case IntegerType => row.getInt(ordinal)
+      case LongType => row.getLong(ordinal)
+      case FloatType => row.getFloat(ordinal)
+      case DoubleType => row.getDouble(ordinal)
+      case StringType => row.getUTF8String(ordinal).toString
+      case TimestampType =>
+        LocalDateTime.ofEpochSecond(row.getLong(ordinal) / 100000, 
(row.getLong(ordinal) % 1000).toInt, ZoneOffset.UTC)
+        new Timestamp(row.getLong(ordinal) / 1000).toString
+      case DateType => DateTimeUtils.toJavaDate(row.getInt(ordinal)).toString
+      case BinaryType => row.getBinary(ordinal)
+      case dt: DecimalType => row.getDecimal(ordinal, dt.precision, dt.scale)
+      case at: ArrayType =>
+        val arrayData = row.getArray(ordinal)
+        var i = 0
+        val buffer = mutable.Buffer[Any]()
+        while (i < arrayData.numElements()) {
+          if (arrayData.isNullAt(i)) buffer += null else buffer += 
rowColumnValue(arrayData, i, at.elementType)
+          i += 1
+        }
+        s"[${buffer.mkString(",")}]"
+      case mt: MapType =>
+        val mapData = row.getMap(ordinal)
+        val keys = mapData.keyArray()
+        val values = mapData.valueArray()
+        var i = 0
+        val map = mutable.Map[Any, Any]()
+        while (i < keys.numElements()) {
+          map += rowColumnValue(keys, i, mt.keyType) -> rowColumnValue(values, 
i, mt.valueType)
+          i += 1
+        }
+        map.toMap.asJava
+      case st: StructType => row.getStruct(ordinal, st.length)
+      case _ => throw new DorisException(s"Unsupported spark type: 
${dataType.typeName}")
+    }
+
+  }
+
 }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index e32267e..b278a38 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -39,8 +39,6 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
 
   private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])
 
-  val batchSize: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_SIZE,
-    ConfigurationOptions.SINK_BATCH_SIZE_DEFAULT)
   private val maxRetryTimes: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_MAX_RETRIES,
     ConfigurationOptions.SINK_MAX_RETRIES_DEFAULT)
   private val sinkTaskPartitionSize: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TASK_PARTITION_SIZE)
@@ -55,45 +53,16 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
   private val dorisStreamLoader: DorisStreamLoad = 
CachedDorisStreamLoadClient.getOrCreate(settings)
 
   def write(dataFrame: DataFrame): Unit = {
+    doWrite(dataFrame, dorisStreamLoader.load)
+  }
 
-    val sc = dataFrame.sqlContext.sparkContext
-    val preCommittedTxnAcc = 
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
-    if (enable2PC) {
-      sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, 
dorisStreamLoader))
-    }
+  def writeStream(dataFrame: DataFrame): Unit = {
+    doWrite(dataFrame, dorisStreamLoader.loadStream)
+  }
 
-    var resultRdd = dataFrame.rdd
-    val dfColumns = dataFrame.columns
-    if (Objects.nonNull(sinkTaskPartitionSize)) {
-      resultRdd = if (sinkTaskUseRepartition) 
resultRdd.repartition(sinkTaskPartitionSize) else 
resultRdd.coalesce(sinkTaskPartitionSize)
-    }
-    resultRdd
-      .map(_.toSeq.map(_.asInstanceOf[AnyRef]).toList.asJava)
-      .foreachPartition(partition => {
-        partition
-          .grouped(batchSize)
-          .foreach(batch => flush(batch, dfColumns))
-      })
-
-    /**
-     * flush data to Doris and do retry when flush error
-     *
-     */
-    def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = 
{
-      Utils.retry[util.List[Integer], Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
-        dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC)
-      } match {
-        case Success(txnIds) => if (enable2PC) 
handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
-        case Failure(e) =>
-          if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
-          throw new IOException(
-            s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
-      }
-    }
+  private def doWrite(dataFrame: DataFrame, loadFunc: 
(util.Iterator[InternalRow], StructType) => Int): Unit = {
 
-  }
 
-  def writeStream(dataFrame: DataFrame): Unit = {
 
     val sc = dataFrame.sqlContext.sparkContext
     val preCommittedTxnAcc = 
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
@@ -103,47 +72,31 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
 
     var resultRdd = dataFrame.queryExecution.toRdd
     val schema = dataFrame.schema
-    val dfColumns = dataFrame.columns
     if (Objects.nonNull(sinkTaskPartitionSize)) {
       resultRdd = if (sinkTaskUseRepartition) 
resultRdd.repartition(sinkTaskPartitionSize) else 
resultRdd.coalesce(sinkTaskPartitionSize)
     }
-    resultRdd
-      .foreachPartition(partition => {
-        partition
-          .grouped(batchSize)
-          .foreach(batch =>
-            flush(batch, dfColumns))
-      })
-
-    /**
-     * flush data to Doris and do retry when flush error
-     *
-     */
-    def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = {
-      Utils.retry[util.List[Integer], Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
-        dorisStreamLoader.loadStream(convertToObjectList(batch, schema), 
dfColumns, enable2PC)
-      } match {
-        case Success(txnIds) => if (enable2PC) 
handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
-        case Failure(e) =>
-          if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
-          throw new IOException(
-            s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
+    resultRdd.foreachPartition(iterator => {
+      while (iterator.hasNext) {
+        // do load batch with retries
+        Utils.retry[Int, Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
+          loadFunc(iterator.asJava, schema)
+        } match {
+          case Success(txnId) => if (enable2PC) handleLoadSuccess(txnId, 
preCommittedTxnAcc)
+          case Failure(e) =>
+            if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+            throw new IOException(
+              s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
+        }
       }
-    }
-
-    def convertToObjectList(rows: Seq[InternalRow], schema: StructType): 
util.List[util.List[Object]] = {
-      rows.map(row => {
-        row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava
-      }).asJava
-    }
+    })
 
   }
 
-  private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: 
CollectionAccumulator[Int]): Unit = {
-    txnIds.foreach(txnId => acc.add(txnId))
+  private def handleLoadSuccess(txnId: Int, acc: CollectionAccumulator[Int]): 
Unit = {
+    acc.add(txnId)
   }
 
-  def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
+  private def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
     // if task run failed, acc value will not be returned to driver,
     // should abort all pre committed transactions inside the task
     logger.info("load task failed, start aborting previously pre-committed 
transactions")
diff --git 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java
 
b/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java
deleted file mode 100644
index 020a241..0000000
--- 
a/spark-doris-connector/src/test/java/org/apache/doris/spark/util/DataUtilTest.java
+++ /dev/null
@@ -1,32 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.spark.util;
-
-import junit.framework.TestCase;
-import org.junit.Assert;
-import scala.collection.mutable.WrappedArray;
-
-import java.sql.Timestamp;
-
-public class DataUtilTest extends TestCase {
-
-    public void testHandleColumnValue() {
-        Assert.assertEquals("2023-08-14 18:00:00.0", 
DataUtil.handleColumnValue(Timestamp.valueOf("2023-08-14 18:00:00")));
-        Assert.assertEquals("[1, 2, 3]", 
DataUtil.handleColumnValue(WrappedArray.make(new Integer[]{1,2,3})));
-    }
-}
\ No newline at end of file
diff --git 
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala
 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala
new file mode 100644
index 0000000..e3868cb
--- /dev/null
+++ 
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/SchemaUtilsTest.scala
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.sql
+
+import org.apache.spark.sql.SparkSession
+import org.junit.{Assert, Ignore, Test}
+
+import java.sql.{Date, Timestamp}
+import scala.collection.JavaConverters._
+
+@Ignore
+class SchemaUtilsTest {
+
+  @Test
+  def rowColumnValueTest(): Unit = {
+
+    val spark = SparkSession.builder().master("local").getOrCreate()
+
+    val df = spark.createDataFrame(Seq(
+      (1, Date.valueOf("2023-09-08"), Timestamp.valueOf("2023-09-08 
17:00:00"), Array(1, 2, 3), Map[String, String]("a" -> "1"))
+    )).toDF("c1", "c2", "c3", "c4", "c5")
+
+    val schema = df.schema
+
+    df.queryExecution.toRdd.foreach(row => {
+
+      val fields = schema.fields
+      Assert.assertEquals(1, SchemaUtils.rowColumnValue(row, 0, 
fields(0).dataType))
+      Assert.assertEquals("2023-09-08", SchemaUtils.rowColumnValue(row, 1, 
fields(1).dataType))
+      Assert.assertEquals("2023-09-08 17:00:00.0", 
SchemaUtils.rowColumnValue(row, 2, fields(2).dataType))
+      Assert.assertEquals("[1,2,3]", SchemaUtils.rowColumnValue(row, 3, 
fields(3).dataType))
+      println(SchemaUtils.rowColumnValue(row, 4, fields(4).dataType))
+      Assert.assertEquals(Map("a" -> "1").asJava, 
SchemaUtils.rowColumnValue(row, 4, fields(4).dataType))
+
+    })
+
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to