[ 
https://issues.apache.org/jira/browse/SPARK-41775?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Rithwik Ediga Lakhamsani updated SPARK-41775:
---------------------------------------------
    Component/s: ML

> Implement training functions as input
> -------------------------------------
>
>                 Key: SPARK-41775
>                 URL: https://issues.apache.org/jira/browse/SPARK-41775
>             Project: Spark
>          Issue Type: Sub-task
>          Components: ML, PySpark
>    Affects Versions: 3.4.0
>            Reporter: Rithwik Ediga Lakhamsani
>            Priority: Major
>
> Currently, `Distributor().run(...)` takes only files as input. Now we will 
> add in additional functionality to take in functions as well. This will 
> require us to go through the following process on each task in the executor 
> nodes:
> 1. take the input function and args and pickle them
> 2. Create a temp train.py file that looks like
> ```python
> import cloudpickle
> import os
> if __name__ == "__main__":
>     train, args = cloudpickle.load(f"\{tempdir}/train_input.pkl")
>     output = train(*args)
>     if output and os.environ.get("RANK", "") == "0": # this is for 
> partitionId == 0
>         cloudpickle.dump(f"\{tempdir}/train_output.pkl")
> ```
> 3. Run that train.py file with `torchrun`
> 4. Check if `train_output.pkl` has been created on process on partitionId == 
> 0, if it has, then deserialize it and return that output



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to