This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 71d59a85081 [SPARK-43768][PYTHON][CONNECT] Python dependency 
management support in Python Spark Connect
71d59a85081 is described below

commit 71d59a85081f20cd179f5282e19aebcefa59174b
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu May 25 20:03:37 2023 +0900

    [SPARK-43768][PYTHON][CONNECT] Python dependency management support in 
Python Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add the support of archive (`.zip`, `.jar`, `.tar.gz`, 
`.tgz`, or `.tar` files) in `SparkSession.addArtifacts` so we can support 
Python dependency management in Python Spark Connect.
    
    ### Why are the changes needed?
    
    In order for end users to add the dependencies and archive files in Python 
Spark Connect client.
    
    This PR enables the Python dependency management 
(https://www.databricks.com/blog/2020/12/22/how-to-manage-python-dependencies-in-pyspark.html)
 usecase in Spark Connect.
    
    See below how to do this with Spark Connect Python client:
    
    #### Precondition
    
    Assume that we have a Spark Connect server already running, e.g., by:
    
    ```bash
    ./sbin/start-connect-server.sh --jars `ls 
connector/connect/server/target/**/spark-connect*SNAPSHOT.jar` --master 
"local-cluster[2,2,1024]"
    ```
    
    and assume that you already have a dev env:
    
    ```bash
    # Notice that you should install `conda-pack`.
    conda create -y -n pyspark_conda_env -c conda-forge conda-pack python=3.9
    conda activate pyspark_conda_env
    pip install --upgrade -r dev/requirements.txt
    ```
    
    #### Dependency management
    
    ```python
    ./bin/pyspark --remote "sc://localhost:15002"
    ```
    
    ```python
    import conda_pack
    import os
    # Pack the current environment ('pyspark_conda_env') to 
'pyspark_conda_env.tar.gz'.
    # Or you can run 'conda pack' in your shell.
    conda_pack.pack()
    
spark.addArtifact(f"{os.environ.get('CONDA_DEFAULT_ENV')}.tar.gz#environment", 
archive=True)
    spark.conf.set("spark.sql.execution.pyspark.python", 
"environment/bin/python")
    # From now on, Python workers on executors use `pyspark_conda_env` Conda 
environment.
    ```
    
    Run your Python UDFs
    
    ```python
    import pandas as pd
    from pyspark.sql.functions import pandas_udf
    
    pandas_udf("long")
    def plug_one(s: pd.Series) -> pd.Series:
        return s + 1
    
    spark.range(10).select(plug_one("id")).show()
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, it adds the support of archive (`.zip`, `.jar`, `.tar.gz`, `.tgz`, or 
`.tar` files) in `SparkSession.addArtifacts`.
    
    ### How was this patch tested?
    
    Manually tested as described above, and added a unittest.
    
    Also, manually tested with `local-cluster` mode with the code below:
    
    Also verified via:
    
    ```python
    import sys
    from pyspark.sql.functions import udf
    
    spark.range(1).select(udf(lambda x: 
sys.executable)("id")).show(truncate=False)
    ```
    ```
    +----------------------------------------------------------------+
    |<lambda>(id)                                                    |
    +----------------------------------------------------------------+
    |/.../spark/work/app-20230524132024-0000/1/environment/bin/python|
    +----------------------------------------------------------------+
    ```
    
    Closes #41292 from HyukjinKwon/python-addArchive.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../artifact/SparkConnectArtifactManager.scala     | 22 ++++++---
 .../service/SparkConnectAddArtifactsHandler.scala  | 19 +++++++-
 .../connect/artifact/ArtifactManagerSuite.scala    | 12 ++---
 python/pyspark/sql/connect/client/artifact.py      | 52 +++++++++++++++++-----
 python/pyspark/sql/connect/client/core.py          |  4 +-
 python/pyspark/sql/connect/session.py              | 11 ++++-
 .../sql/tests/connect/client/test_artifact.py      | 44 +++++++++++++++---
 7 files changed, 130 insertions(+), 34 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index 7a36c46c672..604108f68d2 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.sql.connect.artifact
 
+import java.io.File
 import java.net.{URL, URLClassLoader}
 import java.nio.file.{Files, Path, Paths, StandardCopyOption}
 import java.util.concurrent.CopyOnWriteArrayList
+import javax.ws.rs.core.UriBuilder
 
 import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
@@ -99,16 +101,17 @@ class SparkConnectArtifactManager private[connect] {
   private[connect] def addArtifact(
       sessionHolder: SessionHolder,
       remoteRelativePath: Path,
-      serverLocalStagingPath: Path): Unit = {
+      serverLocalStagingPath: Path,
+      fragment: Option[String]): Unit = {
     require(!remoteRelativePath.isAbsolute)
-    if (remoteRelativePath.startsWith("cache/")) {
+    if (remoteRelativePath.startsWith(s"cache${File.separator}")) {
       val tmpFile = serverLocalStagingPath.toFile
       Utils.tryWithSafeFinallyAndFailureCallbacks {
         val blockManager = sessionHolder.session.sparkContext.env.blockManager
         val blockId = CacheId(
           userId = sessionHolder.userId,
           sessionId = sessionHolder.sessionId,
-          hash = remoteRelativePath.toString.stripPrefix("cache/"))
+          hash = 
remoteRelativePath.toString.stripPrefix(s"cache${File.separator}"))
         val updater = blockManager.TempFileBasedBlockStoreUpdater(
           blockId = blockId,
           level = StorageLevel.MEMORY_AND_DISK_SER,
@@ -118,9 +121,10 @@ class SparkConnectArtifactManager private[connect] {
           tellMaster = false)
         updater.save()
       }(catchBlock = { tmpFile.delete() })
-    } else if (remoteRelativePath.startsWith("classes/")) {
+    } else if (remoteRelativePath.startsWith(s"classes${File.separator}")) {
       // Move class files to common location (shared among all users)
-      val target = 
classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/"))
+      val target = classArtifactDir.resolve(
+        remoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
       Files.createDirectories(target.getParent)
       // Allow overwriting class files to capture updates to classes.
       Files.move(serverLocalStagingPath, target, 
StandardCopyOption.REPLACE_EXISTING)
@@ -135,17 +139,21 @@ class SparkConnectArtifactManager private[connect] {
             s"Jars cannot be overwritten.")
       }
       Files.move(serverLocalStagingPath, target)
-      if (remoteRelativePath.startsWith("jars/")) {
+      if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
         // Adding Jars to the underlying spark context (visible to all users)
         
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
         jarsList.add(target)
-      } else if (remoteRelativePath.startsWith("pyfiles/")) {
+      } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
         sessionHolder.session.sparkContext.addFile(target.toString)
         val stringRemotePath = remoteRelativePath.toString
         if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
             ".egg") || stringRemotePath.endsWith(".jar")) {
           pythonIncludeList.add(target.getFileName.toString)
         }
+      } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
+        val canonicalUri =
+          
fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
+        sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
       }
     }
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index 99e92e42fff..f8bdb58ed85 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.sql.connect.service
 
+import java.io.File
 import java.nio.file.{Files, Path, Paths}
 import java.util.zip.{CheckedOutputStream, CRC32}
 
@@ -85,7 +86,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: 
StreamObserver[AddAr
   }
 
   protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): 
Unit = {
-    artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath)
+    artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath, 
artifact.fragment)
   }
 
   /**
@@ -148,7 +149,21 @@ class SparkConnectAddArtifactsHandler(val 
responseObserver: StreamObserver[AddAr
    * Handles rebuilding an artifact from bytes sent over the wire.
    */
   class StagedArtifact(val name: String) {
-    val path: Path = Paths.get(name)
+    // Workaround to keep the fragment.
+    val (canonicalFileName: String, fragment: Option[String]) =
+      if (name.startsWith(s"archives${File.separator}")) {
+        val splits = name.split("#")
+        assert(splits.length <= 2, "'#' in the path is not supported for 
adding an archive.")
+        if (splits.length == 2) {
+          (splits(0), Some(splits(1)))
+        } else {
+          (splits(0), None)
+        }
+      } else {
+        (name, None)
+      }
+
+    val path: Path = Paths.get(canonicalFileName)
     val stagedPath: Path = stagingDir.resolve(path)
 
     Files.createDirectories(stagedPath.getParent)
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
index 291eadb07c4..b87c6742bdc 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
@@ -48,7 +48,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
     val stagingPath = copyDir.resolve("smallJar.jar")
     val remotePath = Paths.get("jars/smallJar.jar")
-    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
+    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
 
     val jarList = spark.sparkContext.listJars()
     assert(jarList.exists(_.contains(remotePath.toString)))
@@ -60,7 +60,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     val stagingPath = copyDir.resolve("smallClassFile.class")
     val remotePath = Paths.get("classes/smallClassFile.class")
     assert(stagingPath.toFile.exists())
-    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
+    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
 
     val classFileDirectory = artifactManager.classArtifactDir
     val movedClassFile = 
classFileDirectory.resolve("smallClassFile.class").toFile
@@ -73,7 +73,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     val stagingPath = copyDir.resolve("Hello.class")
     val remotePath = Paths.get("classes/Hello.class")
     assert(stagingPath.toFile.exists())
-    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
+    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
 
     val classFileDirectory = artifactManager.classArtifactDir
     val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
@@ -96,7 +96,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
     val stagingPath = copyDir.resolve("Hello.class")
     val remotePath = Paths.get("classes/Hello.class")
     assert(stagingPath.toFile.exists())
-    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
+    artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)
 
     val classFileDirectory = artifactManager.classArtifactDir
     val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
@@ -123,7 +123,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
       val blockManager = spark.sparkContext.env.blockManager
       val blockId = CacheId(session.userId, session.sessionId, "abc")
       try {
-        artifactManager.addArtifact(session, remotePath, stagingPath)
+        artifactManager.addArtifact(session, remotePath, stagingPath, None)
         val bytes = blockManager.getLocalBytes(blockId)
         assert(bytes.isDefined)
         val readback = new String(bytes.get.toByteBuffer().array(), 
StandardCharsets.UTF_8)
@@ -141,7 +141,7 @@ class ArtifactManagerSuite extends SharedSparkSession with 
ResourceHelper {
       Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
       val session = sessionHolder()
       val remotePath = Paths.get("pyfiles/abc.zip")
-      artifactManager.addArtifact(session, remotePath, stagingPath)
+      artifactManager.addArtifact(session, remotePath, stagingPath, None)
       assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip"))
     }
   }
diff --git a/python/pyspark/sql/connect/client/artifact.py 
b/python/pyspark/sql/connect/client/artifact.py
index f06277e5068..64f89119e4f 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -38,6 +38,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
 
 JAR_PREFIX: str = "jars"
 PYFILE_PREFIX: str = "pyfiles"
+ARCHIVE_PREFIX: str = "archives"
 
 
 class LocalData(metaclass=abc.ABCMeta):
@@ -102,6 +103,10 @@ def new_pyfile_artifact(file_name: str, storage: 
LocalData) -> Artifact:
     return _new_artifact(PYFILE_PREFIX, "", file_name, storage)
 
 
+def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact:
+    return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage)
+
+
 def _new_artifact(
     prefix: str, required_suffix: str, file_name: str, storage: LocalData
 ) -> Artifact:
@@ -136,12 +141,16 @@ class ArtifactManager:
         self._stub = grpc_lib.SparkConnectServiceStub(channel)
         self._session_id = session_id
 
-    def _parse_artifacts(self, path_or_uri: str, pyfile: bool) -> 
List[Artifact]:
+    def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) 
-> List[Artifact]:
         # Currently only local files with .jar extension is supported.
-        uri = path_or_uri
-        if urlparse(path_or_uri).scheme == "":  # Is path?
-            uri = Path(path_or_uri).absolute().as_uri()
-        parsed = urlparse(uri)
+        parsed = urlparse(path_or_uri)
+        # Check if it is a file from the scheme
+        if parsed.scheme == "":
+            # Similar with Utils.resolveURI.
+            fragment = parsed.fragment
+            parsed = 
urlparse(Path(url2pathname(parsed.path)).absolute().as_uri())
+            parsed = parsed._replace(fragment=fragment)
+
         if parsed.scheme == "file":
             local_path = url2pathname(parsed.path)
             name = Path(local_path).name
@@ -154,16 +163,37 @@ class ArtifactManager:
                 sys.path.insert(1, local_path)
                 artifact = new_pyfile_artifact(name, LocalFile(local_path))
                 importlib.invalidate_caches()
+            elif archive and (
+                name.endswith(".zip")
+                or name.endswith(".jar")
+                or name.endswith(".tar.gz")
+                or name.endswith(".tgz")
+                or name.endswith(".tar")
+            ):
+                assert any(name.endswith(s) for s in (".zip", ".jar", 
".tar.gz", ".tgz", ".tar"))
+
+                if parsed.fragment != "":
+                    # Minimal fix for the workaround of fragment handling in 
URI.
+                    # This has a limitation - hash(#) in the file name would 
not work.
+                    if "#" in local_path:
+                        raise ValueError("'#' in the path is not supported for 
adding an archive.")
+                    name = f"{name}#{parsed.fragment}"
+
+                artifact = new_archive_artifact(name, LocalFile(local_path))
             elif name.endswith(".jar"):
                 artifact = new_jar_artifact(name, LocalFile(local_path))
             else:
                 raise RuntimeError(f"Unsupported file format: {local_path}")
             return [artifact]
-        raise RuntimeError(f"Unsupported scheme: {uri}")
+        raise RuntimeError(f"Unsupported scheme: {parsed.scheme}")
 
-    def _create_requests(self, *path: str, pyfile: bool) -> 
Iterator[proto.AddArtifactsRequest]:
+    def _create_requests(
+        self, *path: str, pyfile: bool, archive: bool
+    ) -> Iterator[proto.AddArtifactsRequest]:
         """Separated for the testing purpose."""
-        return self._add_artifacts(chain(*(self._parse_artifacts(p, 
pyfile=pyfile) for p in path)))
+        return self._add_artifacts(
+            chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) 
for p in path))
+        )
 
     def _retrieve_responses(
         self, requests: Iterator[proto.AddArtifactsRequest]
@@ -171,12 +201,14 @@ class ArtifactManager:
         """Separated for the testing purpose."""
         return self._stub.AddArtifacts(requests)
 
-    def add_artifacts(self, *path: str, pyfile: bool) -> None:
+    def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
         """
         Add a single artifact to the session.
         Currently only local files with .jar extension is supported.
         """
-        requests: Iterator[proto.AddArtifactsRequest] = 
self._create_requests(*path, pyfile=pyfile)
+        requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
+            *path, pyfile=pyfile, archive=archive
+        )
         response: proto.AddArtifactsResponse = 
self._retrieve_responses(requests)
         summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
 
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index e93d9b5c494..544ed5d4183 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1237,8 +1237,8 @@ class SparkConnectClient(object):
         else:
             raise SparkConnectGrpcException(str(rpc_error)) from None
 
-    def add_artifacts(self, *path: str, pyfile: bool) -> None:
-        self._artifact_manager.add_artifacts(*path, pyfile=pyfile)
+    def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
+        self._artifact_manager.add_artifacts(*path, pyfile=pyfile, 
archive=archive)
 
 
 class RetryState:
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 3341b88eded..7932ab54081 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -601,7 +601,7 @@ class SparkSession:
         """
         return self._client
 
-    def addArtifacts(self, *path: str, pyfile: bool = False) -> None:
+    def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = 
False) -> None:
         """
         Add artifact(s) to the client session. Currently only local files are 
supported.
 
@@ -613,8 +613,15 @@ class SparkSession:
             Artifact's URIs to add.
         pyfile : bool
             Whether to add them as Python dependencies such as .py, .egg, .zip 
or .jar files.
+            The pyfiles are directly inserted into the path when executing 
Python functions
+            in executors.
+        archive : bool
+            Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, 
or .tar files.
+            The archives are unpacked on the executor side automatically.
         """
-        self._client.add_artifacts(*path, pyfile=pyfile)
+        if pyfile and archive:
+            raise ValueError("'pyfile' and 'archive' cannot be True together.")
+        self._client.add_artifacts(*path, pyfile=pyfile, archive=archive)
 
     addArtifact = addArtifacts
 
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py 
b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 73f47486bab..2bff3fd5bc4 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -21,6 +21,7 @@ import os
 
 from pyspark.testing.connectutils import ReusedConnectTestCase, 
should_test_connect
 from pyspark.testing.utils import SPARK_HOME
+from pyspark import SparkFiles
 from pyspark.sql.functions import udf
 
 if should_test_connect:
@@ -48,7 +49,7 @@ class ArtifactTests(ReusedConnectTestCase):
         file_name = "smallJar"
         small_jar_path = os.path.join(self.artifact_file_path, 
f"{file_name}.jar")
         response = self.artifact_manager._retrieve_responses(
-            self.artifact_manager._create_requests(small_jar_path, 
pyfile=False)
+            self.artifact_manager._create_requests(small_jar_path, 
pyfile=False, archive=False)
         )
         
self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar"))
 
@@ -57,7 +58,9 @@ class ArtifactTests(ReusedConnectTestCase):
         small_jar_path = os.path.join(self.artifact_file_path, 
f"{file_name}.jar")
         small_jar_crc_path = os.path.join(self.artifact_crc_path, 
f"{file_name}.txt")
 
-        requests = list(self.artifact_manager._create_requests(small_jar_path, 
pyfile=False))
+        requests = list(
+            self.artifact_manager._create_requests(small_jar_path, 
pyfile=False, archive=False)
+        )
         self.assertEqual(len(requests), 1)
 
         request = requests[0]
@@ -79,7 +82,9 @@ class ArtifactTests(ReusedConnectTestCase):
         large_jar_path = os.path.join(self.artifact_file_path, 
f"{file_name}.jar")
         large_jar_crc_path = os.path.join(self.artifact_crc_path, 
f"{file_name}.txt")
 
-        requests = list(self.artifact_manager._create_requests(large_jar_path, 
pyfile=False))
+        requests = list(
+            self.artifact_manager._create_requests(large_jar_path, 
pyfile=False, archive=False)
+        )
         # Expected chunks = roundUp( file_size / chunk_size) = 12
         # File size of `junitLargeJar.jar` is 384581 bytes.
         large_jar_size = os.path.getsize(large_jar_path)
@@ -111,7 +116,9 @@ class ArtifactTests(ReusedConnectTestCase):
         small_jar_crc_path = os.path.join(self.artifact_crc_path, 
f"{file_name}.txt")
 
         requests = list(
-            self.artifact_manager._create_requests(small_jar_path, 
small_jar_path, pyfile=False)
+            self.artifact_manager._create_requests(
+                small_jar_path, small_jar_path, pyfile=False, archive=False
+            )
         )
         # Single request containing 2 artifacts.
         self.assertEqual(len(requests), 1)
@@ -147,7 +154,12 @@ class ArtifactTests(ReusedConnectTestCase):
 
         requests = list(
             self.artifact_manager._create_requests(
-                small_jar_path, large_jar_path, small_jar_path, 
small_jar_path, pyfile=False
+                small_jar_path,
+                large_jar_path,
+                small_jar_path,
+                small_jar_path,
+                pyfile=False,
+                archive=False,
             )
         )
         # There are a total of 14 requests.
@@ -237,6 +249,28 @@ class ArtifactTests(ReusedConnectTestCase):
             self.spark.addArtifacts(f"{package_path}.zip", pyfile=True)
             
self.assertEqual(self.spark.range(1).select(func("id")).first()[0], 5)
 
+    def test_add_archive(self):
+        with tempfile.TemporaryDirectory() as d:
+            archive_path = os.path.join(d, "my_archive")
+            os.mkdir(archive_path)
+            pyfile_path = os.path.join(archive_path, "my_file.txt")
+            with open(pyfile_path, "w") as f:
+                _ = f.write("hello world!")
+            shutil.make_archive(archive_path, "zip", d, "my_archive")
+
+            @udf("string")
+            def func(x):
+                with open(
+                    os.path.join(
+                        SparkFiles.getRootDirectory(), "my_files", 
"my_archive", "my_file.txt"
+                    ),
+                    "r",
+                ) as my_file:
+                    return my_file.read().strip()
+
+            self.spark.addArtifacts(f"{archive_path}.zip#my_files", 
archive=True)
+            
self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello 
world!")
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.client.test_artifact import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to