This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c303b042966b [SPARK-47808][PYTHON][ML][TESTS] Make pyspark.ml.connect 
tests running without optional dependencies
c303b042966b is described below

commit c303b042966bb3851da6649fc1d7f03de5db20be
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Apr 11 16:42:23 2024 +0900

    [SPARK-47808][PYTHON][ML][TESTS] Make pyspark.ml.connect tests running 
without optional dependencies
    
    ### What changes were proposed in this pull request?
    
    This PR makes `pyspark.ml.connect` tests running without optional 
dependencies.
    
    ### Why are the changes needed?
    
    Optional dependencies should not stop the tests.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    Will be tested together in https://github.com/apache/spark/pull/45941
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #45996 from HyukjinKwon/SPARK-47808.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/ml/connect/classification.py              | 16 ++++++++++++++--
 .../ml/tests/connect/test_connect_classification.py      |  4 +++-
 python/pyspark/ml/tests/connect/test_connect_feature.py  | 13 ++++++++++++-
 python/pyspark/ml/tests/connect/test_connect_pipeline.py | 15 +++++++++++++--
 4 files changed, 42 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/ml/connect/classification.py 
b/python/pyspark/ml/connect/classification.py
index 8b816f51ca27..8d8c6227eac3 100644
--- a/python/pyspark/ml/connect/classification.py
+++ b/python/pyspark/ml/connect/classification.py
@@ -17,8 +17,6 @@
 from typing import Any, Dict, Union, List, Tuple, Callable, Optional
 import math
 
-import torch
-import torch.nn as torch_nn
 import numpy as np
 import pandas as pd
 
@@ -87,6 +85,8 @@ def _train_logistic_regression_model_worker_fn(
     seed: int,
 ) -> Any:
     from pyspark.ml.torch.distributor import _get_spark_partition_data_loader
+    import torch
+    import torch.nn as torch_nn
     from torch.nn.parallel import DistributedDataParallel as DDP
     import torch.distributed
     import torch.optim as optim
@@ -216,6 +216,9 @@ class LogisticRegression(
         self._set(**kwargs)
 
     def _fit(self, dataset: Union[DataFrame, pd.DataFrame]) -> 
"LogisticRegressionModel":
+        import torch
+        import torch.nn as torch_nn
+
         if isinstance(dataset, pd.DataFrame):
             # TODO: support pandas dataframe fitting
             raise NotImplementedError("Fitting pandas dataframe is not 
supported yet.")
@@ -316,6 +319,10 @@ class LogisticRegressionModel(
         return output_cols
 
     def _get_transform_fn(self) -> Callable[["pd.Series"], Any]:
+        import torch
+
+        import torch.nn as torch_nn
+
         model_state_dict = self.torch_model.state_dict()
         num_features = self.num_features
         num_classes = self.num_classes
@@ -357,6 +364,9 @@ class LogisticRegressionModel(
         return self.__class__.__name__ + ".torch"
 
     def _save_core_model(self, path: str) -> None:
+        import torch
+        import torch.nn as torch_nn
+
         lor_torch_model = torch_nn.Sequential(
             self.torch_model,
             torch_nn.Softmax(dim=1),
@@ -364,6 +374,8 @@ class LogisticRegressionModel(
         torch.save(lor_torch_model, path)
 
     def _load_core_model(self, path: str) -> None:
+        import torch
+
         lor_torch_model = torch.load(path)
         self.torch_model = lor_torch_model[0]
 
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py 
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index 1f811c774cbd..ebc1745874d9 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -21,6 +21,7 @@ import unittest
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+torch_requirement_message = "torch is required"
 have_torch = True
 try:
     import torch  # noqa: F401
@@ -32,7 +33,8 @@ if should_test_connect:
 
 
 @unittest.skipIf(
-    not should_test_connect or not have_torch, connect_requirement_message or 
"torch is required"
+    not should_test_connect or not have_torch,
+    connect_requirement_message or torch_requirement_message,
 )
 class ClassificationTestsOnConnect(ClassificationTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py 
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index cf450cc743ae..04b1744c4995 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -20,11 +20,22 @@ import unittest
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 
+have_sklearn = True
+sklearn_requirement_message = None
+try:
+    from sklearn.datasets import load_breast_cancer  # noqa: F401
+except ImportError:
+    have_sklearn = False
+    sklearn_requirement_message = "No sklearn found"
+
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_feature import 
FeatureTestsMixin
 
 
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
+@unittest.skipIf(
+    not should_test_connect or not have_sklearn,
+    connect_requirement_message or sklearn_requirement_message,
+)
 class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py 
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 6a895e892397..45d19f2bcdde 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -24,8 +24,19 @@ from pyspark.testing.connectutils import 
should_test_connect, connect_requiremen
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_pipeline import 
PipelineTestsMixin
 
-
-@unittest.skipIf(not should_test_connect, connect_requirement_message)
+torch_requirement_message = None
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+    torch_requirement_message = "torch is required"
+
+
+@unittest.skipIf(
+    not should_test_connect or not have_torch,
+    connect_requirement_message or torch_requirement_message,
+)
 class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
     def setUp(self) -> None:
         self.spark = (


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to