This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2ab70576d68 [SPARK-44264][ML][PYTHON] Incorporating FunctionPickler 
Into TorchDistributor
2ab70576d68 is described below

commit 2ab70576d68d07aa69dc1e5a9264e0c9e9f05df0
Author: Mathew Jacob <mathew.ja...@databricks.com>
AuthorDate: Wed Jul 19 08:29:33 2023 +0800

    [SPARK-44264][ML][PYTHON] Incorporating FunctionPickler Into 
TorchDistributor
    
    ### What changes were proposed in this pull request?
    For the pickling process when running distributed training on functions, I 
migrated the TorchDistributor to leverage the FunctionPickler class instead of 
internal functions. This separates the responsibility better and uses code 
already written and tested.
    
    FunctionPickler PR is [here](https://github.com/apache/spark/pull/41946).
    
    ### Why are the changes needed?
    Separates the responsibility, making TorchDistributor class deal with more 
focused responsibility.
    
    ### Does this PR introduce _any_ user-facing change?
    No, this is all internal.
    
    ### How was this patch tested?
    Existing TorchDistributor tests.
    
    Closes #42045 from mathewjacob1002/integrate_fn_pickler.
    
    Authored-by: Mathew Jacob <mathew.ja...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 python/pyspark/ml/torch/distributor.py | 49 +++++-----------------------------
 1 file changed, 7 insertions(+), 42 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 5b5d7c2288a..81c9d03abcf 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -48,6 +48,7 @@ from pyspark.ml.torch.log_communication import (  # type: 
ignore
     LogStreamingClient,
     LogStreamingServer,
 )
+from pyspark.ml.dl_util import FunctionPickler
 from pyspark.ml.util import _get_active_session
 
 
@@ -746,12 +747,13 @@ class TorchDistributor(Distributor):
         train_fn: Callable, *args: Any, **kwargs: Any
     ) -> Generator[Tuple[str, str], None, None]:
         save_dir = TorchDistributor._create_save_dir()
-        pickle_file_path = TorchDistributor._save_pickled_function(
-            save_dir, train_fn, *args, **kwargs
+        pickle_file_path = FunctionPickler.pickle_fn_and_save(
+            train_fn, "", save_dir, *args, **kwargs
         )
         output_file_path = os.path.join(save_dir, 
TorchDistributor._PICKLED_OUTPUT_FILE)
-        train_file_path = TorchDistributor._create_torchrun_train_file(
-            save_dir, pickle_file_path, output_file_path
+        script_path = os.path.join(save_dir, TorchDistributor._TRAIN_FILE)
+        train_file_path = FunctionPickler.create_fn_run_script(
+            pickle_file_path, output_file_path, script_path
         )
         try:
             yield (train_file_path, output_file_path)
@@ -817,7 +819,7 @@ class TorchDistributor(Distributor):
                     "View stdout logs for detailed error message."
                 )
             try:
-                output = TorchDistributor._get_pickled_output(output_file_path)
+                output = FunctionPickler.get_fn_output(output_file_path)
             except Exception as e:
                 raise RuntimeError(
                     "TorchDistributor failed due to a pickling error. "
@@ -834,43 +836,6 @@ class TorchDistributor(Distributor):
     def _cleanup_files(save_dir: str) -> None:
         shutil.rmtree(save_dir, ignore_errors=True)
 
-    @staticmethod
-    def _save_pickled_function(
-        save_dir: str, train_fn: Union[str, Callable], *args: Any, **kwargs: 
Any
-    ) -> str:
-        saved_pickle_path = os.path.join(save_dir, 
TorchDistributor._PICKLED_FUNC_FILE)
-        with open(saved_pickle_path, "wb") as f:
-            cloudpickle.dump((train_fn, args, kwargs), f)
-        return saved_pickle_path
-
-    @staticmethod
-    def _create_torchrun_train_file(
-        save_dir_path: str, pickle_file_path: str, output_file_path: str
-    ) -> str:
-        code = textwrap.dedent(
-            f"""
-                    from pyspark import cloudpickle
-                    import os
-
-                    if __name__ == "__main__":
-                        with open("{pickle_file_path}", "rb") as f:
-                            train_fn, args, kwargs = cloudpickle.load(f)
-                        output = train_fn(*args, **kwargs)
-                        with open("{output_file_path}", "wb") as f:
-                            cloudpickle.dump(output, f)
-                    """
-        )
-        saved_file_path = os.path.join(save_dir_path, 
TorchDistributor._TRAIN_FILE)
-        with open(saved_file_path, "w") as f:
-            f.write(code)
-        return saved_file_path
-
-    @staticmethod
-    def _get_pickled_output(output_file_path: str) -> Any:
-        with open(output_file_path, "rb") as f:
-            output = cloudpickle.load(f)
-        return output
-
     def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: 
Any) -> Optional[Any]:
         """Runs distributed training.
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to