This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8d28839  [SPARK-35946][PYTHON] Respect Py4J server in 
InheritableThread API
8d28839 is described below

commit 8d28839689614b497be06743ef04a70f815ae0cb
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Tue Jun 29 22:18:54 2021 -0700

    [SPARK-35946][PYTHON] Respect Py4J server in InheritableThread API
    
    ### What changes were proposed in this pull request?
    
    Currently ,we sets the environment variable `PYSPARK_PIN_THREAD` at the 
client side of `InhertiableThread` API for Py4J (`python/pyspark/util.py`). If 
the Py4J gateway is created somewhere else (e.g., Zeppelin, etc), it could 
introduce a breakage at:
    
    ```python
    from pyspark import SparkContext
    jvm = SparkContext._jvm
    thread_connection = jvm._gateway_client.get_thread_connection()
    # `AttributeError: 'GatewayClient' object has no attribute 
'get_thread_connection'` (non-pinned thread mode)
    # `get_thread_connection` is only in 'ClientServer' (pinned thread mode)
    ```
    
    This PR proposes to check the given gateway created, and do the pinned 
thread mode behaviour accordingly so we can avoid any breakage when Py4J 
server/gateway is created separately from somewhere else without a pinned 
thread mode.
    
    ### Why are the changes needed?
    
    To avoid any potential breakage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, the change happened only in the master 
(https://github.com/apache/spark/commit/fdd7ca5f4e35a906090f3c6b160bdba9ac9fd0ca).
    
    ### How was this patch tested?
    
    This is actually a partial revert of 
https://github.com/apache/spark/commit/fdd7ca5f4e35a906090f3c6b160bdba9ac9fd0ca.
 As long as the existing tests pass, I guess we're all good.
    
    I also manually tested to make doubly sure:
    
    **Before**:
    
    ```python
    >>> from pyspark import InheritableThread, inheritable_thread_target
    >>> InheritableThread(lambda: 1).start()
    >>> inheritable_thread_target(lambda: 1)()
    Traceback (most recent call last):
      File "/.../python3.8/lib/python3.8/threading.py", line 932, in 
_bootstrap_inner
        self.run()
      File "/.../python3.8/lib/python3.8/threading.py", line 870, in run
        self._target(*self._args, **self._kwargs)
      File "/.../spark/python/pyspark/util.py", line 361, in 
copy_local_properties
        InheritableThread._clean_py4j_conn_for_current_thread()
      File "/.../spark/python/pyspark/util.py", line 381, in 
_clean_py4j_conn_for_current_thread
        thread_connection = jvm._gateway_client.get_thread_connection()
    AttributeError: 'GatewayClient' object has no attribute 
'get_thread_connection'
    
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/.../spark/python/pyspark/util.py", line 324, in wrapped
        InheritableThread._clean_py4j_conn_for_current_thread()
      File "/.../spark/python/pyspark/util.py", line 381, in 
_clean_py4j_conn_for_current_thread
        thread_connection = jvm._gateway_client.get_thread_connection()
    AttributeError: 'GatewayClient' object has no attribute 
'get_thread_connection'
    ```
    
    **After**:
    
    ```python
    >>> from pyspark import InheritableThread, inheritable_thread_target
    >>> InheritableThread(lambda: 1).start()
    >>> inheritable_thread_target(lambda: 1)()
    1
    ```
    
    Closes #33147 from HyukjinKwon/SPARK-35946.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Dongjoon Hyun <dongj...@apache.org>
---
 python/pyspark/util.py | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 5b4a0b3..e075b04 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -26,6 +26,8 @@ import threading
 import traceback
 import types
 
+from py4j.clientserver import ClientServer
+
 __all__ = []  # type: ignore
 
 
@@ -308,7 +310,9 @@ def inheritable_thread_target(f):
     """
     from pyspark import SparkContext
 
-    if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
+    if isinstance(SparkContext._gateway, ClientServer):
+        # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
+
         # NOTICE the internal difference vs `InheritableThread`. 
`InheritableThread`
         # copies local properties when the thread starts but 
`inheritable_thread_target`
         # copies when the function is wrapped.
@@ -350,7 +354,8 @@ class InheritableThread(threading.Thread):
     def __init__(self, target, *args, **kwargs):
         from pyspark import SparkContext
 
-        if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
+        if isinstance(SparkContext._gateway, ClientServer):
+            # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
             def copy_local_properties(*a, **k):
                 # self._props is set before starting the thread to match the 
behavior with JVM.
                 assert hasattr(self, "_props")
@@ -368,7 +373,9 @@ class InheritableThread(threading.Thread):
     def start(self, *args, **kwargs):
         from pyspark import SparkContext
 
-        if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
+        if isinstance(SparkContext._gateway, ClientServer):
+            # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
+
             # Local property copy should happen in Thread.start to mimic JVM's 
behavior.
             self._props = 
SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
         return super(InheritableThread, self).start(*args, **kwargs)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to