This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a758d6a0f9d [SPARK-43081][FOLLOW-UP][ML][CONNECT] Make torch dataloader support torch 1.x a758d6a0f9d is described below commit a758d6a0f9dfa32881cfcec263da0ab0c02f5c1d Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Tue Jun 27 07:57:21 2023 -0700 [SPARK-43081][FOLLOW-UP][ML][CONNECT] Make torch dataloader support torch 1.x ### What changes were proposed in this pull request? Make torch dataloader support torch 1.x. Currently, when running with torch 1.x with num_workers > 0, an error is raised like: ``` ValueError: prefetch_factor option could only be specified in multiprocessing.let num_workers > 0 to enable multiprocessing. ``` ### Why are the changes needed? Compatibility fix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually run unit tests with torch 1.x Closes #41751 from WeichenXu123/support-torch-1.x. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/ml/torch/distributor.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 9f9636e6b10..8b34acd959e 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -995,4 +995,11 @@ def _get_spark_partition_data_loader( dataset = _SparkPartitionTorchDataset(arrow_file, schema, num_samples) - return DataLoader(dataset, batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor) + if num_workers > 0: + return DataLoader( + dataset, batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor + ) + else: + # if num_workers is zero, we cannot set `prefetch_factor` otherwise + # torch will raise error. + return DataLoader(dataset, batch_size, num_workers=num_workers) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org