This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8647b243dee [SPARK-44533][PYTHON] Add support for accumulator, 
broadcast, and Spark files in Python UDTF's analyze
8647b243dee is described below

commit 8647b243deed8f2c3279ed17fe196006b6c923af
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Wed Jul 26 21:03:08 2023 -0700

    [SPARK-44533][PYTHON] Add support for accumulator, broadcast, and Spark 
files in Python UDTF's analyze
    
    ### What changes were proposed in this pull request?
    
    Adds support for `accumulator`, `broadcast` in vanilla PySpark, and Spark 
files in both vanilla PySpark and Spark Connect Python client, in Python UDTF's 
analyze.
    
    For example, in vanilla PySpark:
    
    ```py
    >>> colname = sc.broadcast("col1")
    >>> test_accum = sc.accumulator(0)
    
    >>> udtf
    ... class TestUDTF:
    ...     staticmethod
    ...     def analyze(a: AnalyzeArgument) -> AnalyzeResult:
    ...         test_accum.add(1)
    ...         return AnalyzeResult(StructType().add(colname.value, 
a.data_type))
    ...     def eval(self, a):
    ...         test_accum.add(1)
    ...         yield a,
    ...
    >>> df = TestUDTF(lit(10))
    >>> df.printSchema()
    root
     |-- col1: integer (nullable = true)
    
    >>> df.show()
    +----+
    |col1|
    +----+
    |  10|
    +----+
    
    >>> test_accum.value
    2
    ```
    
    or
    
    ```py
    >>> pyfile_path = "my_pyfile.py"
    >>> with open(pyfile_path, "w") as f:
    ...     f.write("my_func = lambda: 'col1'")
    ...
    24
    >>> sc.addPyFile(pyfile_path)
    >>> # or spark.addArtifacts(pyfile_path, pyfile=True)
    >>>
    >>> udtf
    ... class TestUDTF:
    ...     staticmethod
    ...     def analyze(a: AnalyzeArgument) -> AnalyzeResult:
    ...         import my_pyfile
    ...         return AnalyzeResult(StructType().add(my_pyfile.my_func(), 
a.data_type))
    ...     def eval(self, a):
    ...         yield a,
    ...
    >>> df = TestUDTF(lit(10))
    >>> df.printSchema()
    root
     |-- col1: integer (nullable = true)
    
    >>> df.show()
    +----+
    |col1|
    +----+
    |  10|
    +----+
    ```
    
    ### Why are the changes needed?
    
    To support missing features: `accumulator`, `broadcast`, and Spark files in 
Python UDTF's analyze.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, accumulator, broadcast in vanilla PySpark, and Spark files in both 
vanilla PySpark and Spark Connect Python client will be available.
    
    ### How was this patch tested?
    
    Added related tests.
    
    Closes #42135 from ueshin/issues/SPARK-44533/analyze.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 .../org/apache/spark/api/python/PythonRDD.scala    |   4 +-
 .../org/apache/spark/api/python/PythonRunner.scala |  83 +-------
 .../spark/api/python/PythonWorkerUtils.scala       | 152 ++++++++++++++
 .../pyspark/sql/tests/connect/test_parity_udtf.py  |  19 ++
 python/pyspark/sql/tests/test_udtf.py              | 224 ++++++++++++++++++++-
 python/pyspark/sql/worker/analyze_udtf.py          |  18 +-
 python/pyspark/worker.py                           |  91 ++-------
 python/pyspark/worker_util.py                      | 132 ++++++++++++
 .../execution/python/BatchEvalPythonUDTFExec.scala |   7 +-
 .../python/UserDefinedPythonFunction.scala         |  23 ++-
 10 files changed, 584 insertions(+), 169 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 95fbc145d83..91fd92d4422 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -487,9 +487,7 @@ private[spark] object PythonRDD extends Logging {
   }
 
   def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
-    val bytes = str.getBytes(StandardCharsets.UTF_8)
-    dataOut.writeInt(bytes.length)
-    dataOut.write(bytes)
+    PythonWorkerUtils.writeUTF(str, dataOut)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 5d719b33a30..0173de75ff2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -309,8 +309,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         val dataOut = new DataOutputStream(stream)
         // Partition index
         dataOut.writeInt(partitionIndex)
-        // Python version of driver
-        PythonRDD.writeUTF(pythonVer, dataOut)
+
+        PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
         // Init a ServerSocket to accept method calls from Python side.
         val isBarrier = context.isInstanceOf[BarrierTaskContext]
         if (isBarrier) {
@@ -406,69 +407,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
           PythonRDD.writeUTF(v, dataOut)
         }
 
-        // sparkFilesDir
-        val root = jobArtifactUUID.map { uuid =>
-          new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
-        }.getOrElse(SparkFiles.getRootDirectory())
-        PythonRDD.writeUTF(root, dataOut)
-        // Python includes (*.zip and *.egg files)
-        dataOut.writeInt(pythonIncludes.size)
-        for (include <- pythonIncludes) {
-          PythonRDD.writeUTF(include, dataOut)
-        }
-        // Broadcast variables
-        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
-        val newBids = broadcastVars.map(_.id).toSet
-        // number of different broadcasts
-        val toRemove = oldBids.diff(newBids)
-        val addedBids = newBids.diff(oldBids)
-        val cnt = toRemove.size + addedBids.size
-        val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
-        dataOut.writeBoolean(needsDecryptionServer)
-        dataOut.writeInt(cnt)
-        def sendBidsToRemove(): Unit = {
-          for (bid <- toRemove) {
-            // remove the broadcast from worker
-            dataOut.writeLong(-bid - 1) // bid >= 0
-            oldBids.remove(bid)
-          }
-        }
-        if (needsDecryptionServer) {
-          // if there is encryption, we setup a server which reads the 
encrypted files, and sends
-          // the decrypted data to python
-          val idsAndFiles = broadcastVars.flatMap { broadcast =>
-            if (!oldBids.contains(broadcast.id)) {
-              oldBids.add(broadcast.id)
-              Some((broadcast.id, broadcast.value.path))
-            } else {
-              None
-            }
-          }
-          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
-          dataOut.writeInt(server.port)
-          logTrace(s"broadcast decryption server setup on ${server.port}")
-          PythonRDD.writeUTF(server.secret, dataOut)
-          sendBidsToRemove()
-          idsAndFiles.foreach { case (id, _) =>
-            // send new broadcast
-            dataOut.writeLong(id)
-          }
-          dataOut.flush()
-          logTrace("waiting for python to read decrypted broadcast data from 
server")
-          server.waitTillBroadcastDataSent()
-          logTrace("done sending decrypted data to python")
-        } else {
-          sendBidsToRemove()
-          for (broadcast <- broadcastVars) {
-            if (!oldBids.contains(broadcast.id)) {
-              // send new broadcast
-              dataOut.writeLong(broadcast.id)
-              PythonRDD.writeUTF(broadcast.value.path, dataOut)
-              oldBids.add(broadcast.id)
-            }
-          }
-        }
-        dataOut.flush()
+        PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, 
dataOut)
+        PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
 
         dataOut.writeInt(evalType)
         writeCommand(dataOut)
@@ -524,9 +464,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     }
 
     def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
-      val bytes = str.getBytes(UTF_8)
-      dataOut.writeInt(bytes.length)
-      dataOut.write(bytes)
+      PythonWorkerUtils.writeUTF(str, dataOut)
     }
   }
 
@@ -599,13 +537,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     protected def handleEndOfDataSection(): Unit = {
       // We've finished the data section of the output, but we can still
       // read some accumulator updates:
-      val numAccumulatorUpdates = stream.readInt()
-      (1 to numAccumulatorUpdates).foreach { _ =>
-        val updateLen = stream.readInt()
-        val update = new Array[Byte](updateLen)
-        stream.readFully(update)
-        maybeAccumulator.foreach(_.add(update))
-      }
+      PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, stream)
+
       // Check whether the worker is ready to be re-used.
       if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
         if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) {
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
new file mode 100644
index 00000000000..b6ab031d388
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io.{DataInputStream, DataOutputStream, File}
+import java.net.Socket
+import java.nio.charset.StandardCharsets
+
+import org.apache.spark.{SparkEnv, SparkFiles}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+
+private[spark] object PythonWorkerUtils extends Logging {
+
+  /**
+   * Write a string in UTF-8 charset.
+   *
+   * It will be read by `UTF8Deserializer.loads` in Python.
+   */
+  def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
+    val bytes = str.getBytes(StandardCharsets.UTF_8)
+    dataOut.writeInt(bytes.length)
+    dataOut.write(bytes)
+  }
+
+  /**
+   * Write a Python version to check if the Python version is expected.
+   *
+   * It will be read and checked by `worker_util.check_python_version`.
+   */
+  def writePythonVersion(pythonVer: String, dataOut: DataOutputStream): Unit = 
{
+    writeUTF(pythonVer, dataOut)
+  }
+
+  /**
+   * Write Spark files to set up them in the worker.
+   *
+   * It will be read and used by `worker_util.setup_spark_files`.
+   */
+  def writeSparkFiles(
+      jobArtifactUUID: Option[String],
+      pythonIncludes: Set[String],
+      dataOut: DataOutputStream): Unit = {
+    // sparkFilesDir
+    val root = jobArtifactUUID.map { uuid =>
+      new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
+    }.getOrElse(SparkFiles.getRootDirectory())
+    writeUTF(root, dataOut)
+
+    // Python includes (*.zip and *.egg files)
+    dataOut.writeInt(pythonIncludes.size)
+    for (include <- pythonIncludes) {
+      writeUTF(include, dataOut)
+    }
+  }
+
+  /**
+   * Write broadcasted variables to set up them in the worker.
+   *
+   * It will be read and used by 'worker_util.setup_broadcasts`.
+   */
+  def writeBroadcasts(
+      broadcastVars: Seq[Broadcast[PythonBroadcast]],
+      worker: Socket,
+      env: SparkEnv,
+      dataOut: DataOutputStream): Unit = {
+    // Broadcast variables
+    val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+    val newBids = broadcastVars.map(_.id).toSet
+    // number of different broadcasts
+    val toRemove = oldBids.diff(newBids)
+    val addedBids = newBids.diff(oldBids)
+    val cnt = toRemove.size + addedBids.size
+    val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
+    dataOut.writeBoolean(needsDecryptionServer)
+    dataOut.writeInt(cnt)
+    def sendBidsToRemove(): Unit = {
+      for (bid <- toRemove) {
+        // remove the broadcast from worker
+        dataOut.writeLong(-bid - 1) // bid >= 0
+        oldBids.remove(bid)
+      }
+    }
+    if (needsDecryptionServer) {
+      // if there is encryption, we setup a server which reads the encrypted 
files, and sends
+      // the decrypted data to python
+      val idsAndFiles = broadcastVars.flatMap { broadcast =>
+        if (!oldBids.contains(broadcast.id)) {
+          oldBids.add(broadcast.id)
+          Some((broadcast.id, broadcast.value.path))
+        } else {
+          None
+        }
+      }
+      val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+      dataOut.writeInt(server.port)
+      logTrace(s"broadcast decryption server setup on ${server.port}")
+      writeUTF(server.secret, dataOut)
+      sendBidsToRemove()
+      idsAndFiles.foreach { case (id, _) =>
+        // send new broadcast
+        dataOut.writeLong(id)
+      }
+      dataOut.flush()
+      logTrace("waiting for python to read decrypted broadcast data from 
server")
+      server.waitTillBroadcastDataSent()
+      logTrace("done sending decrypted data to python")
+    } else {
+      sendBidsToRemove()
+      for (broadcast <- broadcastVars) {
+        if (!oldBids.contains(broadcast.id)) {
+          // send new broadcast
+          dataOut.writeLong(broadcast.id)
+          writeUTF(broadcast.value.path, dataOut)
+          oldBids.add(broadcast.id)
+        }
+      }
+    }
+    dataOut.flush()
+  }
+
+  /**
+   * Receive accumulator updates from the worker.
+   *
+   * The updates are sent by `worker_util.send_accumulator_updates`.
+   */
+  def receiveAccumulatorUpdates(
+      maybeAccumulator: Option[PythonAccumulatorV2], dataIn: DataInputStream): 
Unit = {
+    val numAccumulatorUpdates = dataIn.readInt()
+    (1 to numAccumulatorUpdates).foreach { _ =>
+      val updateLen = dataIn.readInt()
+      val update = new Array[Byte](updateLen)
+      dataIn.readFully(update)
+      maybeAccumulator.foreach(_.add(update))
+    }
+  }
+}
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index e18e116e003..1aff1bd0686 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -14,6 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import unittest
+
 from pyspark.testing.connectutils import should_test_connect
 
 if should_test_connect:
@@ -104,6 +106,23 @@ class UDTFParityTests(BaseUDTFTestsMixin, 
ReusedConnectTestCase):
         with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
             TestUDTF(lit(1)).show()
 
+    @unittest.skip("Spark Connect does not support broadcast but the test 
depends on it.")
+    def test_udtf_with_analyze_using_broadcast(self):
+        super().test_udtf_with_analyze_using_broadcast()
+
+    @unittest.skip("Spark Connect does not support accumulator but the test 
depends on it.")
+    def test_udtf_with_analyze_using_accumulator(self):
+        super().test_udtf_with_analyze_using_accumulator()
+
+    def _add_pyfile(self, path):
+        self.spark.addArtifacts(path, pyfile=True)
+
+    def _add_archive(self, path):
+        self.spark.addArtifacts(path, archive=True)
+
+    def _add_file(self, path):
+        self.spark.addArtifacts(path, file=True)
+
 
 class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
     @classmethod
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 688034b9930..e5b29b36034 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -14,7 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import os
+import shutil
+import tempfile
 import unittest
 
 from typing import Iterator
@@ -27,6 +29,7 @@ from pyspark.errors import (
     PySparkTypeError,
     AnalysisException,
 )
+from pyspark.files import SparkFiles
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.functions import (
     array,
@@ -1222,6 +1225,225 @@ class BaseUDTFTestsMixin:
         ):
             self.spark.sql("SELECT * FROM test_udtf(1, 'x')").collect()
 
+    def test_udtf_with_analyze_using_broadcast(self):
+        colname = self.sc.broadcast("col1")
+
+        @udtf
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                return AnalyzeResult(StructType().add(colname.value, 
a.data_type))
+
+            def eval(self, a):
+                assert colname.value == "col1"
+                yield a,
+
+            def terminate(self):
+                assert colname.value == "col1"
+                yield 100,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        for i, df in enumerate([TestUDTF(lit(10)), self.spark.sql("SELECT * 
FROM test_udtf(10)")]):
+            with self.subTest(query_no=i):
+                assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
+                assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+
+    def test_udtf_with_analyze_using_accumulator(self):
+        test_accum = self.sc.accumulator(0)
+
+        @udtf
+        class TestUDTF:
+            @staticmethod
+            def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                test_accum.add(1)
+                return AnalyzeResult(StructType().add("col1", a.data_type))
+
+            def eval(self, a):
+                test_accum.add(10)
+                yield a,
+
+            def terminate(self):
+                test_accum.add(100)
+                yield 100,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        for i, df in enumerate([TestUDTF(lit(10)), self.spark.sql("SELECT * 
FROM test_udtf(10)")]):
+            with self.subTest(query_no=i):
+                assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
+                assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+
+        self.assertEqual(test_accum.value, 222)
+
+    def _add_pyfile(self, path):
+        self.sc.addPyFile(path)
+
+    def test_udtf_with_analyze_using_pyfile(self):
+        with tempfile.TemporaryDirectory() as d:
+            pyfile_path = os.path.join(d, "my_pyfile.py")
+            with open(pyfile_path, "w") as f:
+                f.write("my_func = lambda: 'col1'")
+
+            self._add_pyfile(pyfile_path)
+
+            class TestUDTF:
+                @staticmethod
+                def call_my_func() -> str:
+                    import my_pyfile
+
+                    return my_pyfile.my_func()
+
+                @staticmethod
+                def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                    return 
AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.data_type))
+
+                def eval(self, a):
+                    assert TestUDTF.call_my_func() == "col1"
+                    yield a,
+
+                def terminate(self):
+                    assert TestUDTF.call_my_func() == "col1"
+                    yield 100,
+
+            test_udtf = udtf(TestUDTF)
+            self.spark.udtf.register("test_udtf", test_udtf)
+
+            for i, df in enumerate(
+                [test_udtf(lit(10)), self.spark.sql("SELECT * FROM 
test_udtf(10)")]
+            ):
+                with self.subTest(query_no=i):
+                    assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
+                    assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+
+    def test_udtf_with_analyze_using_zipped_package(self):
+        with tempfile.TemporaryDirectory() as d:
+            package_path = os.path.join(d, "my_zipfile")
+            os.mkdir(package_path)
+            pyfile_path = os.path.join(package_path, "__init__.py")
+            with open(pyfile_path, "w") as f:
+                f.write("my_func = lambda: 'col1'")
+            shutil.make_archive(package_path, "zip", d, "my_zipfile")
+
+            self._add_pyfile(f"{package_path}.zip")
+
+            class TestUDTF:
+                @staticmethod
+                def call_my_func() -> str:
+                    import my_zipfile
+
+                    return my_zipfile.my_func()
+
+                @staticmethod
+                def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                    return 
AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.data_type))
+
+                def eval(self, a):
+                    assert TestUDTF.call_my_func() == "col1"
+                    yield a,
+
+                def terminate(self):
+                    assert TestUDTF.call_my_func() == "col1"
+                    yield 100,
+
+            test_udtf = udtf(TestUDTF)
+            self.spark.udtf.register("test_udtf", test_udtf)
+
+            for i, df in enumerate(
+                [test_udtf(lit(10)), self.spark.sql("SELECT * FROM 
test_udtf(10)")]
+            ):
+                with self.subTest(query_no=i):
+                    assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
+                    assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+
+    def _add_archive(self, path):
+        self.sc.addArchive(path)
+
+    def test_udtf_with_analyze_using_archive(self):
+        with tempfile.TemporaryDirectory() as d:
+            archive_path = os.path.join(d, "my_archive")
+            os.mkdir(archive_path)
+            pyfile_path = os.path.join(archive_path, "my_file.txt")
+            with open(pyfile_path, "w") as f:
+                f.write("col1")
+            shutil.make_archive(archive_path, "zip", d, "my_archive")
+
+            self._add_archive(f"{archive_path}.zip#my_files")
+
+            class TestUDTF:
+                @staticmethod
+                def read_my_archive() -> str:
+                    with open(
+                        os.path.join(
+                            SparkFiles.getRootDirectory(), "my_files", 
"my_archive", "my_file.txt"
+                        ),
+                        "r",
+                    ) as my_file:
+                        return my_file.read().strip()
+
+                @staticmethod
+                def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                    return 
AnalyzeResult(StructType().add(TestUDTF.read_my_archive(), a.data_type))
+
+                def eval(self, a):
+                    assert TestUDTF.read_my_archive() == "col1"
+                    yield a,
+
+                def terminate(self):
+                    assert TestUDTF.read_my_archive() == "col1"
+                    yield 100,
+
+            test_udtf = udtf(TestUDTF)
+            self.spark.udtf.register("test_udtf", test_udtf)
+
+            for i, df in enumerate(
+                [test_udtf(lit(10)), self.spark.sql("SELECT * FROM 
test_udtf(10)")]
+            ):
+                with self.subTest(query_no=i):
+                    assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
+                    assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+
+    def _add_file(self, path):
+        self.sc.addFile(path)
+
+    def test_udtf_with_analyze_using_file(self):
+        with tempfile.TemporaryDirectory() as d:
+            file_path = os.path.join(d, "my_file.txt")
+            with open(file_path, "w") as f:
+                f.write("col1")
+
+            self._add_file(file_path)
+
+            class TestUDTF:
+                @staticmethod
+                def read_my_file() -> str:
+                    with open(
+                        os.path.join(SparkFiles.getRootDirectory(), 
"my_file.txt"), "r"
+                    ) as my_file:
+                        return my_file.read().strip()
+
+                @staticmethod
+                def analyze(a: AnalyzeArgument) -> AnalyzeResult:
+                    return 
AnalyzeResult(StructType().add(TestUDTF.read_my_file(), a.data_type))
+
+                def eval(self, a):
+                    assert TestUDTF.read_my_file() == "col1"
+                    yield a,
+
+                def terminate(self):
+                    assert TestUDTF.read_my_file() == "col1"
+                    yield 100,
+
+            test_udtf = udtf(TestUDTF)
+            self.spark.udtf.register("test_udtf", test_udtf)
+
+            for i, df in enumerate(
+                [test_udtf(lit(10)), self.spark.sql("SELECT * FROM 
test_udtf(10)")]
+            ):
+                with self.subTest(query_no=i):
+                    assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
+                    assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
+
 
 class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index ea85d937705..44dcd8c892c 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -21,6 +21,7 @@ import sys
 import traceback
 from typing import List, IO
 
+from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkRuntimeError, PySparkValueError
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import (
@@ -33,7 +34,15 @@ from pyspark.serializers import (
 from pyspark.sql.types import _parse_datatype_json_string
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
 from pyspark.util import try_simplify_traceback
-from pyspark.worker import check_python_version, read_command, pickleSer, 
utf8_deserializer
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_spark_files,
+    utf8_deserializer,
+)
 
 
 def read_udtf(infile: IO) -> type:
@@ -87,6 +96,11 @@ def main(infile: IO, outfile: IO) -> None:
     """
     try:
         check_python_version(infile)
+        setup_spark_files(infile)
+        setup_broadcasts(infile)
+
+        _accumulatorRegistry.clear()
+
         handler = read_udtf(infile)
         args = read_arguments(infile)
 
@@ -122,6 +136,8 @@ def main(infile: IO, outfile: IO) -> None:
             print(traceback.format_exc(), file=sys.stderr)
         sys.exit(-1)
 
+    send_accumulator_updates(outfile)
+
     # check end of stream
     if read_int(infile) == SpecialLengths.END_OF_STREAM:
         write_int(SpecialLengths.END_OF_STREAM, outfile)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 95924757242..bfe788faf6d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -22,7 +22,6 @@ import os
 import sys
 import time
 from inspect import currentframe, getframeinfo, getfullargspec
-import importlib
 import json
 from typing import Iterator
 
@@ -37,10 +36,8 @@ import warnings
 import faulthandler
 
 from pyspark.accumulators import _accumulatorRegistry
-from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import BarrierTaskContext, TaskContext
-from pyspark.files import SparkFiles
 from pyspark.resource import ResourceInformation
 from pyspark.rdd import PythonEvalType
 from pyspark.serializers import (
@@ -67,9 +64,15 @@ from pyspark.sql.types import StructType, 
_parse_datatype_json_string
 from pyspark.util import fail_on_stopiteration, try_simplify_traceback
 from pyspark import shuffle
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError
-
-pickleSer = CPickleSerializer()
-utf8_deserializer = UTF8Deserializer()
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_spark_files,
+    utf8_deserializer,
+)
 
 
 def report_times(outfile, boot, init, finish):
@@ -79,20 +82,6 @@ def report_times(outfile, boot, init, finish):
     write_long(int(1000 * finish), outfile)
 
 
-def add_path(path):
-    # worker can be used, so do not add path multiple times
-    if path not in sys.path:
-        # overwrite system packages
-        sys.path.insert(1, path)
-
-
-def read_command(serializer, file):
-    command = serializer._read_with_length(file)
-    if isinstance(command, Broadcast):
-        command = serializer.loads(command.value)
-    return command
-
-
 def chain(f, g):
     """chain two functions together"""
     return lambda *a: g(f(*a))
@@ -970,21 +959,6 @@ def read_udfs(pickleSer, infile, eval_type):
     return func, None, ser, ser
 
 
-def check_python_version(infile):
-    """
-    Check the Python version between the running process and the one used to 
serialize the command.
-    """
-    version = utf8_deserializer.loads(infile)
-    if version != "%d.%d" % sys.version_info[:2]:
-        raise PySparkRuntimeError(
-            error_class="PYTHON_VERSION_MISMATCH",
-            message_parameters={
-                "worker_version": str(sys.version_info[:2]),
-                "driver_version": str(version),
-            },
-        )
-
-
 def main(infile, outfile):
     faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
     try:
@@ -1074,47 +1048,8 @@ def main(infile, outfile):
         shuffle.DiskBytesSpilled = 0
         _accumulatorRegistry.clear()
 
-        # fetch name of workdir
-        spark_files_dir = utf8_deserializer.loads(infile)
-        SparkFiles._root_directory = spark_files_dir
-        SparkFiles._is_running_on_worker = True
-
-        # fetch names of includes (*.zip and *.egg files) and construct 
PYTHONPATH
-        add_path(spark_files_dir)  # *.py files that were added will be copied 
here
-        num_python_includes = read_int(infile)
-        for _ in range(num_python_includes):
-            filename = utf8_deserializer.loads(infile)
-            add_path(os.path.join(spark_files_dir, filename))
-
-        importlib.invalidate_caches()
-
-        # fetch names and values of broadcast variables
-        needs_broadcast_decryption_server = read_bool(infile)
-        num_broadcast_variables = read_int(infile)
-        if needs_broadcast_decryption_server:
-            # read the decrypted data from a server in the jvm
-            port = read_int(infile)
-            auth_secret = utf8_deserializer.loads(infile)
-            (broadcast_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
-
-        for _ in range(num_broadcast_variables):
-            bid = read_long(infile)
-            if bid >= 0:
-                if needs_broadcast_decryption_server:
-                    read_bid = read_long(broadcast_sock_file)
-                    assert read_bid == bid
-                    _broadcastRegistry[bid] = 
Broadcast(sock_file=broadcast_sock_file)
-                else:
-                    path = utf8_deserializer.loads(infile)
-                    _broadcastRegistry[bid] = Broadcast(path=path)
-
-            else:
-                bid = -bid - 1
-                _broadcastRegistry.pop(bid)
-
-        if needs_broadcast_decryption_server:
-            broadcast_sock_file.write(b"1")
-            broadcast_sock_file.close()
+        setup_spark_files(infile)
+        setup_broadcasts(infile)
 
         _accumulatorRegistry.clear()
         eval_type = read_int(infile)
@@ -1178,9 +1113,7 @@ def main(infile, outfile):
 
     # Mark the beginning of the accumulators section of the output
     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-    write_int(len(_accumulatorRegistry), outfile)
-    for (aid, accum) in _accumulatorRegistry.items():
-        pickleSer._write_with_length((aid, accum._value), outfile)
+    send_accumulator_updates(outfile)
 
     # check end of stream
     if read_int(infile) == SpecialLengths.END_OF_STREAM:
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
new file mode 100644
index 00000000000..eab0daf8f59
--- /dev/null
+++ b/python/pyspark/worker_util.py
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Util functions for workers.
+"""
+import importlib
+import os
+import sys
+from typing import Any, IO
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.errors import PySparkRuntimeError
+from pyspark.files import SparkFiles
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_bool,
+    read_int,
+    read_long,
+    write_int,
+    FramedSerializer,
+    UTF8Deserializer,
+    CPickleSerializer,
+)
+
+pickleSer = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def add_path(path: str) -> None:
+    # worker can be used, so do not add path multiple times
+    if path not in sys.path:
+        # overwrite system packages
+        sys.path.insert(1, path)
+
+
+def read_command(serializer: FramedSerializer, file: IO) -> Any:
+    command = serializer._read_with_length(file)
+    if isinstance(command, Broadcast):
+        command = serializer.loads(command.value)
+    return command
+
+
+def check_python_version(infile: IO) -> None:
+    """
+    Check the Python version between the running process and the one used to 
serialize the command.
+    """
+    version = utf8_deserializer.loads(infile)
+    if version != "%d.%d" % sys.version_info[:2]:
+        raise PySparkRuntimeError(
+            error_class="PYTHON_VERSION_MISMATCH",
+            message_parameters={
+                "worker_version": str(sys.version_info[:2]),
+                "driver_version": str(version),
+            },
+        )
+
+
+def setup_spark_files(infile: IO) -> None:
+    """
+    Set up Spark files, archives, and pyfiles.
+    """
+    # fetch name of workdir
+    spark_files_dir = utf8_deserializer.loads(infile)
+    SparkFiles._root_directory = spark_files_dir
+    SparkFiles._is_running_on_worker = True
+
+    # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
+    add_path(spark_files_dir)  # *.py files that were added will be copied here
+    num_python_includes = read_int(infile)
+    for _ in range(num_python_includes):
+        filename = utf8_deserializer.loads(infile)
+        add_path(os.path.join(spark_files_dir, filename))
+
+    importlib.invalidate_caches()
+
+
+def setup_broadcasts(infile: IO) -> None:
+    """
+    Set up broadcasted variables.
+    """
+    # fetch names and values of broadcast variables
+    needs_broadcast_decryption_server = read_bool(infile)
+    num_broadcast_variables = read_int(infile)
+    if needs_broadcast_decryption_server:
+        # read the decrypted data from a server in the jvm
+        port = read_int(infile)
+        auth_secret = utf8_deserializer.loads(infile)
+        (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+
+    for _ in range(num_broadcast_variables):
+        bid = read_long(infile)
+        if bid >= 0:
+            if needs_broadcast_decryption_server:
+                read_bid = read_long(broadcast_sock_file)
+                assert read_bid == bid
+                _broadcastRegistry[bid] = 
Broadcast(sock_file=broadcast_sock_file)
+            else:
+                path = utf8_deserializer.loads(infile)
+                _broadcastRegistry[bid] = Broadcast(path=path)
+
+        else:
+            bid = -bid - 1
+            _broadcastRegistry.pop(bid)
+
+    if needs_broadcast_decryption_server:
+        broadcast_sock_file.write(b"1")
+        broadcast_sock_file.close()
+
+
+def send_accumulator_updates(outfile: IO) -> None:
+    """
+    Send the accumulator updates back to JVM.
+    """
+    write_int(len(_accumulatorRegistry), outfile)
+    for (aid, accum) in _accumulatorRegistry.items():
+        pickleSer._write_with_length((aid, accum._value), outfile)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 7d9fdb3bf48..9dae874e3ed 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql.execution.python
 
 import java.io.DataOutputStream
 import java.net.Socket
-import java.nio.charset.StandardCharsets.UTF_8
 
 import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.Unpickler
 
 import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonWorkerUtils}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.GenericArrayData
@@ -126,8 +125,6 @@ object PythonUDTFRunner {
     }
     dataOut.writeInt(udtf.func.command.length)
     dataOut.write(udtf.func.command.toArray)
-    val schemaBytes = udtf.elementSchema.json.getBytes(UTF_8)
-    dataOut.writeInt(schemaBytes.length)
-    dataOut.write(schemaBytes)
+    PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index d683f865dcd..05239d8d164 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -26,8 +26,8 @@ import scala.collection.JavaConverters._
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.{SparkEnv, SparkException}
-import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonRDD, 
SpecialLengths}
+import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
+import org.apache.spark.api.python.{PythonEvalType, PythonFunction, 
PythonWorkerUtils, SpecialLengths}
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python._
 import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
@@ -173,12 +173,19 @@ object UserDefinedPythonTableFunction {
     val bufferSize: Int = env.conf.get(BUFFER_SIZE)
     val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
     val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
     val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
 
+    val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
     val envVars = new HashMap[String, String](func.envVars)
     val pythonExec = func.pythonExec
     val pythonVer = func.pythonVer
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    val broadcastVars = func.broadcastVars.asScala.toSeq
+    val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset())
 
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
     if (reuseWorker) {
       envVars.put("SPARK_REUSE_WORKER", "1")
     }
@@ -188,6 +195,8 @@ object UserDefinedPythonTableFunction {
     envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
     envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
 
+    envVars.put("SPARK_JOB_ARTIFACT_UUID", 
jobArtifactUUID.getOrElse("default"))
+
     EvaluatePython.registerPicklers()
     val pickler = new Pickler(/* useMemo = */ true,
       /* valueCompare = */ false)
@@ -200,8 +209,9 @@ object UserDefinedPythonTableFunction {
         new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, 
bufferSize))
       val dataIn = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
 
-      // Python version of driver
-      PythonRDD.writeUTF(pythonVer, dataOut)
+      PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+      PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, 
dataOut)
+      PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
 
       // Send Python UDTF
       dataOut.writeInt(func.command.length)
@@ -210,7 +220,7 @@ object UserDefinedPythonTableFunction {
       // Send arguments
       dataOut.writeInt(exprs.length)
       exprs.zip(tableArgs).foreach { case (expr, is_table) =>
-        PythonRDD.writeUTF(expr.dataType.json, dataOut)
+        PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut)
         if (expr.foldable) {
           dataOut.writeBoolean(true)
           val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), 
expr.dataType))
@@ -240,6 +250,9 @@ object UserDefinedPythonTableFunction {
           throw 
QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
       }
 
+      PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn)
+      Option(func.accumulator).foreach(_.merge(maybeAccumulator.get))
+
       dataIn.readInt() match {
         case SpecialLengths.END_OF_STREAM if reuseWorker =>
           env.releasePythonWorker(pythonExec, workerModule, 
envVars.asScala.toMap, worker)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to