This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 31a81983fb2d [SPARK-45780][CONNECT] Propagate all Spark Connect client 
threadlocals in InheritableThread
31a81983fb2d is described below

commit 31a81983fb2d19e05fadccdf49c37dd4f5c50465
Author: Juliusz Sompolski <ju...@databricks.com>
AuthorDate: Fri Nov 3 17:53:55 2023 -0700

    [SPARK-45780][CONNECT] Propagate all Spark Connect client threadlocals in 
InheritableThread
    
    ### What changes were proposed in this pull request?
    
    Currently pyspark InheritableThread propagates Spark Connect 
session.client.thread_local.tags to child threads. Generalize this to propagate 
all thread locals, and also make a deep copy, just like the scala equivalent 
does a clone.
    
    ### Why are the changes needed?
    
    Generalize the mechanism of SparkConnectClient.thread_local
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test for propagating SparkSession tags should cover this.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43649 from juliuszsompolski/SPARK-45780.
    
    Authored-by: Juliusz Sompolski <ju...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/util.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 9c70bac2a3d9..4a828d6bfc94 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 #
 
+import copy
 import functools
 import itertools
 import os
@@ -343,14 +344,19 @@ def inheritable_thread_target(f: Optional[Union[Callable, 
"SparkSession"]] = Non
         assert session is not None, "Spark Connect session must be provided."
 
         def outer(ff: Callable) -> Callable:
-            if not hasattr(session.client.thread_local, "tags"):  # type: 
ignore[union-attr]
-                session.client.thread_local.tags = set()  # type: 
ignore[union-attr]
-            tags = set(session.client.thread_local.tags)  # type: 
ignore[union-attr]
+            session_client_thread_local_attrs = [
+                (attr, copy.deepcopy(value))
+                for (
+                    attr,
+                    value,
+                ) in session.client.thread_local.__dict__.items()  # type: 
ignore[union-attr]
+            ]
 
             @functools.wraps(ff)
             def inner(*args: Any, **kwargs: Any) -> Any:
-                # Set tags in child thread.
-                session.client.thread_local.tags = tags  # type: 
ignore[union-attr]
+                # Set thread locals in child thread.
+                for attr, value in session_client_thread_local_attrs:
+                    setattr(session.client.thread_local, attr, value)  # type: 
ignore[union-attr]
                 return ff(*args, **kwargs)
 
             return inner


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to