This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4a36c151ea1 [SPARK-41108][CONNECT] Control the max size of arrow batch 4a36c151ea1 is described below commit 4a36c151ea1cf8372ddefbe0a75f3470bfbe1587 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Nov 11 14:35:31 2022 +0900 [SPARK-41108][CONNECT] Control the max size of arrow batch ### What changes were proposed in this pull request? Control the max size of arrow batch ### Why are the changes needed? as per the suggestion https://github.com/apache/spark/pull/38468#discussion_r1018951362 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests Closes #38612 from zhengruifeng/connect_arrow_batchsize. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../service/SparkConnectStreamHandler.scala | 6 ++---- .../sql/execution/arrow/ArrowConverters.scala | 25 ++++++++++++++++------ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 3b734616b21..9652fce5425 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { // The maximum batch size in bytes for a single batch of data to be returned via proto. - val MAX_BATCH_SIZE: Long = 10 * 1024 * 1024 + private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024 def handle(v: Request): Unit = { val session = @@ -127,8 +127,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { val spark = dataframe.sparkSession val schema = dataframe.schema - // TODO: control the batch size instead of max records - val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { @@ -141,7 +139,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte val batches = rows.mapPartitionsInternal { iter => ArrowConverters - .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) + .toBatchWithSchemaIterator(iter, schema, MAX_BATCH_SIZE, timeZoneId) } val signal = new Object diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index a2dce31bc6d..c233ac32c12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -33,12 +33,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.{ByteBufferOutputStream, Utils} +import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} /** @@ -128,10 +128,14 @@ private[sql] object ArrowConverters extends Logging { } } - private[sql] def toArrowBatchIterator( + /** + * Convert the input rows into fully contained arrow batches. + * Different from [[toBatchIterator]], each output arrow batch starts with the schema. + */ + private[sql] def toBatchWithSchemaIterator( rowIter: Iterator[InternalRow], schema: StructType, - maxRecordsPerBatch: Int, + maxBatchSize: Long, timeZoneId: String): Iterator[(Array[Byte], Long)] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator( @@ -140,6 +144,7 @@ private[sql] object ArrowConverters extends Logging { val root = VectorSchemaRoot.create(arrowSchema, allocator) val unloader = new VectorUnloader(root) val arrowWriter = ArrowWriter.create(root) + val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) Option(TaskContext.get).foreach { _.addTaskCompletionListener[Unit] { _ => @@ -161,17 +166,23 @@ private[sql] object ArrowConverters extends Logging { val writeChannel = new WriteChannel(Channels.newChannel(out)) var rowCount = 0L + var estimatedBatchSize = arrowSchemaSize Utils.tryWithSafeFinally { - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + // Always write the schema. + MessageSerializer.serialize(writeChannel, arrowSchema) + + // Always write the first row. + while (rowIter.hasNext && (rowCount == 0 || estimatedBatchSize < maxBatchSize)) { val row = rowIter.next() arrowWriter.write(row) + estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes rowCount += 1 } arrowWriter.finish() val batch = unloader.getRecordBatch() - - MessageSerializer.serialize(writeChannel, arrowSchema) MessageSerializer.serialize(writeChannel, batch) + + // Always write the Ipc options at the end. ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT) batch.close() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org