This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2ab70576d68 [SPARK-44264][ML][PYTHON] Incorporating FunctionPickler Into TorchDistributor 2ab70576d68 is described below commit 2ab70576d68d07aa69dc1e5a9264e0c9e9f05df0 Author: Mathew Jacob <mathew.ja...@databricks.com> AuthorDate: Wed Jul 19 08:29:33 2023 +0800 [SPARK-44264][ML][PYTHON] Incorporating FunctionPickler Into TorchDistributor ### What changes were proposed in this pull request? For the pickling process when running distributed training on functions, I migrated the TorchDistributor to leverage the FunctionPickler class instead of internal functions. This separates the responsibility better and uses code already written and tested. FunctionPickler PR is [here](https://github.com/apache/spark/pull/41946). ### Why are the changes needed? Separates the responsibility, making TorchDistributor class deal with more focused responsibility. ### Does this PR introduce _any_ user-facing change? No, this is all internal. ### How was this patch tested? Existing TorchDistributor tests. Closes #42045 from mathewjacob1002/integrate_fn_pickler. Authored-by: Mathew Jacob <mathew.ja...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/ml/torch/distributor.py | 49 +++++----------------------------- 1 file changed, 7 insertions(+), 42 deletions(-) diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 5b5d7c2288a..81c9d03abcf 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -48,6 +48,7 @@ from pyspark.ml.torch.log_communication import ( # type: ignore LogStreamingClient, LogStreamingServer, ) +from pyspark.ml.dl_util import FunctionPickler from pyspark.ml.util import _get_active_session @@ -746,12 +747,13 @@ class TorchDistributor(Distributor): train_fn: Callable, *args: Any, **kwargs: Any ) -> Generator[Tuple[str, str], None, None]: save_dir = TorchDistributor._create_save_dir() - pickle_file_path = TorchDistributor._save_pickled_function( - save_dir, train_fn, *args, **kwargs + pickle_file_path = FunctionPickler.pickle_fn_and_save( + train_fn, "", save_dir, *args, **kwargs ) output_file_path = os.path.join(save_dir, TorchDistributor._PICKLED_OUTPUT_FILE) - train_file_path = TorchDistributor._create_torchrun_train_file( - save_dir, pickle_file_path, output_file_path + script_path = os.path.join(save_dir, TorchDistributor._TRAIN_FILE) + train_file_path = FunctionPickler.create_fn_run_script( + pickle_file_path, output_file_path, script_path ) try: yield (train_file_path, output_file_path) @@ -817,7 +819,7 @@ class TorchDistributor(Distributor): "View stdout logs for detailed error message." ) try: - output = TorchDistributor._get_pickled_output(output_file_path) + output = FunctionPickler.get_fn_output(output_file_path) except Exception as e: raise RuntimeError( "TorchDistributor failed due to a pickling error. " @@ -834,43 +836,6 @@ class TorchDistributor(Distributor): def _cleanup_files(save_dir: str) -> None: shutil.rmtree(save_dir, ignore_errors=True) - @staticmethod - def _save_pickled_function( - save_dir: str, train_fn: Union[str, Callable], *args: Any, **kwargs: Any - ) -> str: - saved_pickle_path = os.path.join(save_dir, TorchDistributor._PICKLED_FUNC_FILE) - with open(saved_pickle_path, "wb") as f: - cloudpickle.dump((train_fn, args, kwargs), f) - return saved_pickle_path - - @staticmethod - def _create_torchrun_train_file( - save_dir_path: str, pickle_file_path: str, output_file_path: str - ) -> str: - code = textwrap.dedent( - f""" - from pyspark import cloudpickle - import os - - if __name__ == "__main__": - with open("{pickle_file_path}", "rb") as f: - train_fn, args, kwargs = cloudpickle.load(f) - output = train_fn(*args, **kwargs) - with open("{output_file_path}", "wb") as f: - cloudpickle.dump(output, f) - """ - ) - saved_file_path = os.path.join(save_dir_path, TorchDistributor._TRAIN_FILE) - with open(saved_file_path, "w") as f: - f.write(code) - return saved_file_path - - @staticmethod - def _get_pickled_output(output_file_path: str) -> Any: - with open(output_file_path, "rb") as f: - output = cloudpickle.load(f) - return output - def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: Any) -> Optional[Any]: """Runs distributed training. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org