This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4c19158b835b Revert "[SPARK-55351][PYTHON][SQL] PythonArrowInput
encapsulate resource allocation inside `newWriter`"
4c19158b835b is described below
commit 4c19158b835b85a2b6be1be071af4acdf1d02c4f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Feb 4 19:49:55 2026 +0800
Revert "[SPARK-55351][PYTHON][SQL] PythonArrowInput encapsulate resource
allocation inside `newWriter`"
revert https://github.com/apache/spark/pull/54128 due to a potential memory
leak issue
Closes #54138 from zhengruifeng/revert_new_writer.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/execution/python/PythonArrowInput.scala | 73 +++++++++-------------
.../ApplyInPandasWithStatePythonRunner.scala | 1 +
.../TransformWithStateInPySparkPythonRunner.scala | 2 +
3 files changed, 33 insertions(+), 43 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index a8f0f8ba5c56..58a48b1815e1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -53,10 +53,6 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected def pythonMetrics: Map[String, SQLMetric]
- /**
- * Writes input batch to the stream connected to the Python worker.
- * Returns true if any data was written to the stream, false if the input is
exhausted.
- */
protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
@@ -65,6 +61,15 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected def writeUDF(dataOut: DataOutputStream): Unit
+ protected lazy val allocator: BufferAllocator =
+ ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for
$pythonExec", 0, Long.MaxValue)
+
+ protected lazy val root: VectorSchemaRoot = {
+ val arrowSchema = ArrowUtils.toArrowSchema(
+ schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+ VectorSchemaRoot.create(arrowSchema, allocator)
+ }
+
// Create compression codec based on config
protected def codec: CompressionCodec = SQLConf.get.arrowCompressionCodec
match {
case "none" => NoCompressionCodec.INSTANCE
@@ -82,6 +87,20 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
s"Unsupported Arrow compression codec: $other. Supported values: none,
zstd, lz4")
}
+ protected var writer: ArrowStreamWriter = _
+
+ protected def close(): Unit = {
+ Utils.tryWithSafeFinally {
+ // end writes footer to the output stream and doesn't clean any
resources.
+ // It could throw exception if the output stream is closed, so it should
be
+ // in the try block.
+ writer.end()
+ } {
+ root.close()
+ allocator.close()
+ }
+ }
+
protected override def newWriter(
env: SparkEnv,
worker: PythonWorker,
@@ -89,45 +108,20 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
partitionIndex: Int,
context: TaskContext): Writer = {
new Writer(env, worker, inputIterator, partitionIndex, context) {
- private val arrowSchema = ArrowUtils.toArrowSchema(
- schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
- private val allocator: BufferAllocator = ArrowUtils.rootAllocator
- .newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue)
- private val root: VectorSchemaRoot =
VectorSchemaRoot.create(arrowSchema, allocator)
- private var writer: ArrowStreamWriter = _
-
- context.addTaskCompletionListener[Unit] { _ => this.terminate() }
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
writeUDF(dataOut)
}
override def writeNextInputToStream(dataOut: DataOutputStream): Boolean
= {
+
if (writer == null) {
writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()
}
assert(writer != null)
- val hasInput = writeNextBatchToArrowStream(root, writer, dataOut,
inputIterator)
- if (!hasInput) {
- this.terminate()
- }
- hasInput
- }
-
- private def terminate(): Unit = {
- Utils.tryWithSafeFinally {
- // end writes footer to the output stream and doesn't clean any
resources.
- // It could throw exception if the output stream is closed, so it
should be
- // in the try block.
- if (writer != null) {
- writer.end()
- }
- } {
- root.close()
- allocator.close()
- }
+ writeNextBatchToArrowStream(root, writer, dataOut, inputIterator)
}
}
}
@@ -135,6 +129,9 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[InternalRow]] {
self: BasePythonRunner[Iterator[InternalRow], _] =>
+ protected lazy val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
+ protected lazy val unloader = new VectorUnloader(root, true, codec, true)
+
protected val maxRecordsPerBatch: Int = {
val v = SQLConf.get.arrowMaxRecordsPerBatch
if (v > 0) v else Int.MaxValue
@@ -142,18 +139,11 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
protected val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
- protected var arrowWriter: arrow.ArrowWriter = _
- protected var unloader: VectorUnloader = _
-
protected def writeNextBatchToArrowStream(
root: VectorSchemaRoot,
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
- if (arrowWriter == null && unloader == null) {
- arrowWriter = ArrowWriter.create(root)
- unloader = new VectorUnloader(root, true, codec, true)
- }
if (inputIterator.hasNext) {
val startData = dataOut.size()
@@ -177,6 +167,7 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
pythonMetrics("pythonDataSent") += deltaData
true
} else {
+ super[PythonArrowInput].close()
false
}
}
@@ -193,11 +184,6 @@ private[python] trait BatchedPythonArrowInput extends
BasicPythonArrowInput {
writer: ArrowStreamWriter,
dataOut: DataOutputStream,
inputIterator: Iterator[Iterator[InternalRow]]): Boolean = {
- if (arrowWriter == null && unloader == null) {
- arrowWriter = ArrowWriter.create(root)
- unloader = new VectorUnloader(root, true, codec, true)
- }
-
if (!nextBatchStart.hasNext) {
if (inputIterator.hasNext) {
nextBatchStart = inputIterator.next()
@@ -215,6 +201,7 @@ private[python] trait BatchedPythonArrowInput extends
BasicPythonArrowInput {
pythonMetrics("pythonDataSent") += deltaData
true
} else {
+ super[BasicPythonArrowInput].close()
false
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 7de7f140e03d..89d8e425fd2b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -163,6 +163,7 @@ class ApplyInPandasWithStatePythonRunner(
true
} else {
pandasWriter.finalizeData()
+ super[PythonArrowInput].close()
false
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
index 792ba2b100b7..05771d38cd84 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala
@@ -105,6 +105,7 @@ class TransformWithStateInPySparkPythonRunner(
true
} else {
pandasWriter.finalizeCurrentArrowBatch()
+ super[PythonArrowInput].close()
false
}
val deltaData = dataOut.size() - startData
@@ -200,6 +201,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
if (pandasWriter.getTotalNumRowsForBatch > 0) {
pandasWriter.finalizeCurrentArrowBatch()
}
+ super[PythonArrowInput].close()
false
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]