This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 58863dfa1b4b [SPARK-45394][SPARK-45093][PYTHON][CONNECT] Add retries for artifact API. Improve error handling (follow-up to []) 58863dfa1b4b is described below commit 58863dfa1b4b84dee5a0d6323265f6f3bb71a763 Author: Alice Sayutina <alice.sayut...@databricks.com> AuthorDate: Fri Oct 6 11:21:17 2023 +0900 [SPARK-45394][SPARK-45093][PYTHON][CONNECT] Add retries for artifact API. Improve error handling (follow-up to []) ### What changes were proposed in this pull request? 1. Add retries to `add_artifact` api in client 2. Slightly change control flow within `artifact.py` so that client-side errors (e.g. FileNotFound) are properly thrown. (Previously we attempted to add logs in https://github.com/apache/spark/pull/42949, but that was imperfect solution, this should be much better). 3. Accept proper ownership over files in LocalData, and close those descriptors. ### Why are the changes needed? Improves user experience ### Does this PR introduce _any_ user-facing change? Improve error handling, adds retries. ### How was this patch tested? Added test coverage for add_artifact when there is no artifact. ### Was this patch authored or co-authored using generative AI tooling? NO Closes #43216 from cdkrot/SPARK-45394. Authored-by: Alice Sayutina <alice.sayut...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/client/artifact.py | 94 ++++++++++++---------- python/pyspark/sql/connect/client/core.py | 18 ++++- .../sql/tests/connect/client/test_artifact.py | 7 ++ 3 files changed, 72 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index fb31a57e0f62..5829ec9a8d4d 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -52,7 +52,6 @@ class LocalData(metaclass=abc.ABCMeta): Payload stored on this machine. """ - @cached_property @abc.abstractmethod def stream(self) -> BinaryIO: pass @@ -70,14 +69,18 @@ class LocalFile(LocalData): def __init__(self, path: str): self.path = path - self._size: int - self._stream: int + + # Check that the file can be read + # so that incorrect references can be discovered during Artifact creation, + # and not at the point of consumption. + + with self.stream(): + pass @cached_property def size(self) -> int: return os.path.getsize(self.path) - @cached_property def stream(self) -> BinaryIO: return open(self.path, "rb") @@ -89,14 +92,11 @@ class InMemory(LocalData): def __init__(self, blob: bytes): self.blob = blob - self._size: int - self._stream: int @cached_property def size(self) -> int: return len(self.blob) - @cached_property def stream(self) -> BinaryIO: return io.BytesIO(self.blob) @@ -244,18 +244,23 @@ class ArtifactManager: self, *path: str, pyfile: bool, archive: bool, file: bool ) -> Iterator[proto.AddArtifactsRequest]: """Separated for the testing purpose.""" - try: - yield from self._add_artifacts( - chain( - *( - self._parse_artifacts(p, pyfile=pyfile, archive=archive, file=file) - for p in path - ) - ) - ) - except Exception as e: - logger.error(f"Failed to submit addArtifacts request: {e}") - raise + + # It's crucial that this function is not generator, but only returns generator. + # This way we are doing artifact parsing within the original caller thread + # And not during grpc consuming iterator, allowing for much better error reporting. + + artifacts: Iterator[Artifact] = chain( + *(self._parse_artifacts(p, pyfile=pyfile, archive=archive, file=file) for p in path) + ) + + def generator() -> Iterator[proto.AddArtifactsRequest]: + try: + yield from self._add_artifacts(artifacts) + except Exception as e: + logger.error(f"Failed to submit addArtifacts request: {e}") + raise + + return generator() def _retrieve_responses( self, requests: Iterator[proto.AddArtifactsRequest] @@ -279,6 +284,7 @@ class ArtifactManager: requests: Iterator[proto.AddArtifactsRequest] = self._create_requests( *path, pyfile=pyfile, archive=archive, file=file ) + self._request_add_artifacts(requests) def _add_forward_to_fs_artifacts(self, local_path: str, dest_path: str) -> None: @@ -337,7 +343,8 @@ class ArtifactManager: artifact_chunks = [] for artifact in artifacts: - binary = artifact.storage.stream.read() + with artifact.storage.stream() as stream: + binary = stream.read() crc32 = zlib.crc32(binary) data = proto.AddArtifactsRequest.ArtifactChunk(data=binary, crc=crc32) artifact_chunks.append( @@ -363,31 +370,32 @@ class ArtifactManager: ) # Consume stream in chunks until there is no data left to read. - for chunk in iter(lambda: artifact.storage.stream.read(ArtifactManager.CHUNK_SIZE), b""): - if initial_batch: - # First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`). - yield proto.AddArtifactsRequest( - session_id=self._session_id, - user_context=self._user_context, - begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact( - name=artifact.path, - total_bytes=artifact.size, - num_chunks=get_num_chunks, - initial_chunk=proto.AddArtifactsRequest.ArtifactChunk( + with artifact.storage.stream() as stream: + for chunk in iter(lambda: stream.read(ArtifactManager.CHUNK_SIZE), b""): + if initial_batch: + # First RPC contains the `BeginChunkedArtifact` payload (`begin_chunk`). + yield proto.AddArtifactsRequest( + session_id=self._session_id, + user_context=self._user_context, + begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact( + name=artifact.path, + total_bytes=artifact.size, + num_chunks=get_num_chunks, + initial_chunk=proto.AddArtifactsRequest.ArtifactChunk( + data=chunk, crc=zlib.crc32(chunk) + ), + ), + ) + initial_batch = False + else: + # Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`). + yield proto.AddArtifactsRequest( + session_id=self._session_id, + user_context=self._user_context, + chunk=proto.AddArtifactsRequest.ArtifactChunk( data=chunk, crc=zlib.crc32(chunk) ), - ), - ) - initial_batch = False - else: - # Subsequent RPCs contains the `ArtifactChunk` payload (`chunk`). - yield proto.AddArtifactsRequest( - session_id=self._session_id, - user_context=self._user_context, - chunk=proto.AddArtifactsRequest.ArtifactChunk( - data=chunk, crc=zlib.crc32(chunk) - ), - ) + ) def is_cached_artifact(self, hash: str) -> bool: """ diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index db7f8e6dc75c..9e47379c85e7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1538,14 +1538,24 @@ class SparkConnectClient(object): else: raise SparkConnectGrpcException(str(rpc_error)) from None - def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: bool) -> None: - self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive, file=file) + def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file: bool) -> None: + for path in paths: + for attempt in self._retrying(): + with attempt: + self._artifact_manager.add_artifacts( + path, pyfile=pyfile, archive=archive, file=file + ) def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None: - self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path) + for attempt in self._retrying(): + with attempt: + self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path) def cache_artifact(self, blob: bytes) -> str: - return self._artifact_manager.cache_artifact(blob) + for attempt in self._retrying(): + with attempt: + return self._artifact_manager.cache_artifact(blob) + raise SparkConnectException("Invalid state during retry exception handling.") class RetryState: diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index d45230e926b1..7e9f9dbbf569 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -388,6 +388,13 @@ class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): self.assertEqual(actualHash, expected_hash) self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True) + def test_add_not_existing_artifact(self): + with tempfile.TemporaryDirectory() as d: + with self.assertRaises(FileNotFoundError): + self.artifact_manager.add_artifacts( + os.path.join(d, "not_existing"), file=True, pyfile=False, archive=False + ) + class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org