zhengruifeng commented on code in PR #52303:
URL: https://github.com/apache/spark/pull/52303#discussion_r2366623458
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala:
##########
@@ -194,3 +180,115 @@ private[python] trait BatchedPythonArrowInput extends
BasicPythonArrowInput {
}
}
}
+
+object BatchedPythonArrowInput {
+ /**
+ * Split a group into smaller Arrow batches within
+ * a separate and complete Arrow streaming format in order
+ * to work around Arrow 2G limit, see ARROW-4890.
+ *
+ * The return value is the number of rows in the batch.
+ * Each split Arrow batch also does not have mixed grouped. For example:
+ *
+ * +------------------------+ +------------------------+
+--------------------
+ * |Group (by k1) v1, v2, v3| |Group (by k2) v1, v2, v3| |
...
+ * +------------------------+ +------------------------+
+--------------------
+ *
+ *
+------+-----------------+------+------+-----------------+------+------+--------------------
+ * |Schema| Batch| Batch|Schema| Batch| Batch|Schema|
Batch ...
+ *
+------+-----------------+------+------+-----------------+------+------+--------------------
+ * | Arrow Streaming Format | Arrow Streaming Format |
Arrow Streaming Form...
+ *
+ * Here, each (Arrow) batch does not span multiple groups.
+ * These (Arrow) batches within each complete Arrow IPC Format are
+ * reconstructed into the group back as pandas instances later on the Python
worker side.
+ */
+ def writeSizedBatch(
+ arrowWriter: ArrowWriter,
+ writer: ArrowStreamWriter,
+ rowIter: Iterator[InternalRow],
+ maxBytesPerBatch: Long,
+ maxRecordsPerBatch: Int): Int = {
+ var numRowsInBatch: Int = 0
+
+ def underBatchSizeLimit: Boolean =
+ (maxBytesPerBatch == Int.MaxValue) || (arrowWriter.sizeInBytes() <
maxBytesPerBatch)
+
+ while (rowIter.hasNext && numRowsInBatch < maxRecordsPerBatch &&
+ underBatchSizeLimit) {
+ arrowWriter.write(rowIter.next())
+ numRowsInBatch += 1
+ }
+
+ assert(numRowsInBatch > 0)
+ assert(numRowsInBatch <= maxRecordsPerBatch)
+ arrowWriter.finish()
+ writer.writeBatch()
+
+ arrowWriter.reset()
+ numRowsInBatch
+ }
+}
+
+/**
+ * Enables an optimization that splits each group into the sized batches.
+ */
+private[python] trait GroupedPythonArrowInput { self:
RowInputArrowPythonRunner =>
+
+ private val maxRecordsPerBatch = {
+ val v = SQLConf.get.arrowMaxRecordsPerBatch
+ if (v > 0) v else Int.MaxValue
+ }
+
+ private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch
+
+ protected override def newWriter(
+ env: SparkEnv,
+ worker: PythonWorker,
+ inputIterator: Iterator[Iterator[InternalRow]],
+ partitionIndex: Int,
+ context: TaskContext): Writer = {
+ new Writer(env, worker, inputIterator, partitionIndex, context) {
+ protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+ handleMetadataBeforeExec(dataOut)
+ writeUDF(dataOut)
+ }
+
+ var writer: ArrowWriterWrapper = null
+ // Marker inside the input iterator to indicate the start of the next
batch.
+ private var nextBatchStart: Iterator[InternalRow] = Iterator.empty
+
+ override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
Review Comment:
yes, this is as expected.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]