This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e1619653895 [SPARK-41593][FOLLOW-UP] Fix the case torch distributor 
logging server not shut down
e1619653895 is described below

commit e1619653895b4d5e11d7121bdb7906355d8c17bf
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Tue May 30 19:13:20 2023 +0800

    [SPARK-41593][FOLLOW-UP] Fix the case torch distributor logging server not 
shut down
    
    ### What changes were proposed in this pull request?
    
    Fix the case torch distributor logging server not shut down.
    
    The `_get_spark_task_function` and `_check_encryption` might raise 
exception, in this case, the logging server must be shut down but it is not 
shut down. This PR fixes the case.
    
    ### Why are the changes needed?
    
    Fix the case torch distributor logging server not shut down
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #41375 from 
WeichenXu123/improve-torch-distributor-log-server-exception-handling.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 python/pyspark/ml/torch/distributor.py | 26 +++++++++++++-------------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index ad8b4d8cc25..0249e6b4b2c 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -665,20 +665,20 @@ class TorchDistributor(Distributor):
         time.sleep(1)  # wait for the server to start
         self.log_streaming_server_port = log_streaming_server.port
 
-        spark_task_function = self._get_spark_task_function(
-            framework_wrapper_fn, train_object, spark_dataframe, *args, 
**kwargs
-        )
-        self._check_encryption()
-        self.logger.info(
-            f"Started distributed training with {self.num_processes} executor 
processes"
-        )
-        if spark_dataframe is not None:
-            input_df = spark_dataframe
-        else:
-            input_df = self.spark.range(
-                start=0, end=self.num_tasks, step=1, 
numPartitions=self.num_tasks
-            )
         try:
+            spark_task_function = self._get_spark_task_function(
+                framework_wrapper_fn, train_object, spark_dataframe, *args, 
**kwargs
+            )
+            self._check_encryption()
+            self.logger.info(
+                f"Started distributed training with {self.num_processes} 
executor processes"
+            )
+            if spark_dataframe is not None:
+                input_df = spark_dataframe
+            else:
+                input_df = self.spark.range(
+                    start=0, end=self.num_tasks, step=1, 
numPartitions=self.num_tasks
+                )
             rows = input_df.mapInArrow(
                 func=spark_task_function, schema="chunk binary", barrier=True
             ).collect()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to