This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a758d6a0f9d [SPARK-43081][FOLLOW-UP][ML][CONNECT] Make torch 
dataloader support torch 1.x
a758d6a0f9d is described below

commit a758d6a0f9dfa32881cfcec263da0ab0c02f5c1d
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Tue Jun 27 07:57:21 2023 -0700

    [SPARK-43081][FOLLOW-UP][ML][CONNECT] Make torch dataloader support torch 
1.x
    
    ### What changes were proposed in this pull request?
    
    Make torch dataloader support torch 1.x.
    Currently, when running with torch 1.x with num_workers > 0, an error is 
raised like:
    ```
    ValueError: prefetch_factor option could only be specified in 
multiprocessing.let num_workers > 0 to enable multiprocessing.
    ```
    
    ### Why are the changes needed?
    
    Compatibility fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Manually run unit tests with torch 1.x
    
    Closes #41751 from WeichenXu123/support-torch-1.x.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/ml/torch/distributor.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 9f9636e6b10..8b34acd959e 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -995,4 +995,11 @@ def _get_spark_partition_data_loader(
 
     dataset = _SparkPartitionTorchDataset(arrow_file, schema, num_samples)
 
-    return DataLoader(dataset, batch_size, num_workers=num_workers, 
prefetch_factor=prefetch_factor)
+    if num_workers > 0:
+        return DataLoader(
+            dataset, batch_size, num_workers=num_workers, 
prefetch_factor=prefetch_factor
+        )
+    else:
+        # if num_workers is zero, we cannot set `prefetch_factor` otherwise
+        # torch will raise error.
+        return DataLoader(dataset, batch_size, num_workers=num_workers)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to