[ 
https://issues.apache.org/jira/browse/SPARK-41775?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Rithwik Ediga Lakhamsani updated SPARK-41775:
---------------------------------------------
    Description: 
Sidenote: make formatting updates described in 
https://github.com/apache/spark/pull/39188

 

Currently, `Distributor().run(...)` takes only files as input. Now we will add 
in additional functionality to take in functions as well. This will require us 
to go through the following process on each task in the executor nodes:
1. take the input function and args and pickle them
2. Create a temp train.py file that looks like
{code:java}
import cloudpickle
import os
if _name_ == "_main_":
    train, args = cloudpickle.load(f"{tempdir}/train_input.pkl")
    output = train(*args)
    if output and os.environ.get("RANK", "") == "0": # this is for partitionId 
== 0
        cloudpickle.dump(f"{tempdir}/train_output.pkl") {code}
3. Run that train.py file with `torchrun`

4. Check if `train_output.pkl` has been created on process on partitionId == 0, 
if it has, then deserialize it and return that output through `.collect()`

  was:
Currently, `Distributor().run(...)` takes only files as input. Now we will add 
in additional functionality to take in functions as well. This will require us 
to go through the following process on each task in the executor nodes:
1. take the input function and args and pickle them
2. Create a temp train.py file that looks like
{code:java}
import cloudpickle
import os
if _name_ == "_main_":
    train, args = cloudpickle.load(f"{tempdir}/train_input.pkl")
    output = train(*args)
    if output and os.environ.get("RANK", "") == "0": # this is for partitionId 
== 0
        cloudpickle.dump(f"{tempdir}/train_output.pkl") {code}
3. Run that train.py file with `torchrun`

4. Check if `train_output.pkl` has been created on process on partitionId == 0, 
if it has, then deserialize it and return that output through `.collect()`


> Implement training functions as input
> -------------------------------------
>
>                 Key: SPARK-41775
>                 URL: https://issues.apache.org/jira/browse/SPARK-41775
>             Project: Spark
>          Issue Type: Sub-task
>          Components: ML, PySpark
>    Affects Versions: 3.4.0
>            Reporter: Rithwik Ediga Lakhamsani
>            Priority: Major
>
> Sidenote: make formatting updates described in 
> https://github.com/apache/spark/pull/39188
>  
> Currently, `Distributor().run(...)` takes only files as input. Now we will 
> add in additional functionality to take in functions as well. This will 
> require us to go through the following process on each task in the executor 
> nodes:
> 1. take the input function and args and pickle them
> 2. Create a temp train.py file that looks like
> {code:java}
> import cloudpickle
> import os
> if _name_ == "_main_":
>     train, args = cloudpickle.load(f"{tempdir}/train_input.pkl")
>     output = train(*args)
>     if output and os.environ.get("RANK", "") == "0": # this is for 
> partitionId == 0
>         cloudpickle.dump(f"{tempdir}/train_output.pkl") {code}
> 3. Run that train.py file with `torchrun`
> 4. Check if `train_output.pkl` has been created on process on partitionId == 
> 0, if it has, then deserialize it and return that output through `.collect()`



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to