This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 71d59a85081 [SPARK-43768][PYTHON][CONNECT] Python dependency management support in Python Spark Connect 71d59a85081 is described below commit 71d59a85081f20cd179f5282e19aebcefa59174b Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu May 25 20:03:37 2023 +0900 [SPARK-43768][PYTHON][CONNECT] Python dependency management support in Python Spark Connect ### What changes were proposed in this pull request? This PR proposes to add the support of archive (`.zip`, `.jar`, `.tar.gz`, `.tgz`, or `.tar` files) in `SparkSession.addArtifacts` so we can support Python dependency management in Python Spark Connect. ### Why are the changes needed? In order for end users to add the dependencies and archive files in Python Spark Connect client. This PR enables the Python dependency management (https://www.databricks.com/blog/2020/12/22/how-to-manage-python-dependencies-in-pyspark.html) usecase in Spark Connect. See below how to do this with Spark Connect Python client: #### Precondition Assume that we have a Spark Connect server already running, e.g., by: ```bash ./sbin/start-connect-server.sh --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar` --master "local-cluster[2,2,1024]" ``` and assume that you already have a dev env: ```bash # Notice that you should install `conda-pack`. conda create -y -n pyspark_conda_env -c conda-forge conda-pack python=3.9 conda activate pyspark_conda_env pip install --upgrade -r dev/requirements.txt ``` #### Dependency management ```python ./bin/pyspark --remote "sc://localhost:15002" ``` ```python import conda_pack import os # Pack the current environment ('pyspark_conda_env') to 'pyspark_conda_env.tar.gz'. # Or you can run 'conda pack' in your shell. conda_pack.pack() spark.addArtifact(f"{os.environ.get('CONDA_DEFAULT_ENV')}.tar.gz#environment", archive=True) spark.conf.set("spark.sql.execution.pyspark.python", "environment/bin/python") # From now on, Python workers on executors use `pyspark_conda_env` Conda environment. ``` Run your Python UDFs ```python import pandas as pd from pyspark.sql.functions import pandas_udf pandas_udf("long") def plug_one(s: pd.Series) -> pd.Series: return s + 1 spark.range(10).select(plug_one("id")).show() ``` ### Does this PR introduce _any_ user-facing change? Yes, it adds the support of archive (`.zip`, `.jar`, `.tar.gz`, `.tgz`, or `.tar` files) in `SparkSession.addArtifacts`. ### How was this patch tested? Manually tested as described above, and added a unittest. Also, manually tested with `local-cluster` mode with the code below: Also verified via: ```python import sys from pyspark.sql.functions import udf spark.range(1).select(udf(lambda x: sys.executable)("id")).show(truncate=False) ``` ``` +----------------------------------------------------------------+ |<lambda>(id) | +----------------------------------------------------------------+ |/.../spark/work/app-20230524132024-0000/1/environment/bin/python| +----------------------------------------------------------------+ ``` Closes #41292 from HyukjinKwon/python-addArchive. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../artifact/SparkConnectArtifactManager.scala | 22 ++++++--- .../service/SparkConnectAddArtifactsHandler.scala | 19 +++++++- .../connect/artifact/ArtifactManagerSuite.scala | 12 ++--- python/pyspark/sql/connect/client/artifact.py | 52 +++++++++++++++++----- python/pyspark/sql/connect/client/core.py | 4 +- python/pyspark/sql/connect/session.py | 11 ++++- .../sql/tests/connect/client/test_artifact.py | 44 +++++++++++++++--- 7 files changed, 130 insertions(+), 34 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 7a36c46c672..604108f68d2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.connect.artifact +import java.io.File import java.net.{URL, URLClassLoader} import java.nio.file.{Files, Path, Paths, StandardCopyOption} import java.util.concurrent.CopyOnWriteArrayList +import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -99,16 +101,17 @@ class SparkConnectArtifactManager private[connect] { private[connect] def addArtifact( sessionHolder: SessionHolder, remoteRelativePath: Path, - serverLocalStagingPath: Path): Unit = { + serverLocalStagingPath: Path, + fragment: Option[String]): Unit = { require(!remoteRelativePath.isAbsolute) - if (remoteRelativePath.startsWith("cache/")) { + if (remoteRelativePath.startsWith(s"cache${File.separator}")) { val tmpFile = serverLocalStagingPath.toFile Utils.tryWithSafeFinallyAndFailureCallbacks { val blockManager = sessionHolder.session.sparkContext.env.blockManager val blockId = CacheId( userId = sessionHolder.userId, sessionId = sessionHolder.sessionId, - hash = remoteRelativePath.toString.stripPrefix("cache/")) + hash = remoteRelativePath.toString.stripPrefix(s"cache${File.separator}")) val updater = blockManager.TempFileBasedBlockStoreUpdater( blockId = blockId, level = StorageLevel.MEMORY_AND_DISK_SER, @@ -118,9 +121,10 @@ class SparkConnectArtifactManager private[connect] { tellMaster = false) updater.save() }(catchBlock = { tmpFile.delete() }) - } else if (remoteRelativePath.startsWith("classes/")) { + } else if (remoteRelativePath.startsWith(s"classes${File.separator}")) { // Move class files to common location (shared among all users) - val target = classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/")) + val target = classArtifactDir.resolve( + remoteRelativePath.toString.stripPrefix(s"classes${File.separator}")) Files.createDirectories(target.getParent) // Allow overwriting class files to capture updates to classes. Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING) @@ -135,17 +139,21 @@ class SparkConnectArtifactManager private[connect] { s"Jars cannot be overwritten.") } Files.move(serverLocalStagingPath, target) - if (remoteRelativePath.startsWith("jars/")) { + if (remoteRelativePath.startsWith(s"jars${File.separator}")) { // Adding Jars to the underlying spark context (visible to all users) sessionHolder.session.sessionState.resourceLoader.addJar(target.toString) jarsList.add(target) - } else if (remoteRelativePath.startsWith("pyfiles/")) { + } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { sessionHolder.session.sparkContext.addFile(target.toString) val stringRemotePath = remoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { pythonIncludeList.add(target.getFileName.toString) } + } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) { + val canonicalUri = + fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri) + sessionHolder.session.sparkContext.addArchive(canonicalUri.toString) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala index 99e92e42fff..f8bdb58ed85 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connect.service +import java.io.File import java.nio.file.{Files, Path, Paths} import java.util.zip.{CheckedOutputStream, CRC32} @@ -85,7 +86,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr } protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = { - artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath) + artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath, artifact.fragment) } /** @@ -148,7 +149,21 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr * Handles rebuilding an artifact from bytes sent over the wire. */ class StagedArtifact(val name: String) { - val path: Path = Paths.get(name) + // Workaround to keep the fragment. + val (canonicalFileName: String, fragment: Option[String]) = + if (name.startsWith(s"archives${File.separator}")) { + val splits = name.split("#") + assert(splits.length <= 2, "'#' in the path is not supported for adding an archive.") + if (splits.length == 2) { + (splits(0), Some(splits(1))) + } else { + (splits(0), None) + } + } else { + (name, None) + } + + val path: Path = Paths.get(canonicalFileName) val stagedPath: Path = stagingDir.resolve(path) Files.createDirectories(stagedPath.getParent) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 291eadb07c4..b87c6742bdc 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -48,7 +48,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) val stagingPath = copyDir.resolve("smallJar.jar") val remotePath = Paths.get("jars/smallJar.jar") - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val jarList = spark.sparkContext.listJars() assert(jarList.exists(_.contains(remotePath.toString))) @@ -60,7 +60,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("smallClassFile.class") val remotePath = Paths.get("classes/smallClassFile.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val classFileDirectory = artifactManager.classArtifactDir val movedClassFile = classFileDirectory.resolve("smallClassFile.class").toFile @@ -73,7 +73,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val classFileDirectory = artifactManager.classArtifactDir val movedClassFile = classFileDirectory.resolve("Hello.class").toFile @@ -96,7 +96,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val classFileDirectory = artifactManager.classArtifactDir val movedClassFile = classFileDirectory.resolve("Hello.class").toFile @@ -123,7 +123,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val blockManager = spark.sparkContext.env.blockManager val blockId = CacheId(session.userId, session.sessionId, "abc") try { - artifactManager.addArtifact(session, remotePath, stagingPath) + artifactManager.addArtifact(session, remotePath, stagingPath, None) val bytes = blockManager.getLocalBytes(blockId) assert(bytes.isDefined) val readback = new String(bytes.get.toByteBuffer().array(), StandardCharsets.UTF_8) @@ -141,7 +141,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val session = sessionHolder() val remotePath = Paths.get("pyfiles/abc.zip") - artifactManager.addArtifact(session, remotePath, stagingPath) + artifactManager.addArtifact(session, remotePath, stagingPath, None) assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip")) } } diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index f06277e5068..64f89119e4f 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -38,6 +38,7 @@ import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib JAR_PREFIX: str = "jars" PYFILE_PREFIX: str = "pyfiles" +ARCHIVE_PREFIX: str = "archives" class LocalData(metaclass=abc.ABCMeta): @@ -102,6 +103,10 @@ def new_pyfile_artifact(file_name: str, storage: LocalData) -> Artifact: return _new_artifact(PYFILE_PREFIX, "", file_name, storage) +def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact: + return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage) + + def _new_artifact( prefix: str, required_suffix: str, file_name: str, storage: LocalData ) -> Artifact: @@ -136,12 +141,16 @@ class ArtifactManager: self._stub = grpc_lib.SparkConnectServiceStub(channel) self._session_id = session_id - def _parse_artifacts(self, path_or_uri: str, pyfile: bool) -> List[Artifact]: + def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) -> List[Artifact]: # Currently only local files with .jar extension is supported. - uri = path_or_uri - if urlparse(path_or_uri).scheme == "": # Is path? - uri = Path(path_or_uri).absolute().as_uri() - parsed = urlparse(uri) + parsed = urlparse(path_or_uri) + # Check if it is a file from the scheme + if parsed.scheme == "": + # Similar with Utils.resolveURI. + fragment = parsed.fragment + parsed = urlparse(Path(url2pathname(parsed.path)).absolute().as_uri()) + parsed = parsed._replace(fragment=fragment) + if parsed.scheme == "file": local_path = url2pathname(parsed.path) name = Path(local_path).name @@ -154,16 +163,37 @@ class ArtifactManager: sys.path.insert(1, local_path) artifact = new_pyfile_artifact(name, LocalFile(local_path)) importlib.invalidate_caches() + elif archive and ( + name.endswith(".zip") + or name.endswith(".jar") + or name.endswith(".tar.gz") + or name.endswith(".tgz") + or name.endswith(".tar") + ): + assert any(name.endswith(s) for s in (".zip", ".jar", ".tar.gz", ".tgz", ".tar")) + + if parsed.fragment != "": + # Minimal fix for the workaround of fragment handling in URI. + # This has a limitation - hash(#) in the file name would not work. + if "#" in local_path: + raise ValueError("'#' in the path is not supported for adding an archive.") + name = f"{name}#{parsed.fragment}" + + artifact = new_archive_artifact(name, LocalFile(local_path)) elif name.endswith(".jar"): artifact = new_jar_artifact(name, LocalFile(local_path)) else: raise RuntimeError(f"Unsupported file format: {local_path}") return [artifact] - raise RuntimeError(f"Unsupported scheme: {uri}") + raise RuntimeError(f"Unsupported scheme: {parsed.scheme}") - def _create_requests(self, *path: str, pyfile: bool) -> Iterator[proto.AddArtifactsRequest]: + def _create_requests( + self, *path: str, pyfile: bool, archive: bool + ) -> Iterator[proto.AddArtifactsRequest]: """Separated for the testing purpose.""" - return self._add_artifacts(chain(*(self._parse_artifacts(p, pyfile=pyfile) for p in path))) + return self._add_artifacts( + chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) for p in path)) + ) def _retrieve_responses( self, requests: Iterator[proto.AddArtifactsRequest] @@ -171,12 +201,14 @@ class ArtifactManager: """Separated for the testing purpose.""" return self._stub.AddArtifacts(requests) - def add_artifacts(self, *path: str, pyfile: bool) -> None: + def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None: """ Add a single artifact to the session. Currently only local files with .jar extension is supported. """ - requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(*path, pyfile=pyfile) + requests: Iterator[proto.AddArtifactsRequest] = self._create_requests( + *path, pyfile=pyfile, archive=archive + ) response: proto.AddArtifactsResponse = self._retrieve_responses(requests) summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = [] diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index e93d9b5c494..544ed5d4183 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1237,8 +1237,8 @@ class SparkConnectClient(object): else: raise SparkConnectGrpcException(str(rpc_error)) from None - def add_artifacts(self, *path: str, pyfile: bool) -> None: - self._artifact_manager.add_artifacts(*path, pyfile=pyfile) + def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None: + self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive) class RetryState: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 3341b88eded..7932ab54081 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -601,7 +601,7 @@ class SparkSession: """ return self._client - def addArtifacts(self, *path: str, pyfile: bool = False) -> None: + def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = False) -> None: """ Add artifact(s) to the client session. Currently only local files are supported. @@ -613,8 +613,15 @@ class SparkSession: Artifact's URIs to add. pyfile : bool Whether to add them as Python dependencies such as .py, .egg, .zip or .jar files. + The pyfiles are directly inserted into the path when executing Python functions + in executors. + archive : bool + Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files. + The archives are unpacked on the executor side automatically. """ - self._client.add_artifacts(*path, pyfile=pyfile) + if pyfile and archive: + raise ValueError("'pyfile' and 'archive' cannot be True together.") + self._client.add_artifacts(*path, pyfile=pyfile, archive=archive) addArtifact = addArtifacts diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 73f47486bab..2bff3fd5bc4 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -21,6 +21,7 @@ import os from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.utils import SPARK_HOME +from pyspark import SparkFiles from pyspark.sql.functions import udf if should_test_connect: @@ -48,7 +49,7 @@ class ArtifactTests(ReusedConnectTestCase): file_name = "smallJar" small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") response = self.artifact_manager._retrieve_responses( - self.artifact_manager._create_requests(small_jar_path, pyfile=False) + self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False) ) self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar")) @@ -57,7 +58,9 @@ class ArtifactTests(ReusedConnectTestCase): small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") - requests = list(self.artifact_manager._create_requests(small_jar_path, pyfile=False)) + requests = list( + self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False) + ) self.assertEqual(len(requests), 1) request = requests[0] @@ -79,7 +82,9 @@ class ArtifactTests(ReusedConnectTestCase): large_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") - requests = list(self.artifact_manager._create_requests(large_jar_path, pyfile=False)) + requests = list( + self.artifact_manager._create_requests(large_jar_path, pyfile=False, archive=False) + ) # Expected chunks = roundUp( file_size / chunk_size) = 12 # File size of `junitLargeJar.jar` is 384581 bytes. large_jar_size = os.path.getsize(large_jar_path) @@ -111,7 +116,9 @@ class ArtifactTests(ReusedConnectTestCase): small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") requests = list( - self.artifact_manager._create_requests(small_jar_path, small_jar_path, pyfile=False) + self.artifact_manager._create_requests( + small_jar_path, small_jar_path, pyfile=False, archive=False + ) ) # Single request containing 2 artifacts. self.assertEqual(len(requests), 1) @@ -147,7 +154,12 @@ class ArtifactTests(ReusedConnectTestCase): requests = list( self.artifact_manager._create_requests( - small_jar_path, large_jar_path, small_jar_path, small_jar_path, pyfile=False + small_jar_path, + large_jar_path, + small_jar_path, + small_jar_path, + pyfile=False, + archive=False, ) ) # There are a total of 14 requests. @@ -237,6 +249,28 @@ class ArtifactTests(ReusedConnectTestCase): self.spark.addArtifacts(f"{package_path}.zip", pyfile=True) self.assertEqual(self.spark.range(1).select(func("id")).first()[0], 5) + def test_add_archive(self): + with tempfile.TemporaryDirectory() as d: + archive_path = os.path.join(d, "my_archive") + os.mkdir(archive_path) + pyfile_path = os.path.join(archive_path, "my_file.txt") + with open(pyfile_path, "w") as f: + _ = f.write("hello world!") + shutil.make_archive(archive_path, "zip", d, "my_archive") + + @udf("string") + def func(x): + with open( + os.path.join( + SparkFiles.getRootDirectory(), "my_files", "my_archive", "my_file.txt" + ), + "r", + ) as my_file: + return my_file.read().strip() + + self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True) + self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!") + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org