This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 58863dfa1b4b [SPARK-45394][SPARK-45093][PYTHON][CONNECT] Add retries 
for artifact API. Improve error handling (follow-up to [])
58863dfa1b4b is described below

commit 58863dfa1b4b84dee5a0d6323265f6f3bb71a763
Author: Alice Sayutina <alice.sayut...@databricks.com>
AuthorDate: Fri Oct 6 11:21:17 2023 +0900

    [SPARK-45394][SPARK-45093][PYTHON][CONNECT] Add retries for artifact API. 
Improve error handling (follow-up to [])
    
    ### What changes were proposed in this pull request?
    
    1. Add retries to `add_artifact` api in client
    
    2. Slightly change control flow within `artifact.py` so that client-side 
errors (e.g. FileNotFound) are properly thrown. (Previously we attempted to add 
logs in https://github.com/apache/spark/pull/42949, but that was imperfect 
solution, this should be much better).
    
    3. Accept proper ownership over files in LocalData, and close those 
descriptors.
    
    ### Why are the changes needed?
    
    Improves user experience
    
    ### Does this PR introduce _any_ user-facing change?
    
    Improve error handling, adds retries.
    
    ### How was this patch tested?
    
    Added test coverage for add_artifact when there is no artifact.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    NO
    
    Closes #43216 from cdkrot/SPARK-45394.
    
    Authored-by: Alice Sayutina <alice.sayut...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/client/artifact.py      | 94 ++++++++++++----------
 python/pyspark/sql/connect/client/core.py          | 18 ++++-
 .../sql/tests/connect/client/test_artifact.py      |  7 ++
 3 files changed, 72 insertions(+), 47 deletions(-)

diff --git a/python/pyspark/sql/connect/client/artifact.py 
b/python/pyspark/sql/connect/client/artifact.py
index fb31a57e0f62..5829ec9a8d4d 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -52,7 +52,6 @@ class LocalData(metaclass=abc.ABCMeta):
     Payload stored on this machine.
     """
 
-    @cached_property
     @abc.abstractmethod
     def stream(self) -> BinaryIO:
         pass
@@ -70,14 +69,18 @@ class LocalFile(LocalData):
 
     def __init__(self, path: str):
         self.path = path
-        self._size: int
-        self._stream: int
+
+        # Check that the file can be read
+        # so that incorrect references can be discovered during Artifact 
creation,
+        # and not at the point of consumption.
+
+        with self.stream():
+            pass
 
     @cached_property
     def size(self) -> int:
         return os.path.getsize(self.path)
 
-    @cached_property
     def stream(self) -> BinaryIO:
         return open(self.path, "rb")
 
@@ -89,14 +92,11 @@ class InMemory(LocalData):
 
     def __init__(self, blob: bytes):
         self.blob = blob
-        self._size: int
-        self._stream: int
 
     @cached_property
     def size(self) -> int:
         return len(self.blob)
 
-    @cached_property
     def stream(self) -> BinaryIO:
         return io.BytesIO(self.blob)
 
@@ -244,18 +244,23 @@ class ArtifactManager:
         self, *path: str, pyfile: bool, archive: bool, file: bool
     ) -> Iterator[proto.AddArtifactsRequest]:
         """Separated for the testing purpose."""
-        try:
-            yield from self._add_artifacts(
-                chain(
-                    *(
-                        self._parse_artifacts(p, pyfile=pyfile, 
archive=archive, file=file)
-                        for p in path
-                    )
-                )
-            )
-        except Exception as e:
-            logger.error(f"Failed to submit addArtifacts request: {e}")
-            raise
+
+        # It's crucial that this function is not generator, but only returns 
generator.
+        # This way we are doing artifact parsing within the original caller 
thread
+        # And not during grpc consuming iterator, allowing for much better 
error reporting.
+
+        artifacts: Iterator[Artifact] = chain(
+            *(self._parse_artifacts(p, pyfile=pyfile, archive=archive, 
file=file) for p in path)
+        )
+
+        def generator() -> Iterator[proto.AddArtifactsRequest]:
+            try:
+                yield from self._add_artifacts(artifacts)
+            except Exception as e:
+                logger.error(f"Failed to submit addArtifacts request: {e}")
+                raise
+
+        return generator()
 
     def _retrieve_responses(
         self, requests: Iterator[proto.AddArtifactsRequest]
@@ -279,6 +284,7 @@ class ArtifactManager:
         requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
             *path, pyfile=pyfile, archive=archive, file=file
         )
+
         self._request_add_artifacts(requests)
 
     def _add_forward_to_fs_artifacts(self, local_path: str, dest_path: str) -> 
None:
@@ -337,7 +343,8 @@ class ArtifactManager:
         artifact_chunks = []
 
         for artifact in artifacts:
-            binary = artifact.storage.stream.read()
+            with artifact.storage.stream() as stream:
+                binary = stream.read()
             crc32 = zlib.crc32(binary)
             data = proto.AddArtifactsRequest.ArtifactChunk(data=binary, 
crc=crc32)
             artifact_chunks.append(
@@ -363,31 +370,32 @@ class ArtifactManager:
         )
 
         # Consume stream in chunks until there is no data left to read.
-        for chunk in iter(lambda: 
artifact.storage.stream.read(ArtifactManager.CHUNK_SIZE), b""):
-            if initial_batch:
-                # First RPC contains the `BeginChunkedArtifact` payload 
(`begin_chunk`).
-                yield proto.AddArtifactsRequest(
-                    session_id=self._session_id,
-                    user_context=self._user_context,
-                    begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact(
-                        name=artifact.path,
-                        total_bytes=artifact.size,
-                        num_chunks=get_num_chunks,
-                        initial_chunk=proto.AddArtifactsRequest.ArtifactChunk(
+        with artifact.storage.stream() as stream:
+            for chunk in iter(lambda: stream.read(ArtifactManager.CHUNK_SIZE), 
b""):
+                if initial_batch:
+                    # First RPC contains the `BeginChunkedArtifact` payload 
(`begin_chunk`).
+                    yield proto.AddArtifactsRequest(
+                        session_id=self._session_id,
+                        user_context=self._user_context,
+                        
begin_chunk=proto.AddArtifactsRequest.BeginChunkedArtifact(
+                            name=artifact.path,
+                            total_bytes=artifact.size,
+                            num_chunks=get_num_chunks,
+                            
initial_chunk=proto.AddArtifactsRequest.ArtifactChunk(
+                                data=chunk, crc=zlib.crc32(chunk)
+                            ),
+                        ),
+                    )
+                    initial_batch = False
+                else:
+                    # Subsequent RPCs contains the `ArtifactChunk` payload 
(`chunk`).
+                    yield proto.AddArtifactsRequest(
+                        session_id=self._session_id,
+                        user_context=self._user_context,
+                        chunk=proto.AddArtifactsRequest.ArtifactChunk(
                             data=chunk, crc=zlib.crc32(chunk)
                         ),
-                    ),
-                )
-                initial_batch = False
-            else:
-                # Subsequent RPCs contains the `ArtifactChunk` payload 
(`chunk`).
-                yield proto.AddArtifactsRequest(
-                    session_id=self._session_id,
-                    user_context=self._user_context,
-                    chunk=proto.AddArtifactsRequest.ArtifactChunk(
-                        data=chunk, crc=zlib.crc32(chunk)
-                    ),
-                )
+                    )
 
     def is_cached_artifact(self, hash: str) -> bool:
         """
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index db7f8e6dc75c..9e47379c85e7 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1538,14 +1538,24 @@ class SparkConnectClient(object):
         else:
             raise SparkConnectGrpcException(str(rpc_error)) from None
 
-    def add_artifacts(self, *path: str, pyfile: bool, archive: bool, file: 
bool) -> None:
-        self._artifact_manager.add_artifacts(*path, pyfile=pyfile, 
archive=archive, file=file)
+    def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file: 
bool) -> None:
+        for path in paths:
+            for attempt in self._retrying():
+                with attempt:
+                    self._artifact_manager.add_artifacts(
+                        path, pyfile=pyfile, archive=archive, file=file
+                    )
 
     def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None:
-        self._artifact_manager._add_forward_to_fs_artifacts(local_path, 
dest_path)
+        for attempt in self._retrying():
+            with attempt:
+                
self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path)
 
     def cache_artifact(self, blob: bytes) -> str:
-        return self._artifact_manager.cache_artifact(blob)
+        for attempt in self._retrying():
+            with attempt:
+                return self._artifact_manager.cache_artifact(blob)
+        raise SparkConnectException("Invalid state during retry exception 
handling.")
 
 
 class RetryState:
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py 
b/python/pyspark/sql/tests/connect/client/test_artifact.py
index d45230e926b1..7e9f9dbbf569 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -388,6 +388,13 @@ class ArtifactTests(ReusedConnectTestCase, 
ArtifactTestsMixin):
         self.assertEqual(actualHash, expected_hash)
         
self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True)
 
+    def test_add_not_existing_artifact(self):
+        with tempfile.TemporaryDirectory() as d:
+            with self.assertRaises(FileNotFoundError):
+                self.artifact_manager.add_artifacts(
+                    os.path.join(d, "not_existing"), file=True, pyfile=False, 
archive=False
+                )
+
 
 class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin):
     @classmethod


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to