This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new d157a44  [refactor](load) Refactor the loader to expand the way data 
is written (#187)
d157a44 is described below

commit d157a44968b5f23e7a2533ecda1d447b50ec3907
Author: gnehil <[email protected]>
AuthorDate: Thu Feb 1 10:56:24 2024 +0800

    [refactor](load) Refactor the loader to expand the way data is written 
(#187)
---
 spark-doris-connector/pom.xml                      |   2 +-
 .../apache/doris/spark/load/DorisStreamLoad.java   |   1 +
 .../doris/spark/rest/models/RespContent.java       |  11 +
 .../scala/org/apache/doris/spark/load/Loader.scala |  58 +++
 .../org/apache/doris/spark/load/StreamLoader.scala | 536 +++++++++++++++++++++
 .../doris/spark/sql/DorisSourceProvider.scala      |   5 +-
 .../doris/spark/sql/DorisStreamLoadSink.scala      |  12 +-
 .../doris/spark/txn/TransactionHandler.scala       |  56 +--
 .../txn/listener/DorisTransactionListener.scala    |  20 +-
 .../listener/DorisTxnStreamingQueryListener.scala  |  25 +-
 .../apache/doris/spark/writer/DorisWriter.scala    |  63 +--
 11 files changed, 697 insertions(+), 92 deletions(-)

diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml
index ee39932..bc7dc4a 100644
--- a/spark-doris-connector/pom.xml
+++ b/spark-doris-connector/pom.xml
@@ -67,7 +67,7 @@
     </mailingLists>
 
     <properties>
-        <revision>1.3.0-SNAPSHOT</revision>
+        <revision>1.4.0-SNAPSHOT</revision>
         <spark.version>3.1.2</spark.version>
         <spark.major.version>3.1</spark.major.version>
         <scala.version>2.12</scala.version>
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index ac74063..6b1708d 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -78,6 +78,7 @@ import java.util.zip.GZIPOutputStream;
 /**
  * DorisStreamLoad
  **/
+@Deprecated
 public class DorisStreamLoad implements Serializable {
 
     private static final Logger LOG = 
LoggerFactory.getLogger(DorisStreamLoad.class);
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
index a52ea13..60e3949 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
@@ -21,9 +21,15 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
 
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
 @JsonIgnoreProperties(ignoreUnknown = true)
 public class RespContent {
 
+    private final static List<String> DORIS_SUCCESS_STATUS = new 
ArrayList<>(Arrays.asList("Success", "Publish Timeout"));
+
     @JsonProperty(value = "TxnId")
     private long TxnId;
 
@@ -97,4 +103,9 @@ public class RespContent {
         }
 
     }
+
+    public boolean isSuccess() {
+        return DORIS_SUCCESS_STATUS.contains(getStatus());
+    }
+
 }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/load/Loader.scala 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/load/Loader.scala
new file mode 100644
index 0000000..e8489d0
--- /dev/null
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/load/Loader.scala
@@ -0,0 +1,58 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Loader, interface class for write data to doris
+ */
+trait Loader extends Serializable {
+
+  /**
+   * execute load
+   *
+   * @param iterator row data iterator
+   * @param schema row data schema
+   * @return commit message
+   */
+  def load(iterator: Iterator[InternalRow], schema: StructType): 
Option[CommitMessage]
+
+  /**
+   * commit transaction
+   *
+   * @param msg commit message
+   */
+  def commit(msg: CommitMessage): Unit
+
+  /**
+   * abort transaction
+   *
+   * @param msg commit message
+   */
+  def abort(msg: CommitMessage): Unit
+
+}
+
+/**
+ * Commit message class
+ *
+ * @param value message value
+ */
+case class CommitMessage(value: Any) extends Serializable
\ No newline at end of file
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
new file mode 100644
index 0000000..8cb4942
--- /dev/null
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/load/StreamLoader.scala
@@ -0,0 +1,536 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.load
+
+import com.fasterxml.jackson.core.`type`.TypeReference
+import com.fasterxml.jackson.databind.ObjectMapper
+import com.fasterxml.jackson.databind.json.JsonMapper
+import org.apache.commons.lang3.StringUtils
+import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
+import org.apache.doris.spark.exception.{IllegalArgumentException, 
StreamLoadException}
+import org.apache.doris.spark.rest.RestService
+import org.apache.doris.spark.rest.models.BackendV2.BackendRowV2
+import org.apache.doris.spark.rest.models.RespContent
+import org.apache.doris.spark.sql.Utils
+import org.apache.doris.spark.util.ResponseUtil
+import org.apache.http.client.methods.{CloseableHttpResponse, HttpPut, 
HttpRequestBase, HttpUriRequest}
+import org.apache.http.entity.{BufferedHttpEntity, ByteArrayEntity, 
InputStreamEntity}
+import org.apache.http.impl.client.{CloseableHttpClient, 
DefaultRedirectStrategy, HttpClients}
+import org.apache.http.util.EntityUtils
+import org.apache.http.{HttpEntity, HttpHeaders, HttpStatus}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.StructType
+import org.slf4j.{Logger, LoggerFactory}
+
+import java.io.{ByteArrayOutputStream, IOException}
+import java.net.{HttpURLConnection, URL}
+import java.nio.charset.StandardCharsets
+import java.util
+import java.util.concurrent.ExecutionException
+import java.util.zip.GZIPOutputStream
+import java.util.{Base64, Calendar, Collections, UUID}
+import scala.collection.JavaConverters._
+import scala.util.{Failure, Success, Try}
+
+case class StreamLoadResponse(code: Int, msg: String, content: String)
+
+class StreamLoader(settings: SparkSettings, isStreaming: Boolean) extends 
Loader {
+
+  private final val LOG: Logger = 
LoggerFactory.getLogger(classOf[StreamLoader])
+
+  private final val MAPPER: ObjectMapper = JsonMapper.builder().build()
+
+  private val LOAD_URL_PATTERN = "http://%s/api/%s/%s/_stream_load";
+
+  private val LOAD_2PC_URL_PATTERN = "http://%s/api/%s/_stream_load_2pc";
+
+  private val database: String = 
settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.")(0)
+
+  private val table: String = 
settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.")(1)
+
+  private val authEncoded: String = getAuthEncoded
+
+  private val enableTwoPhaseCommit = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
+    ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
+
+  private val streamLoadProps: Map[String, String] = getStreamLoadProps
+
+  private val format: DataFormat = 
DataFormat.valueOf(streamLoadProps.getOrElse("format", "csv").toUpperCase)
+
+  private var currentLoadUrl: String = _
+
+  /**
+   * execute stream load
+   *
+   * @param iterator row data iterator
+   * @param schema row schema
+   * @throws stream load exception
+   * @return transaction id
+   */
+  @throws[StreamLoadException]
+  override def load(iterator: Iterator[InternalRow], schema: StructType): 
Option[CommitMessage] = {
+
+    var msg: Option[CommitMessage] = None
+
+    val client: CloseableHttpClient = getHttpClient
+    val label: String = generateLoadLabel()
+
+    Try {
+      val request = buildLoadRequest(iterator, schema, label)
+      val response = client.execute(request)
+      val txnId = handleStreamLoadResponse(response)
+      msg = Some(CommitMessage(txnId))
+    } match {
+      case Success(_) => client.close()
+      case Failure(e) =>
+        LOG.error(s"stream load failed, err: 
${ExceptionUtils.getStackTrace(e)}")
+        if (enableTwoPhaseCommit) abortByLabel(label)
+        if (e.isInstanceOf[StreamLoadException]) throw e
+        throw new StreamLoadException(s"failed to load data on 
$currentLoadUrl", e)
+    }
+
+    msg
+
+  }
+
+  /**
+   * commit transaction
+   *
+   * @param msg commit message with transaction id
+   * @throws stream load exception
+   */
+  @throws[StreamLoadException]
+  override def commit(msg: CommitMessage): Unit = {
+
+    val client = getHttpClient
+
+    Try {
+
+      val address = getAddress
+      val abortUrl = String.format(LOAD_2PC_URL_PATTERN, address, database)
+      val httpPut = new HttpPut(abortUrl)
+      addCommonHeader(httpPut)
+      httpPut.setHeader("txn_operation", "commit")
+      httpPut.setHeader("txn_id", String.valueOf(msg.value))
+
+      val response = client.execute(httpPut)
+      var statusCode = response.getStatusLine.getStatusCode
+      if (statusCode != 200 || response.getEntity == null) {
+        LOG.warn("commit transaction response: " + 
response.getStatusLine.toString)
+        throw new StreamLoadException("Fail to commit transaction " + 
msg.value + " with url " + abortUrl)
+      }
+
+      statusCode = response.getStatusLine.getStatusCode
+      val reasonPhrase = response.getStatusLine.getReasonPhrase
+      if (statusCode != 200) {
+        LOG.warn(s"commit failed with $address, reason $reasonPhrase")
+        throw new StreamLoadException("stream load error: " + reasonPhrase)
+      }
+
+      if (response.getEntity != null) {
+        val loadResult = EntityUtils.toString(response.getEntity)
+        val res = MAPPER.readValue(loadResult, new 
TypeReference[util.HashMap[String, String]]() {})
+        if (res.get("status") == "Fail" && 
!ResponseUtil.isCommitted(res.get("msg"))) throw new 
StreamLoadException("Commit failed " + loadResult)
+        else LOG.info("load result {}", loadResult)
+      }
+
+    } match {
+      case Success(_) => client.close()
+      case Failure(e) =>
+        client.close()
+        LOG.error("commit transaction failed, {}", 
ExceptionUtils.getStackTrace(e))
+        if (e.isInstanceOf[StreamLoadException]) throw e
+        throw new StreamLoadException(e)
+    }
+
+  }
+
+  /**
+   * abort transaction
+   *
+   * @param msg commit message with transaction id
+   * @throws stream load exception
+   */
+  override def abort(msg: CommitMessage): Unit = {
+    doAbort(_.setHeader("txn_id", String.valueOf(msg.value)))
+  }
+
+  private def abortByLabel(label: String): Unit = {
+    Try {
+      doAbort(_.setHeader("label", label))
+    } match {
+      case Success(_) => // do nothing
+      case Failure(e) =>
+        LOG.warn(s"abort by label failed, label: $label, err: 
${ExceptionUtils.getStackTrace(e)}")
+    }
+  }
+
+  /**
+   * get stream load properties from settings
+   *
+   * @return map data of stream load properties
+   */
+  private def getStreamLoadProps: Map[String, String] = {
+    val props = 
settings.asProperties().asScala.filter(_._1.startsWith(ConfigurationOptions.STREAM_LOAD_PROP_PREFIX))
+      .map { case (k,v) => 
(k.substring(ConfigurationOptions.STREAM_LOAD_PROP_PREFIX.length), v)}
+    if (props.getOrElse("add_double_quotes", "false").toBoolean) {
+      LOG.info("set add_double_quotes for csv mode, add trim_double_quotes to 
true for prop.")
+      props.put("trim_double_quotes", "true")
+    }
+    if ("json".equalsIgnoreCase(props.getOrElse("format", "csv"))) {
+      props += "read_json_by_line" -> "true"
+      props.remove("strip_outer_array")
+    }
+    props.remove("columns")
+    props.toMap
+  }
+
+  /**
+   * get http client
+   *
+   * @return http client
+   */
+  private def getHttpClient: CloseableHttpClient = {
+    HttpClients.custom().setRedirectStrategy(new DefaultRedirectStrategy() {
+      override def isRedirectable(method: String): Boolean = true
+    }).build()
+  }
+
+  /**
+   * add some common header for doris http request
+   *
+   * @param req http request
+   */
+  private def addCommonHeader(req: HttpRequestBase): Unit = {
+    req.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded)
+    req.setHeader(HttpHeaders.EXPECT, "100-continue")
+    req.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; charset=UTF-8")
+  }
+
+  /**
+   * build load request, set params as request header
+   *
+   * @param iterator row data iterator
+   * @param schema row data schema
+   * @param label load label
+   * @return http request
+   */
+  private def buildLoadRequest(iterator: Iterator[InternalRow], schema: 
StructType, label: String): HttpUriRequest = {
+
+    currentLoadUrl = String.format(LOAD_URL_PATTERN, getAddress, database, 
table)
+    val put = new HttpPut(currentLoadUrl)
+    addCommonHeader(put)
+
+    put.setHeader("label", label)
+
+    val columns = settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS)
+    if (StringUtils.isNotBlank(columns)) {
+      put.setHeader("columns", columns)
+    } else if (schema != null && schema.nonEmpty) {
+      put.setHeader("columns", 
schema.fieldNames.map(Utils.quote).mkString(","))
+    }
+
+    val maxFilterRatio = 
settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO)
+    if (StringUtils.isNotBlank(maxFilterRatio)) 
put.setHeader("max_filter_ratio", maxFilterRatio)
+
+    if (enableTwoPhaseCommit) put.setHeader("two_phase_commit", "true")
+
+    if (streamLoadProps != null && streamLoadProps.nonEmpty) {
+      streamLoadProps.foreach(prop => put.setHeader(prop._1, prop._2))
+    }
+
+    put.setEntity(generateHttpEntity(iterator, schema))
+
+    put
+
+  }
+
+  /**
+   * get load address
+   * if enable auto redirect, return fe address,
+   * or be node is configured, return one be address randomly.
+   * otherwise, request fe to get alive be node, and return address.
+   *
+   * if load data to be directly, check node available will be done before 
return.
+   *
+   * @throws [[org.apache.doris.spark.exception.StreamLoadException]]
+   * @return address
+   */
+  @throws[StreamLoadException]
+  private def getAddress: String = {
+
+    var address: Option[String] = None
+
+    Try {
+
+      val autoRedirect: Boolean = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_AUTO_REDIRECT,
+        ConfigurationOptions.DORIS_SINK_AUTO_REDIRECT_DEFAULT)
+      if (autoRedirect) {
+        val feNodes = settings.getProperty(ConfigurationOptions.DORIS_FENODES)
+        address = Some(RestService.randomEndpoint(feNodes, LOG))
+      } else {
+        val backends = RestService.getBackendRows(settings, LOG)
+        val iter = backends.iterator()
+        while (iter.hasNext) {
+          if (!checkAvailable(iter.next())) {
+            iter.remove()
+          }
+        }
+        if (backends.isEmpty) throw new StreamLoadException("no backend alive")
+        Collections.shuffle(backends)
+        val backend = backends.get(0)
+        address = Some(backend.getIp + ":" + backend.getHttpPort)
+      }
+
+    } match {
+      case Success(_) => // do nothing
+      case Failure(e: ExecutionException) => throw new 
StreamLoadException("get backends info fail", e)
+      case Failure(e: IllegalArgumentException) => throw new 
StreamLoadException("get frontend info fail", e)
+    }
+
+    address.get
+
+  }
+
+  /**
+   * check be is alive or not
+   *
+   * @param backend backend
+   * @return is alive or not
+   */
+  private def checkAvailable(backend: BackendRowV2): Boolean = {
+    Try {
+      val url = new URL(s"http://${backend.getIp}:${backend.getHttpPort}";)
+      val connection = url.openConnection.asInstanceOf[HttpURLConnection]
+      connection.setConnectTimeout(60 * 1000)
+      connection.connect()
+      connection.disconnect()
+    } match {
+      case Success(_) => true
+      case Failure(e) =>
+        LOG.warn(s"Failed to connect to backend: 
${backend.getIp}:${backend.getHttpPort}", e)
+        false
+    }
+
+  }
+
+  /**
+   * authorization info after base 64 encoded
+   *
+   * @return auth info
+   */
+  private def getAuthEncoded: String = {
+    val user = 
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER)
+    val passwd = 
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD)
+    
Base64.getEncoder.encodeToString(s"$user:$passwd".getBytes(StandardCharsets.UTF_8))
+  }
+
+  /**
+   * generate load label
+   *
+   * spark_streamload_YYYYMMDD_HHMMSS_{UUID}
+   *
+   * @return load label
+   */
+  private def generateLoadLabel(): String = {
+    val calendar = Calendar.getInstance
+    "spark_streamload_" +
+      f"${calendar.get(Calendar.YEAR)}${calendar.get(Calendar.MONTH) + 
1}%02d${calendar.get(Calendar.DAY_OF_MONTH)}%02d" +
+      
f"_${calendar.get(Calendar.HOUR_OF_DAY)}%02d${calendar.get(Calendar.MINUTE)}%02d${calendar.get(Calendar.SECOND)}%02d"
 +
+      f"_${UUID.randomUUID.toString.replaceAll("-", "")}"
+  }
+
+  private def generateHttpEntity(iterator: Iterator[InternalRow], schema: 
StructType): HttpEntity = {
+
+    var entity: Option[HttpEntity] = None
+
+    val compressType = streamLoadProps.get("compress_type")
+    val columnSeparator = 
escapeString(streamLoadProps.getOrElse("column_separator", "\t"))
+    val lineDelimiter = 
escapeString(streamLoadProps.getOrElse("line_delimiter", "\t"))
+    val addDoubleQuotes = streamLoadProps.getOrElse("add_double_quotes", 
"false").toBoolean
+    val streamingPassthrough: Boolean = isStreaming && 
settings.getBooleanProperty(
+      ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
+      ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT)
+
+    if (compressType.nonEmpty) {
+      if ("gz".equalsIgnoreCase(compressType.get) && format == DataFormat.CSV) 
{
+        val recordBatchString = new 
RecordBatchString(RecordBatch.newBuilder(iterator.asJava)
+          .format(format)
+          .sep(columnSeparator)
+          .delim(lineDelimiter)
+          .schema(schema)
+          .addDoubleQuotes(addDoubleQuotes).build, streamingPassthrough)
+        val content = recordBatchString.getContent
+        val compressedData = compressByGZ(content)
+        entity = Some(new ByteArrayEntity(compressedData))
+      }
+      else {
+        val msg = s"Not support the compress type [$compressType] for the 
format [$format]"
+        throw new StreamLoadException(msg)
+      }
+    }
+    else {
+      val recodeBatchInputStream = new 
RecordBatchInputStream(RecordBatch.newBuilder(iterator.asJava)
+        .format(format)
+        .sep(columnSeparator)
+        .delim(lineDelimiter)
+        .schema(schema)
+        .addDoubleQuotes(addDoubleQuotes).build, streamingPassthrough)
+      entity = Some(new InputStreamEntity(recodeBatchInputStream))
+    }
+
+    entity.get
+
+  }
+
+  /**
+   * Escape special characters
+   *
+   * @param hexData origin string
+   * @return escaped string
+   */
+  private def escapeString(hexData: String): String = {
+    if (hexData.startsWith("\\x") || hexData.startsWith("\\X")) {
+      try {
+        val tmp = hexData.substring(2)
+        val stringBuilder = new StringBuilder
+        var i = 0
+        while (i < tmp.length) {
+          val hexByte = tmp.substring(i, i + 2)
+          val decimal = Integer.parseInt(hexByte, 16)
+          val character = decimal.toChar
+          stringBuilder.append(character)
+          i += 2
+        }
+        return stringBuilder.toString
+      } catch {
+        case e: Exception =>
+          throw new RuntimeException("escape column_separator or 
line_delimiter error.{}", e)
+      }
+    }
+    hexData
+  }
+
+  /**
+   * compress data by gzip
+   *
+   * @param content data content
+   * @throws
+   * @return compressed byte array data
+   */
+  @throws[IOException]
+  def compressByGZ(content: String): Array[Byte] = {
+    var compressedData: Array[Byte] = null
+    try {
+      val baos = new ByteArrayOutputStream
+      val gzipOutputStream = new GZIPOutputStream(baos)
+      try {
+        gzipOutputStream.write(content.getBytes("UTF-8"))
+        gzipOutputStream.finish()
+        compressedData = baos.toByteArray
+      } finally {
+        if (baos != null) baos.close()
+        if (gzipOutputStream != null) gzipOutputStream.close()
+      }
+    }
+    compressedData
+  }
+
+  /**
+   * handle stream load response
+   *
+   * @param response http response
+   * @throws
+   * @return transaction id
+   */
+  @throws[StreamLoadException]
+  private def handleStreamLoadResponse(response: CloseableHttpResponse): Long 
= {
+
+    val code = response.getStatusLine.getStatusCode
+    val msg = response.getStatusLine.getReasonPhrase
+    val content = EntityUtils.toString(new 
BufferedHttpEntity(response.getEntity), StandardCharsets.UTF_8)
+    val loadResponse: StreamLoadResponse = StreamLoadResponse(code, msg, 
content)
+
+    if (loadResponse.code != HttpStatus.SC_OK) {
+      LOG.error(s"Stream load http status is not OK, status: 
${loadResponse.code}, response: $loadResponse")
+      throw new StreamLoadException(String.format("stream load error, http 
status:%d, response:%s",
+        new Integer(loadResponse.code), loadResponse))
+    } else {
+      try {
+        val respContent = MAPPER.readValue(loadResponse.content, 
classOf[RespContent])
+        if (!respContent.isSuccess) {
+          LOG.error(s"Stream load status is not success, 
status:${respContent.getStatus}, response:$loadResponse")
+          throw new StreamLoadException(String.format("stream load error, load 
status:%s, response:%s", respContent.getStatus, loadResponse))
+        }
+        LOG.info("Stream load Response:{}", loadResponse)
+        respContent.getTxnId
+      } catch {
+        case e: IOException =>
+          throw new StreamLoadException(e)
+      }
+    }
+
+  }
+
+  /**
+   * execute abort
+   *
+   * @param f function to set header
+   * @throws
+   */
+  @throws[StreamLoadException]
+  private def doAbort(f: HttpRequestBase => Unit): Unit = {
+
+    val client = getHttpClient
+
+    Try {
+
+      val abortUrl = String.format(LOAD_2PC_URL_PATTERN, getAddress, database)
+      val httpPut = new HttpPut(abortUrl)
+      addCommonHeader(httpPut)
+      httpPut.setHeader("txn_operation", "abort")
+      f(httpPut)
+
+      val response = client.execute(httpPut)
+      val statusCode = response.getStatusLine.getStatusCode
+      if (statusCode != HttpStatus.SC_OK || response.getEntity == null) {
+        LOG.error("abort transaction response: " + 
response.getStatusLine.toString)
+        throw new StreamLoadException("Fail to abort transaction with url " + 
abortUrl)
+      }
+
+      val loadResult = EntityUtils.toString(response.getEntity)
+      val res = MAPPER.readValue(loadResult, new 
TypeReference[util.HashMap[String, String]]() {})
+      if (!"Success".equalsIgnoreCase(res.get("status"))) {
+        if (ResponseUtil.isCommitted(res.get("msg"))) throw new 
IOException("try abort committed transaction")
+        LOG.error(s"Fail to abort transaction. error: ${res.get("msg")}")
+        throw new StreamLoadException(String.format("Fail to abort 
transaction. error: %s", res.get("msg")))
+      }
+
+    } match {
+      case Success(_) => client.close()
+      case Failure(e) =>
+        client.close()
+        LOG.error(s"abort transaction failed, 
${ExceptionUtils.getStackTrace(e)}")
+        if (e.isInstanceOf[StreamLoadException]) throw e
+        throw new StreamLoadException(e)
+    }
+
+  }
+
+}
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
index 995bd41..0156d37 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
@@ -21,6 +21,7 @@ import org.apache.commons.lang3.exception.ExceptionUtils
 import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
 import org.apache.doris.spark.exception.DorisException
 import org.apache.doris.spark.jdbc.JdbcUtils
+import org.apache.doris.spark.load.{CommitMessage, StreamLoader}
 import org.apache.doris.spark.sql.DorisSourceProvider.SHORT_NAME
 import org.apache.doris.spark.writer.DorisWriter
 import org.apache.spark.SparkConf
@@ -68,9 +69,9 @@ private[sql] class DorisSourceProvider extends 
DataSourceRegister
     }
 
     // accumulator for transaction handling
-    val acc = 
sqlContext.sparkContext.collectionAccumulator[Long]("BatchTxnAcc")
+    val acc = 
sqlContext.sparkContext.collectionAccumulator[CommitMessage]("BatchTxnAcc")
     // init stream loader
-    val writer = new DorisWriter(sparkSettings, acc)
+    val writer = new DorisWriter(sparkSettings, acc, false)
     writer.write(data)
 
     new BaseRelation {
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
index 9a80fa8..eb0ac12 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
@@ -17,8 +17,8 @@
 
 package org.apache.doris.spark.sql
 
-import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
-import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
+import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.load.CommitMessage
 import org.apache.doris.spark.txn.listener.DorisTxnStreamingQueryListener
 import org.apache.doris.spark.writer.DorisWriter
 import org.apache.spark.sql.execution.streaming.Sink
@@ -31,17 +31,17 @@ private[sql] class DorisStreamLoadSink(sqlContext: 
SQLContext, settings: SparkSe
   @volatile private var latestBatchId = -1L
 
   // accumulator for transaction handling
-  private val acc = 
sqlContext.sparkContext.collectionAccumulator[Long]("StreamTxnAcc")
-  private val writer = new DorisWriter(settings, acc)
+  private val acc = 
sqlContext.sparkContext.collectionAccumulator[CommitMessage]("StreamTxnAcc")
+  private val writer = new DorisWriter(settings, acc, true)
 
   // add listener for structured streaming
-  sqlContext.streams.addListener(new DorisTxnStreamingQueryListener(acc, 
settings))
+  sqlContext.streams.addListener(new DorisTxnStreamingQueryListener(acc, 
writer.getTransactionHandler))
 
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
     if (batchId <= latestBatchId) {
       logger.info(s"Skipping already committed batch $batchId")
     } else {
-      writer.writeStream(data)
+      writer.write(data)
       latestBatchId = batchId
     }
   }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala
index deeb40b..a64e437 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/TransactionHandler.scala
@@ -17,8 +17,7 @@
 
 package org.apache.doris.spark.txn
 
-import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
-import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
+import org.apache.doris.spark.load.{CommitMessage, Loader}
 import org.apache.doris.spark.sql.Utils
 import org.apache.spark.internal.Logging
 
@@ -27,29 +26,25 @@ import scala.collection.mutable
 import scala.util.{Failure, Success}
 
 /**
- * Stream load transaction handler
+ * load transaction handler
  *
- * @param settings job settings
+ * @param loader loader
+ * @param retries max retry times
+ * @param interval retry interval ms
  */
-class TransactionHandler(settings: SparkSettings) extends Logging {
-
-  private val sinkTxnIntervalMs: Int = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
-    ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
-  private val sinkTxnRetries: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
-    ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT)
-  private val dorisStreamLoad: DorisStreamLoad = 
CachedDorisStreamLoadClient.getOrCreate(settings)
+class TransactionHandler(loader: Loader, retries: Int, interval: Int) extends 
Logging with Serializable {
 
   /**
    * commit transactions
    *
-   * @param txnIds transaction id list
+   * @param messages commit message list
    */
-  def commitTransactions(txnIds: List[Long]): Unit = {
-    log.debug(s"start to commit transactions, count ${txnIds.size}")
-    val (failedTxnIds, ex) = 
txnIds.map(commitTransaction).filter(_._1.nonEmpty)
+  def commitTransactions(messages: List[CommitMessage]): Unit = {
+    log.debug(s"start to commit transactions, count ${messages.size}")
+    val (failedTxnIds, ex) = 
messages.map(commitTransaction).filter(_._1.nonEmpty)
       .map(e => (e._1.get, e._2.get))
-      .aggregate((mutable.Buffer[Long](), new Exception))(
-        (z, r) => ((z._1 += r._1).asInstanceOf[mutable.Buffer[Long]], r._2), 
(r1, r2) => (r1._1 ++ r2._1, r2._2))
+      .aggregate((mutable.Buffer[Any](), new Exception))(
+        (z, r) => ((z._1 += r._1).asInstanceOf[mutable.Buffer[Any]], r._2), 
(r1, r2) => (r1._1 ++ r2._1, r2._2))
     if (failedTxnIds.nonEmpty) {
       log.error("uncommitted txn ids: {}", failedTxnIds.mkString("[", ",", 
"]"))
       throw ex
@@ -59,34 +54,34 @@ class TransactionHandler(settings: SparkSettings) extends 
Logging {
   /**
    * commit single transaction
    *
-   * @param txnId transaction id
+   * @param msg commit message
    * @return
    */
-  private def commitTransaction(txnId: Long): (Option[Long], 
Option[Exception]) = {
-    Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), log) {
-      dorisStreamLoad.commit(txnId)
+  private def commitTransaction(msg: CommitMessage): (Option[Any], 
Option[Exception]) = {
+    Utils.retry(retries, Duration.ofMillis(interval), log) {
+      loader.commit(msg)
     }() match {
       case Success(_) => (None, None)
-      case Failure(e: Exception) => (Option(txnId), Option(e))
+      case Failure(e: Exception) => (Option(msg.value), Option(e))
     }
   }
 
   /**
    * abort transactions
    *
-   * @param txnIds transaction id list
+   * @param messages commit message list
    */
-  def abortTransactions(txnIds: List[Long]): Unit = {
-    log.debug(s"start to abort transactions, count ${txnIds.size}")
+  def abortTransactions(messages: List[CommitMessage]): Unit = {
+    log.debug(s"start to abort transactions, count ${messages.size}")
     var ex: Option[Exception] = None
-    val failedTxnIds = txnIds.map(txnId =>
-      Utils.retry(sinkTxnRetries, Duration.ofMillis(sinkTxnIntervalMs), log) {
-        dorisStreamLoad.abortById(txnId)
+    val failedTxnIds = messages.map(msg =>
+      Utils.retry(retries, Duration.ofMillis(interval), log) {
+        loader.abort(msg)
       }() match {
         case Success(_) => None
         case Failure(e: Exception) =>
           ex = Option(e)
-          Option(txnId)
+          Option(msg.value)
       }).filter(_.nonEmpty).map(_.get)
     if (failedTxnIds.nonEmpty) {
       log.error("not aborted txn ids: {}", failedTxnIds.mkString("[", ",", 
"]"))
@@ -96,5 +91,6 @@ class TransactionHandler(settings: SparkSettings) extends 
Logging {
 }
 
 object TransactionHandler {
-  def apply(settings: SparkSettings): TransactionHandler = new 
TransactionHandler(settings)
+  def apply(loader: Loader, retries: Int, interval: Int): TransactionHandler =
+    new TransactionHandler(loader, retries, interval)
 }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala
index b23dcae..1d33190 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTransactionListener.scala
@@ -17,7 +17,7 @@
 
 package org.apache.doris.spark.txn.listener
 
-import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.load.CommitMessage
 import org.apache.doris.spark.txn.TransactionHandler
 import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler._
@@ -26,39 +26,37 @@ import org.apache.spark.util.CollectionAccumulator
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
-class DorisTransactionListener(preCommittedTxnAcc: 
CollectionAccumulator[Long], settings: SparkSettings)
+class DorisTransactionListener(txnAcc: CollectionAccumulator[CommitMessage], 
txnHandler: TransactionHandler)
   extends SparkListener with Logging {
 
-  val txnHandler: TransactionHandler = TransactionHandler(settings)
-
   override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
-    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
+    val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala
     jobEnd.jobResult match {
       // if job succeed, commit all transactions
       case JobSucceeded =>
-        if (txnIds.isEmpty) {
+        if (messages.isEmpty) {
           log.debug("job run succeed, but there is no pre-committed txn ids")
           return
         }
         log.info("job run succeed, start committing transactions")
-        try txnHandler.commitTransactions(txnIds.toList)
+        try txnHandler.commitTransactions(messages.toList)
         catch {
           case e: Exception => throw e
         }
-        finally preCommittedTxnAcc.reset()
+        finally txnAcc.reset()
         log.info("commit transaction success")
       // if job failed, abort all pre committed transactions
       case _ =>
-        if (txnIds.isEmpty) {
+        if (messages.isEmpty) {
           log.debug("job run failed, but there is no pre-committed txn ids")
           return
         }
         log.info("job run failed, start aborting transactions")
-        try txnHandler.abortTransactions(txnIds.toList)
+        try txnHandler.abortTransactions(messages.toList)
         catch {
           case e: Exception => throw e
         }
-        finally preCommittedTxnAcc.reset()
+        finally txnAcc.reset()
         log.info("abort transaction success")
     }
   }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala
index 77ac9c3..0ddd4ef 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/txn/listener/DorisTxnStreamingQueryListener.scala
@@ -18,6 +18,7 @@
 package org.apache.doris.spark.txn.listener
 
 import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.load.CommitMessage
 import org.apache.doris.spark.txn.TransactionHandler
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.streaming.StreamingQueryListener
@@ -26,42 +27,40 @@ import org.apache.spark.util.CollectionAccumulator
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
-class DorisTxnStreamingQueryListener(preCommittedTxnAcc: 
CollectionAccumulator[Long], settings: SparkSettings)
+class DorisTxnStreamingQueryListener(txnAcc: 
CollectionAccumulator[CommitMessage], txnHandler: TransactionHandler)
   extends StreamingQueryListener with Logging {
 
-  private val txnHandler = TransactionHandler(settings)
-
   override def onQueryStarted(event: 
StreamingQueryListener.QueryStartedEvent): Unit = {}
 
   override def onQueryProgress(event: 
StreamingQueryListener.QueryProgressEvent): Unit = {
     // do commit transaction when each batch ends
-    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
-    if (txnIds.isEmpty) {
-      log.warn("job run succeed, but there is no pre-committed txn ids")
+    val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala
+    if (messages.isEmpty) {
+      log.warn("job run succeed, but there is no pre-committed txn")
       return
     }
     log.info(s"batch[${event.progress.batchId}] run succeed, start committing 
transactions")
-    try txnHandler.commitTransactions(txnIds.toList)
+    try txnHandler.commitTransactions(messages.toList)
     catch {
       case e: Exception => throw e
-    } finally preCommittedTxnAcc.reset()
+    } finally txnAcc.reset()
     log.info(s"batch[${event.progress.batchId}] commit transaction success")
   }
 
 
   override def onQueryTerminated(event: 
StreamingQueryListener.QueryTerminatedEvent): Unit = {
-    val txnIds: mutable.Buffer[Long] = preCommittedTxnAcc.value.asScala
+    val messages: mutable.Buffer[CommitMessage] = txnAcc.value.asScala
     // if job failed, abort all pre committed transactions
     if (event.exception.nonEmpty) {
-      if (txnIds.isEmpty) {
-        log.warn("job run failed, but there is no pre-committed txn ids")
+      if (messages.isEmpty) {
+        log.warn("job run failed, but there is no pre-committed txn")
         return
       }
       log.info("job run failed, start aborting transactions")
-      try txnHandler.abortTransactions(txnIds.toList)
+      try txnHandler.abortTransactions(messages.toList)
       catch {
         case e: Exception => throw e
-      } finally preCommittedTxnAcc.reset()
+      } finally txnAcc.reset()
       log.info("abort transaction success")
     }
   }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 53187f4..26491df 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -18,7 +18,7 @@
 package org.apache.doris.spark.writer
 
 import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
-import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
+import org.apache.doris.spark.load.{CommitMessage, Loader, StreamLoader}
 import org.apache.doris.spark.sql.Utils
 import org.apache.doris.spark.txn.TransactionHandler
 import org.apache.doris.spark.txn.listener.DorisTransactionListener
@@ -30,13 +30,14 @@ import org.slf4j.{Logger, LoggerFactory}
 
 import java.io.IOException
 import java.time.Duration
-import java.util
 import java.util.Objects
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 import scala.util.{Failure, Success}
 
-class DorisWriter(settings: SparkSettings, preCommittedTxnAcc: 
CollectionAccumulator[Long]) extends Serializable {
+class DorisWriter(settings: SparkSettings,
+                  txnAcc: CollectionAccumulator[CommitMessage],
+                  isStreaming: Boolean) extends Serializable {
 
   private val logger: Logger = LoggerFactory.getLogger(classOf[DorisWriter])
 
@@ -58,9 +59,15 @@ class DorisWriter(settings: SparkSettings, 
preCommittedTxnAcc: CollectionAccumul
   private val enable2PC: Boolean = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
     ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT)
 
-  private val dorisStreamLoader: DorisStreamLoad = 
CachedDorisStreamLoadClient.getOrCreate(settings)
+  private val loader: Loader = generateLoader
 
-  private var isStreaming = false;
+  private val sinkTxnRetries = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_RETRIES,
+    ConfigurationOptions.DORIS_SINK_TXN_RETRIES_DEFAULT)
+
+  private val sinkTxnIntervalMs = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS,
+    ConfigurationOptions.DORIS_SINK_TXN_INTERVAL_MS_DEFAULT)
+
+  private val txnHandler: TransactionHandler = TransactionHandler(loader, 
sinkTxnRetries, sinkTxnIntervalMs)
 
   /**
    * write data in batch mode
@@ -68,23 +75,13 @@ class DorisWriter(settings: SparkSettings, 
preCommittedTxnAcc: CollectionAccumul
    * @param dataFrame source dataframe
    */
   def write(dataFrame: DataFrame): Unit = {
-    doWrite(dataFrame, dorisStreamLoader.load)
+    doWrite(dataFrame, loader.load)
   }
 
-  /**
-   * write data in stream mode
-   *
-   * @param dataFrame source dataframe
-   */
-  def writeStream(dataFrame: DataFrame): Unit = {
-    isStreaming = true
-    doWrite(dataFrame, dorisStreamLoader.loadStream)
-  }
-
-  private def doWrite(dataFrame: DataFrame, loadFunc: 
(util.Iterator[InternalRow], StructType) => Long): Unit = {
+  private def doWrite(dataFrame: DataFrame, loadFunc: (Iterator[InternalRow], 
StructType) => Option[CommitMessage]): Unit = {
     // do not add spark listener when job is streaming mode
     if (enable2PC && !isStreaming) {
-      dataFrame.sparkSession.sparkContext.addSparkListener(new 
DorisTransactionListener(preCommittedTxnAcc, settings))
+      dataFrame.sparkSession.sparkContext.addSparkListener(new 
DorisTransactionListener(txnAcc, txnHandler))
     }
 
     var resultRdd = dataFrame.queryExecution.toRdd
@@ -96,16 +93,15 @@ class DorisWriter(settings: SparkSettings, 
preCommittedTxnAcc: CollectionAccumul
 
       while (iterator.hasNext) {
         val batchIterator = new BatchIterator[InternalRow](iterator, 
batchSize, maxRetryTimes > 0)
-        val retry = Utils.retry[Long, Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) _
-        retry(loadFunc(batchIterator.asJava, schema))(batchIterator.reset()) 
match {
-          case Success(txnId) =>
-            if (enable2PC) handleLoadSuccess(txnId, preCommittedTxnAcc)
+        val retry = Utils.retry[Option[CommitMessage], 
Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) _
+        retry(loadFunc(batchIterator, schema))(batchIterator.reset()) match {
+          case Success(msg) =>
+            if (enable2PC) handleLoadSuccess(msg, txnAcc)
             batchIterator.close()
           case Failure(e) =>
-            if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+            if (enable2PC) handleLoadFailure(txnAcc)
             batchIterator.close()
-            throw new IOException(
-              s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node.", e)
+            throw e
         }
       }
 
@@ -113,11 +109,11 @@ class DorisWriter(settings: SparkSettings, 
preCommittedTxnAcc: CollectionAccumul
 
   }
 
-  private def handleLoadSuccess(txnId: Long, acc: 
CollectionAccumulator[Long]): Unit = {
-    acc.add(txnId)
+  private def handleLoadSuccess(msg: Option[CommitMessage], acc: 
CollectionAccumulator[CommitMessage]): Unit = {
+    acc.add(msg.get)
   }
 
-  private def handleLoadFailure(acc: CollectionAccumulator[Long]): Unit = {
+  private def handleLoadFailure(acc: CollectionAccumulator[CommitMessage]): 
Unit = {
     // if task run failed, acc value will not be returned to driver,
     // should abort all pre committed transactions inside the task
     logger.info("load task failed, start aborting previously pre-committed 
transactions")
@@ -126,7 +122,7 @@ class DorisWriter(settings: SparkSettings, 
preCommittedTxnAcc: CollectionAccumul
       return
     }
 
-    try 
TransactionHandler(settings).abortTransactions(acc.value.asScala.toList)
+    try txnHandler.abortTransactions(acc.value.asScala.toList)
     catch {
       case e: Exception => throw e
     }
@@ -212,5 +208,14 @@ class DorisWriter(settings: SparkSettings, 
preCommittedTxnAcc: CollectionAccumul
 
   }
 
+  @throws[IllegalArgumentException]
+  private def generateLoader: Loader = {
+    val loadMode = settings.getProperty("load_mode", "stream_load")
+    if ("stream_load".equalsIgnoreCase(loadMode)) new StreamLoader(settings, 
isStreaming)
+    else throw new IllegalArgumentException(s"Unsupported load mode: 
$loadMode")
+  }
+
+  def getTransactionHandler: TransactionHandler = txnHandler
+
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to