This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 663a6c479b22 [SPARK-55351][PYTHON][SQL] PythonArrowInput encapsulate
resource allocation inside `newWriter`
663a6c479b22 is described below
commit 663a6c479b22d4d63c58210db45f2015a24182b8
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 4 13:58:27 2026 +0800
[SPARK-55351][PYTHON][SQL] PythonArrowInput encapsulate resource allocation
inside `newWriter`
### What changes were proposed in this pull request?
PythonArrowInput encapsulate resource allocation inside `newWriter`
### Why are the changes needed?
it is up to the writer to manage the resource.
PythonArrowInput is just a helper layer to build the writer.
Currently, subclass always have to release the resource `allocator\root`
even if it might be not used in subclass.
### Does this PR introduce _any_ user-facing change?
no, internal refactoring
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #54128 from zhengruifeng/refactor_pai_new_writer.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/execution/python/PythonArrowInput.scala | 73 +++++++++++++---------
.../ApplyInPandasWithStatePythonRunner.scala | 1 -
.../TransformWithStateInPySparkPythonRunner.scala | 2 -
3 files changed, 43 insertions(+), 33 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 58a48b1815e1..a8f0f8ba5c56 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -53,6 +53,10 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected def pythonMetrics: Map[String, SQLMetric]
+ /**
+ * Writes input batch to the stream connected to the Python worker.
+ * Returns true if any data was written to the stream, false if the input is
exhausted.
+ */
protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
@@ -61,15 +65,6 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected def writeUDF(dataOut: DataOutputStream): Unit
- protected lazy val allocator: BufferAllocator =
- ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for
$pythonExec", 0, Long.MaxValue)
-
- protected lazy val root: VectorSchemaRoot = {
- val arrowSchema = ArrowUtils.toArrowSchema(
- schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
- VectorSchemaRoot.create(arrowSchema, allocator)
- }
-
// Create compression codec based on config
protected def codec: CompressionCodec = SQLConf.get.arrowCompressionCodec
match {
case "none" => NoCompressionCodec.INSTANCE
@@ -87,20 +82,6 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
s"Unsupported Arrow compression codec: $other. Supported values: none,
zstd, lz4")
}
- protected var writer: ArrowStreamWriter = _
-
- protected def close(): Unit = {
- Utils.tryWithSafeFinally {
- // end writes footer to the output stream and doesn't clean any
resources.
- // It could throw exception if the output stream is closed, so it should
be
- // in the try block.
- writer.end()
- } {
- root.close()
- allocator.close()
- }
- }
-
protected override def newWriter(
env: SparkEnv,
worker: PythonWorker,
@@ -108,20 +89,45 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
partitionIndex: Int,
context: TaskContext): Writer = {
new Writer(env, worker, inputIterator, partitionIndex, context) {
+ private val arrowSchema = ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+ private val allocator: BufferAllocator = ArrowUtils.rootAllocator
+ .newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue)
+ private val root: VectorSchemaRoot =
VectorSchemaRoot.create(arrowSchema, allocator)
+ private var writer: ArrowStreamWriter = _
+
+ context.addTaskCompletionListener[Unit] { _ => this.terminate() }
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
writeUDF(dataOut)
}
override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
-
if (writer == null) {
writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()
}
assert(writer != null)
- writeNextBatchToArrowStream(root, writer, dataOut, inputIterator)
+ val hasInput = writeNextBatchToArrowStream(root, writer, dataOut,
inputIterator)
+ if (!hasInput) {
+ this.terminate()
+ }
+ hasInput
+ }
+
+ private def terminate(): Unit = {
+ Utils.tryWithSafeFinally {
+ // end writes footer to the output stream and doesn't clean any
resources.
+ // It could throw exception if the output stream is closed, so it
should be
+ // in the try block.
+ if (writer != null) {
+ writer.end()
+ }
+ } {
+ root.close()
+ allocator.close()
+ }
}
}
}
@@ -129,9 +135,6 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[InternalRow]] {
self: BasePythonRunner[Iterator[InternalRow], _] =>
- protected lazy val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
- protected lazy val unloader = new VectorUnloader(root, true, codec, true)
-
protected val maxRecordsPerBatch: Int = {
val v = SQLConf.get.arrowMaxRecordsPerBatch
if (v > 0) v else Int.MaxValue
@@ -139,11 +142,18 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
protected val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
+ protected var arrowWriter: arrow.ArrowWriter = _
+ protected var unloader: VectorUnloader = _
+
protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
+ if (arrowWriter == null && unloader == null) {
+ arrowWriter = ArrowWriter.create(root)
+ unloader = new VectorUnloader(root, true, codec, true)
+ }
if (inputIterator.hasNext) {
val startData = dataOut.size()
@@ -167,7 +177,6 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
pythonMetrics("pythonDataSent") += deltaData
true
} else {
- super[PythonArrowInput].close()
false
}
}
@@ -184,6 +193,11 @@ private[python] trait BatchedPythonArrowInput extends
BasicPythonArrowInput {
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
+ if (arrowWriter == null && unloader == null) {
+ arrowWriter = ArrowWriter.create(root)
+ unloader = new VectorUnloader(root, true, codec, true)
+ }
+
if (!nextBatchStart.hasNext) {
if (inputIterator.hasNext) {
nextBatchStart = inputIterator.next()
@@ -201,7 +215,6 @@ private[python] trait BatchedPythonArrowInput extends
BasicPythonArrowInput {
pythonMetrics("pythonDataSent") += deltaData
true
} else {
- super[BasicPythonArrowInput].close()
false
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 89d8e425fd2b..7de7f140e03d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -163,7 +163,6 @@ class ApplyInPandasWithStatePythonRunner(
true
} else {
pandasWriter.finalizeData()
- super[PythonArrowInput].close()
false
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 05771d38cd84..792ba2b100b7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -105,7 +105,6 @@ class TransformWithStateInPySparkPythonRunner(
true
} else {
pandasWriter.finalizeCurrentArrowBatch()
- super[PythonArrowInput].close()
false
}
val deltaData = dataOut.size() - startData
@@ -201,7 +200,6 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
if (pandasWriter.getTotalNumRowsForBatch > 0) {
pandasWriter.finalizeCurrentArrowBatch()
}
- super[PythonArrowInput].close()
false
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]