This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 4f61662e91a [SPARK-44909][ML] Skip starting torch distributor log 
streaming server when it is not available
4f61662e91a is described below

commit 4f61662e91a77315a9d3d7454884f47c8bb9a6b1
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Wed Aug 23 15:30:50 2023 +0800

    [SPARK-44909][ML] Skip starting torch distributor log streaming server when 
it is not available
    
    ### What changes were proposed in this pull request?
    
    Skip starting torch distributor log streaming server when it is not 
available.
    
    In some cases, e.g., in a databricks connect cluster, there is some network 
limitation that casues starting log streaming server failure, but, this does 
not need to break torch distributor training routine.
    
    In this PR, it captures exception raised from log server `start` method, 
and set server port to be -1 if `start` failed.
    
    ### Why are the changes needed?
    
    In some cases, e.g., in a databricks connect cluster, there is some network 
limitation that casues starting log streaming server failure, but, this does 
not need to break torch distributor training routine.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42606 from WeichenXu123/fix-torch-log-server-in-connect-mode.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
    (cherry picked from commit 80668dc1a36ac0def80f3c18f981fbdacfb2904d)
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 python/pyspark/ml/torch/distributor.py       | 16 +++++++++++++---
 python/pyspark/ml/torch/log_communication.py |  3 +++
 2 files changed, 16 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index b407672ac48..d0979f53d41 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -765,9 +765,19 @@ class TorchDistributor(Distributor):
         log_streaming_server = LogStreamingServer()
         self.driver_address = _get_conf(self.spark, "spark.driver.host", "")
         assert self.driver_address != ""
-        log_streaming_server.start(spark_host_address=self.driver_address)
-        time.sleep(1)  # wait for the server to start
-        self.log_streaming_server_port = log_streaming_server.port
+        try:
+            log_streaming_server.start(spark_host_address=self.driver_address)
+            time.sleep(1)  # wait for the server to start
+            self.log_streaming_server_port = log_streaming_server.port
+        except Exception as e:
+            # If starting log streaming server failed, we don't need to break
+            # the distributor training but emit a warning instead.
+            self.log_streaming_server_port = -1
+            self.logger.warning(
+                "Start torch distributor log streaming server failed, "
+                "You cannot receive logs sent from distributor workers, ",
+                f"error: {repr(e)}.",
+            )
 
         try:
             spark_task_function = self._get_spark_task_function(
diff --git a/python/pyspark/ml/torch/log_communication.py 
b/python/pyspark/ml/torch/log_communication.py
index ca91121d3e3..8efa83e62c3 100644
--- a/python/pyspark/ml/torch/log_communication.py
+++ b/python/pyspark/ml/torch/log_communication.py
@@ -156,6 +156,9 @@ class LogStreamingClient(LogStreamingClientBase):
         warnings.warn(f"{error_msg}: {traceback.format_exc()}\n")
 
     def _connect(self) -> None:
+        if self.port == -1:
+            self._fail("Log streaming server is not available.")
+            return
         try:
             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
             sock.settimeout(self.timeout)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to