This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b37daf5695e [SPARK-44740][CONNECT][FOLLOW] Fix metadata values for 
Artifacts
b37daf5695e is described below

commit b37daf5695e59ef2f29c6e084230ac89153cca26
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Fri Aug 18 20:30:57 2023 +0800

    [SPARK-44740][CONNECT][FOLLOW] Fix metadata values for Artifacts
    
    ### What changes were proposed in this pull request?
    This is a followup for a previous fix where we did not properly propagate 
the metadata from the main client into the dependent stubs.
    
    ### Why are the changes needed?
    compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Existing UT
    
    Closes #42537 from grundprinzip/spark-44740-follow.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/client/artifact.py | 17 +++++++++++++----
 python/pyspark/sql/connect/client/core.py     |  4 +++-
 2 files changed, 16 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/connect/client/artifact.py 
b/python/pyspark/sql/connect/client/artifact.py
index cad030e0d5b..c858768ccbf 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -25,7 +25,7 @@ import sys
 import os
 import zlib
 from itertools import chain
-from typing import List, Iterable, BinaryIO, Iterator, Optional
+from typing import List, Iterable, BinaryIO, Iterator, Optional, Tuple
 import abc
 from pathlib import Path
 from urllib.parse import urlparse
@@ -162,12 +162,19 @@ class ArtifactManager:
     # https://github.com/grpc/grpc.github.io/issues/371.
     CHUNK_SIZE: int = 32 * 1024
 
-    def __init__(self, user_id: Optional[str], session_id: str, channel: 
grpc.Channel):
+    def __init__(
+        self,
+        user_id: Optional[str],
+        session_id: str,
+        channel: grpc.Channel,
+        metadata: Iterable[Tuple[str, str]],
+    ):
         self._user_context = proto.UserContext()
         if user_id is not None:
             self._user_context.user_id = user_id
         self._stub = grpc_lib.SparkConnectServiceStub(channel)
         self._session_id = session_id
+        self._metadata = metadata
 
     def _parse_artifacts(
         self, path_or_uri: str, pyfile: bool, archive: bool, file: bool
@@ -246,7 +253,7 @@ class ArtifactManager:
         self, requests: Iterator[proto.AddArtifactsRequest]
     ) -> proto.AddArtifactsResponse:
         """Separated for the testing purpose."""
-        return self._stub.AddArtifacts(requests)
+        return self._stub.AddArtifacts(requests, metadata=self._metadata)
 
     def _request_add_artifacts(self, requests: 
Iterator[proto.AddArtifactsRequest]) -> None:
         response: proto.AddArtifactsResponse = 
self._retrieve_responses(requests)
@@ -382,7 +389,9 @@ class ArtifactManager:
         request = proto.ArtifactStatusesRequest(
             user_context=self._user_context, session_id=self._session_id, 
names=[artifactName]
         )
-        resp: proto.ArtifactStatusesResponse = 
self._stub.ArtifactStatus(request)
+        resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(
+            request, metadata=self._metadata
+        )
         status = resp.statuses.get(artifactName)
         return status.exists if status is not None else False
 
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 02afe2c50e7..1e439b8c0f6 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -672,7 +672,9 @@ class SparkConnectClient(object):
         self._channel = self._builder.toChannel()
         self._closed = False
         self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
-        self._artifact_manager = ArtifactManager(self._user_id, 
self._session_id, self._channel)
+        self._artifact_manager = ArtifactManager(
+            self._user_id, self._session_id, self._channel, 
self._builder.metadata()
+        )
         self._use_reattachable_execute = use_reattachable_execute
         # Configure logging for the SparkConnect client.
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to