hvanhovell commented on code in PR #38468: URL: https://github.com/apache/spark/pull/38468#discussion_r1018944392
########## connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala: ########## @@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte responseObserver.onNext(response.build()) } - responseObserver.onNext(sendMetricsToResponse(clientId, rows)) + responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) responseObserver.onCompleted() } + def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { + val spark = dataframe.sparkSession + val schema = dataframe.schema + // TODO: control the batch size instead of max records + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + + SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { + val rows = dataframe.queryExecution.executedPlan.execute() + val numPartitions = rows.getNumPartitions + var numSent = 0 + + if (numPartitions > 0) { + type Batch = (Array[Byte], Long, Long) + + val batches = rows.mapPartitionsInternal { iter => + ArrowConverters + .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) + } + + val signal = new Object + val partitions = Array.fill[Array[Batch]](numPartitions)(null) + + val processPartition = (iter: Iterator[Batch]) => iter.toArray + + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { + signal.synchronized { + partitions(partitionId) = partition + signal.notify() + } + val i = 0 // Unit + } + + spark.sparkContext.runJob(batches, processPartition, resultHandler) + + var currentPartitionId = 0 + while (currentPartitionId < numPartitions) { + val partition = signal.synchronized { + while (partitions(currentPartitionId) == null) { + signal.wait() + } + val partition = partitions(currentPartitionId) + partitions(currentPartitionId) = null + partition + } + + // only send non-empty partitions + if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) { Review Comment: Different questions. Why not just iterator over the partitions, and filter out the non-empty batches? That should be the same and it saves you from an unneeded if. ########## connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala: ########## @@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte responseObserver.onNext(response.build()) } - responseObserver.onNext(sendMetricsToResponse(clientId, rows)) + responseObserver.onNext(sendMetricsToResponse(clientId, dataframe)) responseObserver.onCompleted() } + def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = { + val spark = dataframe.sparkSession + val schema = dataframe.schema + // TODO: control the batch size instead of max records + val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone + + SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) { + val rows = dataframe.queryExecution.executedPlan.execute() + val numPartitions = rows.getNumPartitions + var numSent = 0 + + if (numPartitions > 0) { + type Batch = (Array[Byte], Long, Long) + + val batches = rows.mapPartitionsInternal { iter => + ArrowConverters + .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId) + } + + val signal = new Object + val partitions = Array.fill[Array[Batch]](numPartitions)(null) + + val processPartition = (iter: Iterator[Batch]) => iter.toArray + + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { + signal.synchronized { + partitions(partitionId) = partition + signal.notify() + } + val i = 0 // Unit + } + + spark.sparkContext.runJob(batches, processPartition, resultHandler) + + var currentPartitionId = 0 + while (currentPartitionId < numPartitions) { + val partition = signal.synchronized { + while (partitions(currentPartitionId) == null) { + signal.wait() + } + val partition = partitions(currentPartitionId) + partitions(currentPartitionId) = null + partition + } + + // only send non-empty partitions + if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) { Review Comment: Different questions. Why not just iterate over the partitions, and filter out the non-empty batches? That should be the same and it saves you from an unneeded if. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org