This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 31a81983fb2d [SPARK-45780][CONNECT] Propagate all Spark Connect client threadlocals in InheritableThread 31a81983fb2d is described below commit 31a81983fb2d19e05fadccdf49c37dd4f5c50465 Author: Juliusz Sompolski <ju...@databricks.com> AuthorDate: Fri Nov 3 17:53:55 2023 -0700 [SPARK-45780][CONNECT] Propagate all Spark Connect client threadlocals in InheritableThread ### What changes were proposed in this pull request? Currently pyspark InheritableThread propagates Spark Connect session.client.thread_local.tags to child threads. Generalize this to propagate all thread locals, and also make a deep copy, just like the scala equivalent does a clone. ### Why are the changes needed? Generalize the mechanism of SparkConnectClient.thread_local ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test for propagating SparkSession tags should cover this. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43649 from juliuszsompolski/SPARK-45780. Authored-by: Juliusz Sompolski <ju...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/util.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 9c70bac2a3d9..4a828d6bfc94 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -16,6 +16,7 @@ # limitations under the License. # +import copy import functools import itertools import os @@ -343,14 +344,19 @@ def inheritable_thread_target(f: Optional[Union[Callable, "SparkSession"]] = Non assert session is not None, "Spark Connect session must be provided." def outer(ff: Callable) -> Callable: - if not hasattr(session.client.thread_local, "tags"): # type: ignore[union-attr] - session.client.thread_local.tags = set() # type: ignore[union-attr] - tags = set(session.client.thread_local.tags) # type: ignore[union-attr] + session_client_thread_local_attrs = [ + (attr, copy.deepcopy(value)) + for ( + attr, + value, + ) in session.client.thread_local.__dict__.items() # type: ignore[union-attr] + ] @functools.wraps(ff) def inner(*args: Any, **kwargs: Any) -> Any: - # Set tags in child thread. - session.client.thread_local.tags = tags # type: ignore[union-attr] + # Set thread locals in child thread. + for attr, value in session_client_thread_local_attrs: + setattr(session.client.thread_local, attr, value) # type: ignore[union-attr] return ff(*args, **kwargs) return inner --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org