This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4a36c151ea1 [SPARK-41108][CONNECT] Control the max size of arrow batch
4a36c151ea1 is described below

commit 4a36c151ea1cf8372ddefbe0a75f3470bfbe1587
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Nov 11 14:35:31 2022 +0900

    [SPARK-41108][CONNECT] Control the max size of arrow batch
    
    ### What changes were proposed in this pull request?
    
    Control the max size of arrow batch
    
    ### Why are the changes needed?
    
    as per the suggestion 
https://github.com/apache/spark/pull/38468#discussion_r1018951362
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing tests
    
    Closes #38612 from zhengruifeng/connect_arrow_batchsize.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../service/SparkConnectStreamHandler.scala        |  6 ++----
 .../sql/execution/arrow/ArrowConverters.scala      | 25 ++++++++++++++++------
 2 files changed, 20 insertions(+), 11 deletions(-)

diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 3b734616b21..9652fce5425 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters
 class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) 
extends Logging {
 
   // The maximum batch size in bytes for a single batch of data to be returned 
via proto.
-  val MAX_BATCH_SIZE: Long = 10 * 1024 * 1024
+  private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024
 
   def handle(v: Request): Unit = {
     val session =
@@ -127,8 +127,6 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
     val spark = dataframe.sparkSession
     val schema = dataframe.schema
-    // TODO: control the batch size instead of max records
-    val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
     val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
 
     SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
@@ -141,7 +139,7 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
 
         val batches = rows.mapPartitionsInternal { iter =>
           ArrowConverters
-            .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+            .toBatchWithSchemaIterator(iter, schema, MAX_BATCH_SIZE, 
timeZoneId)
         }
 
         val signal = new Object
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index a2dce31bc6d..c233ac32c12 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -33,12 +33,12 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
-import org.apache.spark.util.{ByteBufferOutputStream, Utils}
+import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils}
 
 
 /**
@@ -128,10 +128,14 @@ private[sql] object ArrowConverters extends Logging {
     }
   }
 
-  private[sql] def toArrowBatchIterator(
+  /**
+   * Convert the input rows into fully contained arrow batches.
+   * Different from [[toBatchIterator]], each output arrow batch starts with 
the schema.
+   */
+  private[sql] def toBatchWithSchemaIterator(
       rowIter: Iterator[InternalRow],
       schema: StructType,
-      maxRecordsPerBatch: Int,
+      maxBatchSize: Long,
       timeZoneId: String): Iterator[(Array[Byte], Long)] = {
     val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
     val allocator = ArrowUtils.rootAllocator.newChildAllocator(
@@ -140,6 +144,7 @@ private[sql] object ArrowConverters extends Logging {
     val root = VectorSchemaRoot.create(arrowSchema, allocator)
     val unloader = new VectorUnloader(root)
     val arrowWriter = ArrowWriter.create(root)
+    val arrowSchemaSize = SizeEstimator.estimate(arrowSchema)
 
     Option(TaskContext.get).foreach {
       _.addTaskCompletionListener[Unit] { _ =>
@@ -161,17 +166,23 @@ private[sql] object ArrowConverters extends Logging {
         val writeChannel = new WriteChannel(Channels.newChannel(out))
 
         var rowCount = 0L
+        var estimatedBatchSize = arrowSchemaSize
         Utils.tryWithSafeFinally {
-          while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+          // Always write the schema.
+          MessageSerializer.serialize(writeChannel, arrowSchema)
+
+          // Always write the first row.
+          while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < 
maxBatchSize)) {
             val row = rowIter.next()
             arrowWriter.write(row)
+            estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes
             rowCount += 1
           }
           arrowWriter.finish()
           val batch = unloader.getRecordBatch()
-
-          MessageSerializer.serialize(writeChannel, arrowSchema)
           MessageSerializer.serialize(writeChannel, batch)
+
+          // Always write the Ipc options at the end.
           ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
 
           batch.close()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to