[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019170138


##
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
 }
   }
 
+  private[sql] def toArrowBatchIterator(
+  rowIter: Iterator[InternalRow],
+  schema: StructType,
+  maxRecordsPerBatch: Int,
+  timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+  "toArrowBatchIterator", 0, Long.MaxValue)
+
+val root = VectorSchemaRoot.create(arrowSchema, allocator)
+val unloader = new VectorUnloader(root)
+val arrowWriter = ArrowWriter.create(root)
+
+Option(TaskContext.get).foreach {
+  _.addTaskCompletionListener[Unit] { _ =>
+root.close()
+allocator.close()
+  }
+}
+
+new Iterator[(Array[Byte], Long, Long)] {
+
+  override def hasNext: Boolean = rowIter.hasNext || {
+root.close()
+allocator.close()
+false
+  }
+
+  override def next(): (Array[Byte], Long, Long) = {
+val out = new ByteArrayOutputStream()
+val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+var rowCount = 0L
+var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+  SizeEstimator.estimate(IpcOption.DEFAULT)
+Utils.tryWithSafeFinally {
+  while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+val row = rowIter.next()
+arrowWriter.write(row)
+rowCount += 1
+estimatedSize += SizeEstimator.estimate(row)
+  }
+  arrowWriter.finish()
+  val batch = unloader.getRecordBatch()
+
+  MessageSerializer.serialize(writeChannel, arrowSchema)
+  MessageSerializer.serialize(writeChannel, batch)
+  ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+  batch.close()
+} {
+  arrowWriter.reset()
+}
+
+(out.toByteArray, rowCount, estimatedSize)
+  }
+}
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   calling toArrowBatchIterator with an empty iterator will return an empty 
iterator
   
   here needs an arrow batch with empty data



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019118121


##
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##
@@ -128,6 +128,92 @@ private[sql] object ArrowConverters extends Logging {
 }
   }
 
+  private[sql] def toArrowBatchIterator(

Review Comment:
   `toArrowBatchIterator` also write schema before each record batch, while 
`toBatchIterator ` just output record batch.
   
   and I am going to update `toArrowBatchIterator` in a follow-up to control 
each batch size < 4MB as per the suggestions 
https://github.com/apache/spark/pull/38468#discussion_r1018951362   
https://github.com/apache/spark/pull/38468#discussion_r1013186031
   
   I think we can deduplicate the codes then



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019100602


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   > can we apply the same idea to JSON batches?
   
   I think so, let's optimize it later



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019099531


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit
+}
+
+spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+var currentPartitionId = 0
+while (currentPartitionId < numPartitions) {
+  val partition = signal.synchronized {
+while (partitions(currentPartitionId) == null) {
+  signal.wait()

Review Comment:
   > If the first partition arrives last, the whole dataset stays in the 
driver's memory, right?
   
   yes, but at least it's not worse than existing `collect` which always keep 
whole dataset in memory.
   
   receiving the partitions by order may make it easier to consume in the 
client, if ordering matters.
   
   I think we will optimize it further, it is just an initial implementation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1019093054


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +120,93 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

Review Comment:
   just change from array to map ... see 
https://github.com/apache/spark/pull/38468#discussion_r1018938395



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018939295


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)

Review Comment:
   will update soon



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018928333


##
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
 }
   }
 
+  private[sql] def toArrowBatchIterator(
+  rowIter: Iterator[InternalRow],
+  schema: StructType,
+  maxRecordsPerBatch: Int,
+  timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+  "toArrowBatchIterator", 0, Long.MaxValue)
+
+val root = VectorSchemaRoot.create(arrowSchema, allocator)
+val unloader = new VectorUnloader(root)
+val arrowWriter = ArrowWriter.create(root)
+
+Option(TaskContext.get).foreach {
+  _.addTaskCompletionListener[Unit] { _ =>
+root.close()
+allocator.close()
+  }
+}
+
+new Iterator[(Array[Byte], Long, Long)] {
+
+  override def hasNext: Boolean = rowIter.hasNext || {
+root.close()
+allocator.close()
+false
+  }
+
+  override def next(): (Array[Byte], Long, Long) = {
+val out = new ByteArrayOutputStream()
+val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+var rowCount = 0L
+var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+  SizeEstimator.estimate(IpcOption.DEFAULT)
+Utils.tryWithSafeFinally {
+  while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+val row = rowIter.next()
+arrowWriter.write(row)
+rowCount += 1
+estimatedSize += SizeEstimator.estimate(row)
+  }
+  arrowWriter.finish()
+  val batch = unloader.getRecordBatch()
+
+  MessageSerializer.serialize(writeChannel, arrowSchema)
+  MessageSerializer.serialize(writeChannel, batch)
+  ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+  batch.close()
+} {
+  arrowWriter.reset()
+}
+
+(out.toByteArray, rowCount, estimatedSize)
+  }
+}
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   still don't figure out how to deduplicate, the `toArrowBatchIterator` should 
also return `rowCount`
   
   what about trying to do this after switch to `batch size < 4MB`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018875460


##
connector/connect/src/main/protobuf/spark/connect/base.proto:
##
@@ -83,7 +83,6 @@ message Response {
 int64 uncompressed_bytes = 2;

Review Comment:
   yeah, will git rid of it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-10 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018794542


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit
+}
+
+spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+var currentPartitionId = 0
+while (currentPartitionId < numPartitions) {
+  val partition = signal.synchronized {
+while (partitions(currentPartitionId) == null) {
+  signal.wait()
+}
+val partition = partitions(currentPartitionId)
+partitions(currentPartitionId) = null
+partition
+  }
+
+  // only send non-empty partitions
+  if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+partition.foreach { case (bytes, count, size) =>
+  val response = proto.Response.newBuilder().setClientId(clientId)
+  val batch = proto.Response.ArrowBatch
+.newBuilder()
+.setRowCount(count)
+.setUncompressedBytes(size)
+.setCompressedBytes(bytes.length)
+.setData(ByteString.copyFrom(bytes))
+.build()
+  response.setArrowBatch(batch)
+  responseObserver.onNext(response.build())
+}
+numSent += 1
+  }
+
+  currentPartitionId += 1
+}
+  }
+
+  // make sure at least 1 batch will be sent
+  if (numSent == 0) {

Review Comment:
   `Optional[Pandas]` maybe still needed since 0 json batch maybe returned



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018741978


##
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##
@@ -128,6 +128,97 @@ private[sql] object ArrowConverters extends Logging {
 }
   }
 
+  private[sql] def toArrowBatchIterator(
+  rowIter: Iterator[InternalRow],
+  schema: StructType,
+  maxRecordsPerBatch: Int,
+  timeZoneId: String): Iterator[(Array[Byte], Long, Long)] = {
+val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+  "toArrowBatchIterator", 0, Long.MaxValue)
+
+val root = VectorSchemaRoot.create(arrowSchema, allocator)
+val unloader = new VectorUnloader(root)
+val arrowWriter = ArrowWriter.create(root)
+
+Option(TaskContext.get).foreach {
+  _.addTaskCompletionListener[Unit] { _ =>
+root.close()
+allocator.close()
+  }
+}
+
+new Iterator[(Array[Byte], Long, Long)] {
+
+  override def hasNext: Boolean = rowIter.hasNext || {
+root.close()
+allocator.close()
+false
+  }
+
+  override def next(): (Array[Byte], Long, Long) = {
+val out = new ByteArrayOutputStream()
+val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+var rowCount = 0L
+var estimatedSize = SizeEstimator.estimate(arrowSchema) +
+  SizeEstimator.estimate(IpcOption.DEFAULT)
+Utils.tryWithSafeFinally {
+  while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+val row = rowIter.next()
+arrowWriter.write(row)
+rowCount += 1
+estimatedSize += SizeEstimator.estimate(row)
+  }
+  arrowWriter.finish()
+  val batch = unloader.getRecordBatch()
+
+  MessageSerializer.serialize(writeChannel, arrowSchema)
+  MessageSerializer.serialize(writeChannel, batch)
+  ArrowStreamWriter.writeEndOfStream(writeChannel, IpcOption.DEFAULT)
+
+  batch.close()
+} {
+  arrowWriter.reset()
+}
+
+(out.toByteArray, rowCount, estimatedSize)
+  }
+}
+  }
+
+  private[sql] def createEmptyArrowBatch(

Review Comment:
   let me take a look



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018741578


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit
+}
+
+spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+var currentPartitionId = 0
+while (currentPartitionId < numPartitions) {
+  val partition = signal.synchronized {
+while (partitions(currentPartitionId) == null) {
+  signal.wait()
+}
+val partition = partitions(currentPartitionId)
+partitions(currentPartitionId) = null
+partition
+  }
+
+  // only send non-empty partitions
+  if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {

Review Comment:
   yes, at least for `collect` empty partition is meaningless
   
   maybe meaningful for some partitioning-aware operations like 
`RDD.zipPartitions`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018690469


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit

Review Comment:
   got it, will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018684111


##
python/pyspark/sql/connect/client.py:
##
@@ -400,6 +400,14 @@ def _execute_and_fetch(self, req: pb2.Request) -> 
typing.Optional[pandas.DataFra
 
 if len(result_dfs) > 0:
 df = pd.concat(result_dfs)
+del result_dfs

Review Comment:
   just want to release the buffer asap, maybe not needed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018683751


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit
+}
+
+spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+var currentPartitionId = 0
+while (currentPartitionId < numPartitions) {
+  val partition = signal.synchronized {
+while (partitions(currentPartitionId) == null) {
+  signal.wait()

Review Comment:
   no, partitions can be fetched by random order. here wait for the 
`currentPartitionId`-th (start from 0) partition



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018682801


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit

Review Comment:
   yes , it is used to change the returned type to `Unit`



##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+type Batch = (Array[Byte], Long, Long)
+
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val partitions = Array.fill[Array[Batch]](numPartitions)(null)
+
+val processPartition = (iter: Iterator[Batch]) => iter.toArray
+
+val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
+  signal.synchronized {
+partitions(partitionId) = partition
+signal.notify()
+  }
+  val i = 0 // Unit
+}
+
+spark.sparkContext.runJob(batches, processPartition, resultHandler)
+
+var currentPartitionId = 0
+while (currentPartitionId < numPartitions) {
+  val partition = signal.synchronized {
+while (partitions(currentPartitionId) == null) {
+  signal.wait()
+}
+val partition = partitions(currentPartitionId)
+partitions(currentPartitionId) = null
+partition
+  }
+
+  // only send non-empty partitions
+  if (partition.nonEmpty && partition.exists(_._1.nonEmpty)) {
+partition.foreach { case (bytes, count, size) =>
+  val response = proto.Response.newBuilder().setClientId(clientId)
+  val batch = proto.Response.ArrowBatch
+.newBuilder()
+.setRowCount(count)
+.setUncompressedBytes(size)
+.setCompressedBytes(bytes.length)
+.setData(ByteString.copyFrom(bytes))
+.build()
+  response.setArrowBatch(batch)
+  responseObserver.onNext(response.build())
+}
+numSent += 1

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to Gi

[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018682607


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -114,10 +123,97 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records

Review Comment:
   https://github.com/apache/spark/pull/38468#discussion_r1013186031 suggested 
control the batch size < 4MB
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1018681437


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -48,19 +51,25 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
 }
   }
 
-  def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+  def handlePlan(session: SparkSession, request: Request): Unit = {
 // Extract the plan from the request and convert it to a logical plan
 val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-val rows =
-  Dataset.ofRows(session, planner.transform())
-processRows(request.getClientId, rows)
+val dataframe = Dataset.ofRows(session, planner.transform())
+// check whether all data types are supported
+if (Try {
+ArrowUtils.toArrowSchema(dataframe.schema, 
session.sessionState.conf.sessionLocalTimeZone)
+  }.isSuccess) {
+  processRowsAsArrowBatches(request.getClientId, dataframe)
+} else {
+  processRowsAsJsonBatches(request.getClientId, dataframe)

Review Comment:
   nice will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-09 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1017740678


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +127,99 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val rows = dataframe.queryExecution.executedPlan.execute()
+  val numPartitions = rows.getNumPartitions
+  var numSent = 0
+
+  if (numPartitions > 0) {
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val signal = new Object
+val queue = collection.mutable.Queue.empty[(Int, Array[(Array[Byte], 
Long, Long)])]

Review Comment:
   ok, i will make sure the batches are sent in order from server side.
   
   then we don't need partition_id and batch_id any more



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-08 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016358065


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   > Why not use the main thread for this?
   
   ok will follow 
https://github.com/apache/spark/pull/38468#discussion_r1013184548
   
   > You can also use 2 fields in the proto: partition_id & batch_id
   
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-08 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016363328


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+  val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+  val rows = dataframe.queryExecution.executedPlan.execute()
+
+  if (rows.getNumPartitions > 0) {
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => 
iter.toArray
+
+val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], 
Long, Long)]) => {

Review Comment:
   i guess the reordering in the client side is cheap, it only sort # batch  
integers



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-08 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016363328


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+  val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+  val rows = dataframe.queryExecution.executedPlan.execute()
+
+  if (rows.getNumPartitions > 0) {
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => 
iter.toArray
+
+val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], 
Long, Long)]) => {

Review Comment:
   i guess the reordering in the client side is cheap, it only sort # batch  
ints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-08 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016362177


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+  val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+  val rows = dataframe.queryExecution.executedPlan.execute()
+
+  if (rows.getNumPartitions > 0) {
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => 
iter.toArray
+
+val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], 
Long, Long)]) => {

Review Comment:
   i guess the reordering in the client side is cheap, it only sort # batch  
ints



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-08 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1016358065


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   > Why not use the main thread for this?
   ok will follow 
https://github.com/apache/spark/pull/38468#discussion_r1013184548



##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")

Review Comment:
   > Why not use the main thread for this?
   
   ok will follow 
https://github.com/apache/spark/pull/38468#discussion_r1013184548



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-05 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014600233


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,10 +129,91 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
 responseObserver.onCompleted()
   }
 
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+SQLExecution.withNewExecutionId(dataframe.queryExecution, 
Some("collectArrow")) {
+  val pool = 
ThreadUtils.newDaemonSingleThreadExecutor("connect-collect-arrow")
+  val tasks = collection.mutable.ArrayBuffer.empty[Future[_]]
+  val rows = dataframe.queryExecution.executedPlan.execute()
+
+  if (rows.getNumPartitions > 0) {
+val batches = rows.mapPartitionsInternal { iter =>
+  ArrowConverters
+.toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+}
+
+val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => 
iter.toArray
+
+val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], 
Long, Long)]) => {
+  if (taskResult.exists(_._1.nonEmpty)) {
+// only send non-empty partitions
+val task = pool.submit(new Runnable {
+  override def run(): Unit = {
+var batchId = partitionId.toLong << 33

Review Comment:
   generate batch ids in the same way of `monotonically_increasing_id`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-04 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014560227


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+val rows = dataframe.queryExecution.executedPlan.execute()
+var numBatches = 0L
+
+if (rows.getNumPartitions > 0) {
+  val batches = rows.mapPartitionsInternal { iter =>
+ArrowConverters
+  .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+  }
+
+  val obj = new Object
+
+  val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => 
iter.toArray

Review Comment:
   with batch_id, we can send higher partition before lower ones



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-04 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014559977


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+val rows = dataframe.queryExecution.executedPlan.execute()

Review Comment:
   good point



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-04 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1014559938


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,7 +126,70 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val spark = dataframe.sparkSession
+val schema = dataframe.schema
+// TODO: control the batch size instead of max records
+val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
+
+val rows = dataframe.queryExecution.executedPlan.execute()
+var numBatches = 0L
+
+if (rows.getNumPartitions > 0) {
+  val batches = rows.mapPartitionsInternal { iter =>
+ArrowConverters
+  .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId)
+  }
+
+  val obj = new Object
+
+  val processPartition = (iter: Iterator[(Array[Byte], Long, Long)]) => 
iter.toArray
+
+  val resultHandler = (partitionId: Int, taskResult: Array[(Array[Byte], 
Long, Long)]) =>
+obj.synchronized {
+  var batchId = partitionId.toLong << 33
+  taskResult.foreach { case (bytes, count, size) =>
+val response = proto.Response.newBuilder().setClientId(clientId)
+val batch = proto.Response.ArrowBatch
+  .newBuilder()
+  .setBatchId(batchId)
+  .setRowCount(count)
+  .setUncompressedBytes(size)
+  .setCompressedBytes(bytes.length)
+  .setData(ByteString.copyFrom(bytes))
+  .build()
+response.setArrowBatch(batch)
+responseObserver.onNext(response.build())

Review Comment:
   ok will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-04 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013700526


##
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
 }
   }
 
+  private[sql] def toArrowBatchIterator(
+  rowIter: Iterator[InternalRow],
+  schema: StructType,
+  maxRecordsPerBatch: Int,
+  timeZoneId: String,
+  context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+  "toArrowBatchIterator", 0, Long.MaxValue)
+
+val root = VectorSchemaRoot.create(arrowSchema, allocator)
+val unloader = new VectorUnloader(root)
+val arrowWriter = ArrowWriter.create(root)
+
+if (context != null) { // for test at driver
+  context.addTaskCompletionListener[Unit] { _ =>
+root.close()
+allocator.close()
+  }
+}
+
+new Iterator[(Array[Byte], Long, Long)] {
+
+  override def hasNext: Boolean = rowIter.hasNext || {
+root.close()
+allocator.close()
+false
+  }
+
+  override def next(): (Array[Byte], Long, Long) = {

Review Comment:
   ok, let me check it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-03 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013560372


##
python/pyspark/sql/connect/client.py:
##
@@ -182,6 +191,10 @@ def _to_pandas(self, plan: pb2.Plan) -> 
Optional[pandas.DataFrame]:
 req = pb2.Request()
 req.user_context.user_id = self._user_id
 req.plan.CopyFrom(plan)
+if self.has_arrow:
+req.preferred_result_type = pb2.Request.ArrowBatch
+else:
+req.preferred_result_type = pb2.Request.JSONBatch

Review Comment:
   i notice that pyspark checks whether schema is supported 
https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/conversion.py#L102
 ,
   i move this check to the server side since it needs schema.
   
   ~~maybe we can keep it as a fallback if json could support more data types 
(not sure about this)~~
   
   yes, we always prefer arrow, but may get json batches from server if schema 
is not supported



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-03 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013560372


##
python/pyspark/sql/connect/client.py:
##
@@ -182,6 +191,10 @@ def _to_pandas(self, plan: pb2.Plan) -> 
Optional[pandas.DataFrame]:
 req = pb2.Request()
 req.user_context.user_id = self._user_id
 req.plan.CopyFrom(plan)
+if self.has_arrow:
+req.preferred_result_type = pb2.Request.ArrowBatch
+else:
+req.preferred_result_type = pb2.Request.JSONBatch

Review Comment:
   i notice that pyspark checks whether schema is supported 
https://github.com/apache/spark/blob/master/python/pyspark/sql/pandas/conversion.py#L102
 ,
   i move this check to the server side since it needs schema.
   
   maybe we can keep it as a fallback if json could support more data types 
(not sure about this)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-03 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013550540


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val schema = dataframe.schema
+val maxRecordsPerBatch = 
dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = 
dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+val batches = dataframe.queryExecution.executedPlan
+  .execute()
+  .mapPartitionsInternal { iter =>
+ArrowConverters
+  .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, 
TaskContext.get)
+  }
+
+batches.toLocalIterator.foreach { case (bytes, count, size) =>

Review Comment:
   maybe we should also add a field `partitionId` in proto message, than we 
sort by it in client to keep the order



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-03 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013549587


##
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala:
##
@@ -128,6 +128,65 @@ private[sql] object ArrowConverters extends Logging {
 }
   }
 
+  private[sql] def toArrowBatchIterator(
+  rowIter: Iterator[InternalRow],
+  schema: StructType,
+  maxRecordsPerBatch: Int,
+  timeZoneId: String,
+  context: TaskContext): Iterator[(Array[Byte], Long, Long)] = {
+val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+  "toArrowBatchIterator", 0, Long.MaxValue)
+
+val root = VectorSchemaRoot.create(arrowSchema, allocator)
+val unloader = new VectorUnloader(root)
+val arrowWriter = ArrowWriter.create(root)
+
+if (context != null) { // for test at driver
+  context.addTaskCompletionListener[Unit] { _ =>
+root.close()
+allocator.close()
+  }
+}
+
+new Iterator[(Array[Byte], Long, Long)] {
+
+  override def hasNext: Boolean = rowIter.hasNext || {
+root.close()
+allocator.close()
+false
+  }
+
+  override def next(): (Array[Byte], Long, Long) = {
+val out = new ByteArrayOutputStream()
+val writeChannel = new WriteChannel(Channels.newChannel(out))
+
+var rowCount = 0L
+var estimatedSize = 0L
+Utils.tryWithSafeFinally {
+  while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+val row = rowIter.next()
+arrowWriter.write(row)
+rowCount += 1
+estimatedSize += SizeEstimator.estimate(row)
+  }
+  arrowWriter.finish()
+  val batch = unloader.getRecordBatch()
+
+  MessageSerializer.serialize(writeChannel, arrowSchema)

Review Comment:
   i had tried to split schema and record batches, but didn't find an easy way 
to do this. Then made this change as per @grundprinzip and @HyukjinKwon  
suggestion



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-03 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013547085


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -49,21 +51,33 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
 }
   }
 
-  def handlePlan(session: SparkSession, request: proto.Request): Unit = {
+  def handlePlan(session: SparkSession, request: Request): Unit = {
 // Extract the plan from the request and convert it to a logical plan
 val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-val rows =
-  Dataset.ofRows(session, planner.transform())
-processRows(request.getClientId, rows)
+val dataframe = Dataset.ofRows(session, planner.transform())
+request.getPreferredResultType match {
+  case Request.ResultType.ArrowBatch =>
+// check whether all data types are supported

Review Comment:
   for example: CharType, VarcharType, UserDefinedType(like VectorUDT)
   
   supported list: 
   
   
https://github.com/apache/spark/blob/1a90512f605c490255f7b38215c207e64621475b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala#L38-L60
   
   
https://github.com/apache/spark/blob/7b8016a578f511d1c17b16393c487429ce08f132/python/pyspark/sql/pandas/types.py#L54-L120



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] [spark] zhengruifeng commented on a diff in pull request #38468: [SPARK-41005][CONNECT][PYTHON] Arrow-based collect

2022-11-03 Thread GitBox


zhengruifeng commented on code in PR #38468:
URL: https://github.com/apache/spark/pull/38468#discussion_r1013540451


##
connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala:
##
@@ -117,7 +131,36 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[Response]) exte
   responseObserver.onNext(response.build())
 }
 
-responseObserver.onNext(sendMetricsToResponse(clientId, rows))
+responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
+responseObserver.onCompleted()
+  }
+
+  def processRowsAsArrowBatches(clientId: String, dataframe: DataFrame): Unit 
= {
+val schema = dataframe.schema
+val maxRecordsPerBatch = 
dataframe.sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
+val timeZoneId = 
dataframe.sparkSession.sessionState.conf.sessionLocalTimeZone
+
+val batches = dataframe.queryExecution.executedPlan
+  .execute()
+  .mapPartitionsInternal { iter =>
+ArrowConverters
+  .toArrowBatchIterator(iter, schema, maxRecordsPerBatch, timeZoneId, 
TaskContext.get)

Review Comment:
   ok, let me add a todo item



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org