This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8647b243dee [SPARK-44533][PYTHON] Add support for accumulator, broadcast, and Spark files in Python UDTF's analyze 8647b243dee is described below commit 8647b243deed8f2c3279ed17fe196006b6c923af Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Jul 26 21:03:08 2023 -0700 [SPARK-44533][PYTHON] Add support for accumulator, broadcast, and Spark files in Python UDTF's analyze ### What changes were proposed in this pull request? Adds support for `accumulator`, `broadcast` in vanilla PySpark, and Spark files in both vanilla PySpark and Spark Connect Python client, in Python UDTF's analyze. For example, in vanilla PySpark: ```py >>> colname = sc.broadcast("col1") >>> test_accum = sc.accumulator(0) >>> udtf ... class TestUDTF: ... staticmethod ... def analyze(a: AnalyzeArgument) -> AnalyzeResult: ... test_accum.add(1) ... return AnalyzeResult(StructType().add(colname.value, a.data_type)) ... def eval(self, a): ... test_accum.add(1) ... yield a, ... >>> df = TestUDTF(lit(10)) >>> df.printSchema() root |-- col1: integer (nullable = true) >>> df.show() +----+ |col1| +----+ | 10| +----+ >>> test_accum.value 2 ``` or ```py >>> pyfile_path = "my_pyfile.py" >>> with open(pyfile_path, "w") as f: ... f.write("my_func = lambda: 'col1'") ... 24 >>> sc.addPyFile(pyfile_path) >>> # or spark.addArtifacts(pyfile_path, pyfile=True) >>> >>> udtf ... class TestUDTF: ... staticmethod ... def analyze(a: AnalyzeArgument) -> AnalyzeResult: ... import my_pyfile ... return AnalyzeResult(StructType().add(my_pyfile.my_func(), a.data_type)) ... def eval(self, a): ... yield a, ... >>> df = TestUDTF(lit(10)) >>> df.printSchema() root |-- col1: integer (nullable = true) >>> df.show() +----+ |col1| +----+ | 10| +----+ ``` ### Why are the changes needed? To support missing features: `accumulator`, `broadcast`, and Spark files in Python UDTF's analyze. ### Does this PR introduce _any_ user-facing change? Yes, accumulator, broadcast in vanilla PySpark, and Spark files in both vanilla PySpark and Spark Connect Python client will be available. ### How was this patch tested? Added related tests. Closes #42135 from ueshin/issues/SPARK-44533/analyze. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- .../org/apache/spark/api/python/PythonRDD.scala | 4 +- .../org/apache/spark/api/python/PythonRunner.scala | 83 +------- .../spark/api/python/PythonWorkerUtils.scala | 152 ++++++++++++++ .../pyspark/sql/tests/connect/test_parity_udtf.py | 19 ++ python/pyspark/sql/tests/test_udtf.py | 224 ++++++++++++++++++++- python/pyspark/sql/worker/analyze_udtf.py | 18 +- python/pyspark/worker.py | 91 ++------- python/pyspark/worker_util.py | 132 ++++++++++++ .../execution/python/BatchEvalPythonUDTFExec.scala | 7 +- .../python/UserDefinedPythonFunction.scala | 23 ++- 10 files changed, 584 insertions(+), 169 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 95fbc145d83..91fd92d4422 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -487,9 +487,7 @@ private[spark] object PythonRDD extends Logging { } def writeUTF(str: String, dataOut: DataOutputStream): Unit = { - val bytes = str.getBytes(StandardCharsets.UTF_8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + PythonWorkerUtils.writeUTF(str, dataOut) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5d719b33a30..0173de75ff2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -309,8 +309,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) + + PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. val isBarrier = context.isInstanceOf[BarrierTaskContext] if (isBarrier) { @@ -406,69 +407,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( PythonRDD.writeUTF(v, dataOut) } - // sparkFilesDir - val root = jobArtifactUUID.map { uuid => - new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath - }.getOrElse(SparkFiles.getRootDirectory()) - PythonRDD.writeUTF(root, dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val addedBids = newBids.diff(oldBids) - val cnt = toRemove.size + addedBids.size - val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty - dataOut.writeBoolean(needsDecryptionServer) - dataOut.writeInt(cnt) - def sendBidsToRemove(): Unit = { - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(-bid - 1) // bid >= 0 - oldBids.remove(bid) - } - } - if (needsDecryptionServer) { - // if there is encryption, we setup a server which reads the encrypted files, and sends - // the decrypted data to python - val idsAndFiles = broadcastVars.flatMap { broadcast => - if (!oldBids.contains(broadcast.id)) { - oldBids.add(broadcast.id) - Some((broadcast.id, broadcast.value.path)) - } else { - None - } - } - val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) - dataOut.writeInt(server.port) - logTrace(s"broadcast decryption server setup on ${server.port}") - PythonRDD.writeUTF(server.secret, dataOut) - sendBidsToRemove() - idsAndFiles.foreach { case (id, _) => - // send new broadcast - dataOut.writeLong(id) - } - dataOut.flush() - logTrace("waiting for python to read decrypted broadcast data from server") - server.waitTillBroadcastDataSent() - logTrace("done sending decrypted data to python") - } else { - sendBidsToRemove() - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - } - dataOut.flush() + PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) + PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut) dataOut.writeInt(evalType) writeCommand(dataOut) @@ -524,9 +464,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } def writeUTF(str: String, dataOut: DataOutputStream): Unit = { - val bytes = str.getBytes(UTF_8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) + PythonWorkerUtils.writeUTF(str, dataOut) } } @@ -599,13 +537,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected def handleEndOfDataSection(): Unit = { // We've finished the data section of the output, but we can still // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - maybeAccumulator.foreach(_.add(update)) - } + PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, stream) + // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala new file mode 100644 index 00000000000..b6ab031d388 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io.{DataInputStream, DataOutputStream, File} +import java.net.Socket +import java.nio.charset.StandardCharsets + +import org.apache.spark.{SparkEnv, SparkFiles} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging + +private[spark] object PythonWorkerUtils extends Logging { + + /** + * Write a string in UTF-8 charset. + * + * It will be read by `UTF8Deserializer.loads` in Python. + */ + def writeUTF(str: String, dataOut: DataOutputStream): Unit = { + val bytes = str.getBytes(StandardCharsets.UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + + /** + * Write a Python version to check if the Python version is expected. + * + * It will be read and checked by `worker_util.check_python_version`. + */ + def writePythonVersion(pythonVer: String, dataOut: DataOutputStream): Unit = { + writeUTF(pythonVer, dataOut) + } + + /** + * Write Spark files to set up them in the worker. + * + * It will be read and used by `worker_util.setup_spark_files`. + */ + def writeSparkFiles( + jobArtifactUUID: Option[String], + pythonIncludes: Set[String], + dataOut: DataOutputStream): Unit = { + // sparkFilesDir + val root = jobArtifactUUID.map { uuid => + new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath + }.getOrElse(SparkFiles.getRootDirectory()) + writeUTF(root, dataOut) + + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + writeUTF(include, dataOut) + } + } + + /** + * Write broadcasted variables to set up them in the worker. + * + * It will be read and used by 'worker_util.setup_broadcasts`. + */ + def writeBroadcasts( + broadcastVars: Seq[Broadcast[PythonBroadcast]], + worker: Socket, + env: SparkEnv, + dataOut: DataOutputStream): Unit = { + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val addedBids = newBids.diff(oldBids) + val cnt = toRemove.size + addedBids.size + val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty + dataOut.writeBoolean(needsDecryptionServer) + dataOut.writeInt(cnt) + def sendBidsToRemove(): Unit = { + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(-bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } + if (needsDecryptionServer) { + // if there is encryption, we setup a server which reads the encrypted files, and sends + // the decrypted data to python + val idsAndFiles = broadcastVars.flatMap { broadcast => + if (!oldBids.contains(broadcast.id)) { + oldBids.add(broadcast.id) + Some((broadcast.id, broadcast.value.path)) + } else { + None + } + } + val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) + dataOut.writeInt(server.port) + logTrace(s"broadcast decryption server setup on ${server.port}") + writeUTF(server.secret, dataOut) + sendBidsToRemove() + idsAndFiles.foreach { case (id, _) => + // send new broadcast + dataOut.writeLong(id) + } + dataOut.flush() + logTrace("waiting for python to read decrypted broadcast data from server") + server.waitTillBroadcastDataSent() + logTrace("done sending decrypted data to python") + } else { + sendBidsToRemove() + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + } + dataOut.flush() + } + + /** + * Receive accumulator updates from the worker. + * + * The updates are sent by `worker_util.send_accumulator_updates`. + */ + def receiveAccumulatorUpdates( + maybeAccumulator: Option[PythonAccumulatorV2], dataIn: DataInputStream): Unit = { + val numAccumulatorUpdates = dataIn.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = dataIn.readInt() + val update = new Array[Byte](updateLen) + dataIn.readFully(update) + maybeAccumulator.foreach(_.add(update)) + } + } +} diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index e18e116e003..1aff1bd0686 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import unittest + from pyspark.testing.connectutils import should_test_connect if should_test_connect: @@ -104,6 +106,23 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): TestUDTF(lit(1)).show() + @unittest.skip("Spark Connect does not support broadcast but the test depends on it.") + def test_udtf_with_analyze_using_broadcast(self): + super().test_udtf_with_analyze_using_broadcast() + + @unittest.skip("Spark Connect does not support accumulator but the test depends on it.") + def test_udtf_with_analyze_using_accumulator(self): + super().test_udtf_with_analyze_using_accumulator() + + def _add_pyfile(self, path): + self.spark.addArtifacts(path, pyfile=True) + + def _add_archive(self, path): + self.spark.addArtifacts(path, archive=True) + + def _add_file(self, path): + self.spark.addArtifacts(path, file=True) + class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): @classmethod diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 688034b9930..e5b29b36034 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os +import shutil +import tempfile import unittest from typing import Iterator @@ -27,6 +29,7 @@ from pyspark.errors import ( PySparkTypeError, AnalysisException, ) +from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType from pyspark.sql.functions import ( array, @@ -1222,6 +1225,225 @@ class BaseUDTFTestsMixin: ): self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect() + def test_udtf_with_analyze_using_broadcast(self): + colname = self.sc.broadcast("col1") + + @udtf + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add(colname.value, a.data_type)) + + def eval(self, a): + assert colname.value == "col1" + yield a, + + def terminate(self): + assert colname.value == "col1" + yield 100, + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate([TestUDTF(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")]): + with self.subTest(query_no=i): + assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) + assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + + def test_udtf_with_analyze_using_accumulator(self): + test_accum = self.sc.accumulator(0) + + @udtf + class TestUDTF: + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + test_accum.add(1) + return AnalyzeResult(StructType().add("col1", a.data_type)) + + def eval(self, a): + test_accum.add(10) + yield a, + + def terminate(self): + test_accum.add(100) + yield 100, + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate([TestUDTF(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")]): + with self.subTest(query_no=i): + assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) + assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + + self.assertEqual(test_accum.value, 222) + + def _add_pyfile(self, path): + self.sc.addPyFile(path) + + def test_udtf_with_analyze_using_pyfile(self): + with tempfile.TemporaryDirectory() as d: + pyfile_path = os.path.join(d, "my_pyfile.py") + with open(pyfile_path, "w") as f: + f.write("my_func = lambda: 'col1'") + + self._add_pyfile(pyfile_path) + + class TestUDTF: + @staticmethod + def call_my_func() -> str: + import my_pyfile + + return my_pyfile.my_func() + + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.data_type)) + + def eval(self, a): + assert TestUDTF.call_my_func() == "col1" + yield a, + + def terminate(self): + assert TestUDTF.call_my_func() == "col1" + yield 100, + + test_udtf = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", test_udtf) + + for i, df in enumerate( + [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] + ): + with self.subTest(query_no=i): + assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) + assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + + def test_udtf_with_analyze_using_zipped_package(self): + with tempfile.TemporaryDirectory() as d: + package_path = os.path.join(d, "my_zipfile") + os.mkdir(package_path) + pyfile_path = os.path.join(package_path, "__init__.py") + with open(pyfile_path, "w") as f: + f.write("my_func = lambda: 'col1'") + shutil.make_archive(package_path, "zip", d, "my_zipfile") + + self._add_pyfile(f"{package_path}.zip") + + class TestUDTF: + @staticmethod + def call_my_func() -> str: + import my_zipfile + + return my_zipfile.my_func() + + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.data_type)) + + def eval(self, a): + assert TestUDTF.call_my_func() == "col1" + yield a, + + def terminate(self): + assert TestUDTF.call_my_func() == "col1" + yield 100, + + test_udtf = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", test_udtf) + + for i, df in enumerate( + [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] + ): + with self.subTest(query_no=i): + assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) + assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + + def _add_archive(self, path): + self.sc.addArchive(path) + + def test_udtf_with_analyze_using_archive(self): + with tempfile.TemporaryDirectory() as d: + archive_path = os.path.join(d, "my_archive") + os.mkdir(archive_path) + pyfile_path = os.path.join(archive_path, "my_file.txt") + with open(pyfile_path, "w") as f: + f.write("col1") + shutil.make_archive(archive_path, "zip", d, "my_archive") + + self._add_archive(f"{archive_path}.zip#my_files") + + class TestUDTF: + @staticmethod + def read_my_archive() -> str: + with open( + os.path.join( + SparkFiles.getRootDirectory(), "my_files", "my_archive", "my_file.txt" + ), + "r", + ) as my_file: + return my_file.read().strip() + + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add(TestUDTF.read_my_archive(), a.data_type)) + + def eval(self, a): + assert TestUDTF.read_my_archive() == "col1" + yield a, + + def terminate(self): + assert TestUDTF.read_my_archive() == "col1" + yield 100, + + test_udtf = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", test_udtf) + + for i, df in enumerate( + [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] + ): + with self.subTest(query_no=i): + assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) + assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + + def _add_file(self, path): + self.sc.addFile(path) + + def test_udtf_with_analyze_using_file(self): + with tempfile.TemporaryDirectory() as d: + file_path = os.path.join(d, "my_file.txt") + with open(file_path, "w") as f: + f.write("col1") + + self._add_file(file_path) + + class TestUDTF: + @staticmethod + def read_my_file() -> str: + with open( + os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r" + ) as my_file: + return my_file.read().strip() + + @staticmethod + def analyze(a: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult(StructType().add(TestUDTF.read_my_file(), a.data_type)) + + def eval(self, a): + assert TestUDTF.read_my_file() == "col1" + yield a, + + def terminate(self): + assert TestUDTF.read_my_file() == "col1" + yield 100, + + test_udtf = udtf(TestUDTF) + self.spark.udtf.register("test_udtf", test_udtf) + + for i, df in enumerate( + [test_udtf(lit(10)), self.spark.sql("SELECT * FROM test_udtf(10)")] + ): + with self.subTest(query_no=i): + assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) + assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index ea85d937705..44dcd8c892c 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -21,6 +21,7 @@ import sys import traceback from typing import List, IO +from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkRuntimeError, PySparkValueError from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( @@ -33,7 +34,15 @@ from pyspark.serializers import ( from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult from pyspark.util import try_simplify_traceback -from pyspark.worker import check_python_version, read_command, pickleSer, utf8_deserializer +from pyspark.worker_util import ( + check_python_version, + read_command, + pickleSer, + send_accumulator_updates, + setup_broadcasts, + setup_spark_files, + utf8_deserializer, +) def read_udtf(infile: IO) -> type: @@ -87,6 +96,11 @@ def main(infile: IO, outfile: IO) -> None: """ try: check_python_version(infile) + setup_spark_files(infile) + setup_broadcasts(infile) + + _accumulatorRegistry.clear() + handler = read_udtf(infile) args = read_arguments(infile) @@ -122,6 +136,8 @@ def main(infile: IO, outfile: IO) -> None: print(traceback.format_exc(), file=sys.stderr) sys.exit(-1) + send_accumulator_updates(outfile) + # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: write_int(SpecialLengths.END_OF_STREAM, outfile) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 95924757242..bfe788faf6d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -22,7 +22,6 @@ import os import sys import time from inspect import currentframe, getframeinfo, getfullargspec -import importlib import json from typing import Iterator @@ -37,10 +36,8 @@ import warnings import faulthandler from pyspark.accumulators import _accumulatorRegistry -from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext -from pyspark.files import SparkFiles from pyspark.resource import ResourceInformation from pyspark.rdd import PythonEvalType from pyspark.serializers import ( @@ -67,9 +64,15 @@ from pyspark.sql.types import StructType, _parse_datatype_json_string from pyspark.util import fail_on_stopiteration, try_simplify_traceback from pyspark import shuffle from pyspark.errors import PySparkRuntimeError, PySparkTypeError - -pickleSer = CPickleSerializer() -utf8_deserializer = UTF8Deserializer() +from pyspark.worker_util import ( + check_python_version, + read_command, + pickleSer, + send_accumulator_updates, + setup_broadcasts, + setup_spark_files, + utf8_deserializer, +) def report_times(outfile, boot, init, finish): @@ -79,20 +82,6 @@ def report_times(outfile, boot, init, finish): write_long(int(1000 * finish), outfile) -def add_path(path): - # worker can be used, so do not add path multiple times - if path not in sys.path: - # overwrite system packages - sys.path.insert(1, path) - - -def read_command(serializer, file): - command = serializer._read_with_length(file) - if isinstance(command, Broadcast): - command = serializer.loads(command.value) - return command - - def chain(f, g): """chain two functions together""" return lambda *a: g(f(*a)) @@ -970,21 +959,6 @@ def read_udfs(pickleSer, infile, eval_type): return func, None, ser, ser -def check_python_version(infile): - """ - Check the Python version between the running process and the one used to serialize the command. - """ - version = utf8_deserializer.loads(infile) - if version != "%d.%d" % sys.version_info[:2]: - raise PySparkRuntimeError( - error_class="PYTHON_VERSION_MISMATCH", - message_parameters={ - "worker_version": str(sys.version_info[:2]), - "driver_version": str(version), - }, - ) - - def main(infile, outfile): faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) try: @@ -1074,47 +1048,8 @@ def main(infile, outfile): shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() - # fetch name of workdir - spark_files_dir = utf8_deserializer.loads(infile) - SparkFiles._root_directory = spark_files_dir - SparkFiles._is_running_on_worker = True - - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - add_path(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) - for _ in range(num_python_includes): - filename = utf8_deserializer.loads(infile) - add_path(os.path.join(spark_files_dir, filename)) - - importlib.invalidate_caches() - - # fetch names and values of broadcast variables - needs_broadcast_decryption_server = read_bool(infile) - num_broadcast_variables = read_int(infile) - if needs_broadcast_decryption_server: - # read the decrypted data from a server in the jvm - port = read_int(infile) - auth_secret = utf8_deserializer.loads(infile) - (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) - - for _ in range(num_broadcast_variables): - bid = read_long(infile) - if bid >= 0: - if needs_broadcast_decryption_server: - read_bid = read_long(broadcast_sock_file) - assert read_bid == bid - _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) - else: - path = utf8_deserializer.loads(infile) - _broadcastRegistry[bid] = Broadcast(path=path) - - else: - bid = -bid - 1 - _broadcastRegistry.pop(bid) - - if needs_broadcast_decryption_server: - broadcast_sock_file.write(b"1") - broadcast_sock_file.close() + setup_spark_files(infile) + setup_broadcasts(infile) _accumulatorRegistry.clear() eval_type = read_int(infile) @@ -1178,9 +1113,7 @@ def main(infile, outfile): # Mark the beginning of the accumulators section of the output write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - write_int(len(_accumulatorRegistry), outfile) - for (aid, accum) in _accumulatorRegistry.items(): - pickleSer._write_with_length((aid, accum._value), outfile) + send_accumulator_updates(outfile) # check end of stream if read_int(infile) == SpecialLengths.END_OF_STREAM: diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py new file mode 100644 index 00000000000..eab0daf8f59 --- /dev/null +++ b/python/pyspark/worker_util.py @@ -0,0 +1,132 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Util functions for workers. +""" +import importlib +import os +import sys +from typing import Any, IO + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.errors import PySparkRuntimeError +from pyspark.files import SparkFiles +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_bool, + read_int, + read_long, + write_int, + FramedSerializer, + UTF8Deserializer, + CPickleSerializer, +) + +pickleSer = CPickleSerializer() +utf8_deserializer = UTF8Deserializer() + + +def add_path(path: str) -> None: + # worker can be used, so do not add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + +def read_command(serializer: FramedSerializer, file: IO) -> Any: + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def check_python_version(infile: IO) -> None: + """ + Check the Python version between the running process and the one used to serialize the command. + """ + version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: + raise PySparkRuntimeError( + error_class="PYTHON_VERSION_MISMATCH", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "driver_version": str(version), + }, + ) + + +def setup_spark_files(infile: IO) -> None: + """ + Set up Spark files, archives, and pyfiles. + """ + # fetch name of workdir + spark_files_dir = utf8_deserializer.loads(infile) + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True + + # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH + add_path(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + filename = utf8_deserializer.loads(infile) + add_path(os.path.join(spark_files_dir, filename)) + + importlib.invalidate_caches() + + +def setup_broadcasts(infile: IO) -> None: + """ + Set up broadcasted variables. + """ + # fetch names and values of broadcast variables + needs_broadcast_decryption_server = read_bool(infile) + num_broadcast_variables = read_int(infile) + if needs_broadcast_decryption_server: + # read the decrypted data from a server in the jvm + port = read_int(infile) + auth_secret = utf8_deserializer.loads(infile) + (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) + + for _ in range(num_broadcast_variables): + bid = read_long(infile) + if bid >= 0: + if needs_broadcast_decryption_server: + read_bid = read_long(broadcast_sock_file) + assert read_bid == bid + _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) + else: + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) + + else: + bid = -bid - 1 + _broadcastRegistry.pop(bid) + + if needs_broadcast_decryption_server: + broadcast_sock_file.write(b"1") + broadcast_sock_file.close() + + +def send_accumulator_updates(outfile: IO) -> None: + """ + Send the accumulator updates back to JVM. + """ + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): + pickleSer._write_with_length((aid, accum._value), outfile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 7d9fdb3bf48..9dae874e3ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream import java.net.Socket -import java.nio.charset.StandardCharsets.UTF_8 import scala.collection.JavaConverters._ import net.razorvine.pickle.Unpickler import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonWorkerUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData @@ -126,8 +125,6 @@ object PythonUDTFRunner { } dataOut.writeInt(udtf.func.command.length) dataOut.write(udtf.func.command.toArray) - val schemaBytes = udtf.elementSchema.json.getBytes(UTF_8) - dataOut.writeInt(schemaBytes.length) - dataOut.write(schemaBytes) + PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index d683f865dcd..05239d8d164 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -26,8 +26,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler -import org.apache.spark.{SparkEnv, SparkException} -import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonRDD, SpecialLengths} +import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException} +import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths} import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} @@ -173,12 +173,19 @@ object UserDefinedPythonTableFunction { val bufferSize: Int = env.conf.get(BUFFER_SIZE) val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + val envVars = new HashMap[String, String](func.envVars) val pythonExec = func.pythonExec val pythonVer = func.pythonVer + val pythonIncludes = func.pythonIncludes.asScala.toSet + val broadcastVars = func.broadcastVars.asScala.toSeq + val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset()) + envVars.put("SPARK_LOCAL_DIRS", localdir) if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } @@ -188,6 +195,8 @@ object UserDefinedPythonTableFunction { envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) + EvaluatePython.registerPicklers() val pickler = new Pickler(/* useMemo = */ true, /* valueCompare = */ false) @@ -200,8 +209,9 @@ object UserDefinedPythonTableFunction { new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize)) val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) + PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) + PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) + PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut) // Send Python UDTF dataOut.writeInt(func.command.length) @@ -210,7 +220,7 @@ object UserDefinedPythonTableFunction { // Send arguments dataOut.writeInt(exprs.length) exprs.zip(tableArgs).foreach { case (expr, is_table) => - PythonRDD.writeUTF(expr.dataType.json, dataOut) + PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut) if (expr.foldable) { dataOut.writeBoolean(true) val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType)) @@ -240,6 +250,9 @@ object UserDefinedPythonTableFunction { throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg) } + PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn) + Option(func.accumulator).foreach(_.merge(maybeAccumulator.get)) + dataIn.readInt() match { case SpecialLengths.END_OF_STREAM if reuseWorker => env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org