[ https://issues.apache.org/jira/browse/SPARK-43081?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Weichen Xu updated SPARK-43081: ------------------------------- Description: Add torch distributor data loader that loads data from spark partition data. We can add 2 APIs like: Adds a `TorchDistributor` method API : {code:java} def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs): """ Runs distributed training using provided spark DataFrame as input data. You should ensure the input spark DataFrame have evenly divided partitions, and this method starts a barrier spark job that each spark task in the job process one partition of the input spark DataFrame. Parameters ---------- train_function : Either a PyTorch function, PyTorch Lightning function that launches distributed training. Note that inside the function, you can call `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch data loader, the data loader loads data from the corresponding partition of the input spark DataFrame. spark_dataframe : An input spark DataFrame that can be used in PyTorch `train_function` function. See `train_function` argument doc for details. args : `args` need to be the input parameters to `train_function` function. It would look like >>> model = distributor.run(train, 1e-3, 64) where train is a function and 1e-3 and 64 are regular numeric inputs to the function. kwargs : `kwargs` need to be the key-work input parameters to `train_function` function. It would look like >>> model = distributor.run(train, tol=1e-3, max_iter=64) where train is a function that has 2 arguments `tol` and `max_iter`. Returns ------- Returns the output of `train_function` called with args inside spark rank 0 task. """{code} Adds an loader API: {code:java} def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2): """ This function must be called inside the `train_function` where `train_function` is the input argument of `TorchDistributor.train_on_dataframe`. The function returns a pytorch data loader that loads data from the corresponding spark partition data. Parameters ---------- num_samples : Number of samples to generate per epoch. If `num_samples` is less than the number of rows in the spark partition, it generate the first `num_samples` rows of the spark partition, if `num_samples` is greater than the number of rows in the spark partition, then after the iterator loaded all rows from the partition, it wraps round back to the first row. batch_size: How many samples per batch to load. prefetch: Number of batches loaded in advance. """{code} was: Add torch distributor data loader that loads data from spark partition data. We can add 2 APIs like: Adds a `TorchDistributor` method API : ``` def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs): """ Runs distributed training using provided spark DataFrame as input data. You should ensure the input spark DataFrame have evenly divided partitions, and this method starts a barrier spark job that each spark task in the job process one partition of the input spark DataFrame. Parameters ---------- train_function : Either a PyTorch function, PyTorch Lightning function that launches distributed training. Note that inside the function, you can call `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch data loader, the data loader loads data from the corresponding partition of the input spark DataFrame. spark_dataframe : An input spark DataFrame that can be used in PyTorch `train_function` function. See `train_function` argument doc for details. args : `args` need to be the input parameters to `train_function` function. It would look like >>> model = distributor.run(train, 1e-3, 64) where train is a function and 1e-3 and 64 are regular numeric inputs to the function. kwargs : `kwargs` need to be the key-work input parameters to `train_function` function. It would look like >>> model = distributor.run(train, tol=1e-3, max_iter=64) where train is a function that has 2 arguments `tol` and `max_iter`. Returns ------- Returns the output of `train_function` called with args inside spark rank 0 task. """ ``` Adds an loader API: ``` def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2): """ This function must be called inside the `train_function` where `train_function` is the input argument of `TorchDistributor.train_on_dataframe`. The function returns a pytorch data loader that loads data from the corresponding spark partition data. Parameters ---------- num_samples : Number of samples to generate per epoch. If `num_samples` is less than the number of rows in the spark partition, it generate the first `num_samples` rows of the spark partition, if `num_samples` is greater than the number of rows in the spark partition, then after the iterator loaded all rows from the partition, it wraps round back to the first row. batch_size: How many samples per batch to load. prefetch: Number of batches loaded in advance. """ ``` > Add torch distributor data loader that loads data from spark partition data > --------------------------------------------------------------------------- > > Key: SPARK-43081 > URL: https://issues.apache.org/jira/browse/SPARK-43081 > Project: Spark > Issue Type: Sub-task > Components: Connect, ML, PySpark > Affects Versions: 3.5.0 > Reporter: Weichen Xu > Priority: Major > > Add torch distributor data loader that loads data from spark partition data. > > We can add 2 APIs like: > Adds a `TorchDistributor` method API : > {code:java} > def train_on_dataframe(self, train_function, spark_dataframe, *args, > **kwargs): > """ > Runs distributed training using provided spark DataFrame as input > data. > You should ensure the input spark DataFrame have evenly divided > partitions, > and this method starts a barrier spark job that each spark task in > the job > process one partition of the input spark DataFrame. > Parameters > ---------- > train_function : > Either a PyTorch function, PyTorch Lightning function that > launches distributed > training. Note that inside the function, you can call > `pyspark.ml.torch.distributor.get_spark_partition_data_loader` > API to get a torch > data loader, the data loader loads data from the corresponding > partition of the > input spark DataFrame. > spark_dataframe : > An input spark DataFrame that can be used in PyTorch > `train_function` function. > See `train_function` argument doc for details. > args : > `args` need to be the input parameters to `train_function` > function. It would look like > >>> model = distributor.run(train, 1e-3, 64) > where train is a function and 1e-3 and 64 are regular numeric > inputs to the function. > kwargs : > `kwargs` need to be the key-work input parameters to > `train_function` function. > It would look like > >>> model = distributor.run(train, tol=1e-3, max_iter=64) > where train is a function that has 2 arguments `tol` and > `max_iter`. > Returns > ------- > Returns the output of `train_function` called with args inside > spark rank 0 task. > """{code} > > Adds an loader API: > > {code:java} > def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2): > """ > This function must be called inside the `train_function` where > `train_function` > is the input argument of `TorchDistributor.train_on_dataframe`. > The function returns a pytorch data loader that loads data from > the corresponding spark partition data. > Parameters > ---------- > num_samples : > Number of samples to generate per epoch. If `num_samples` is less > than the number of > rows in the spark partition, it generate the first `num_samples` rows > of > the spark partition, if `num_samples` is greater than the number of > rows in the spark partition, then after the iterator loaded all rows > from the partition, > it wraps round back to the first row. > batch_size: > How many samples per batch to load. > prefetch: > Number of batches loaded in advance. > """{code} -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org