This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e39b7f662cb [SPARK-51197][ML][PYTHON][CONNECT][TESTS] Unit test clean 
up
8e39b7f662cb is described below

commit 8e39b7f662cb83961b8a57e53e873d9640b379b8
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Thu Feb 13 14:25:53 2025 +0800

    [SPARK-51197][ML][PYTHON][CONNECT][TESTS] Unit test clean up
    
    ### What changes were proposed in this pull request?
    Unit test clean up
    
    ### Why are the changes needed?
    test code clean up
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49927 from zhengruifeng/ml_connect_test_cleanup.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../ml/tests/connect/test_parity_classification.py | 18 +----
 .../ml/tests/connect/test_parity_clustering.py     | 18 +----
 .../ml/tests/connect/test_parity_evaluation.py     | 18 +----
 .../ml/tests/connect/test_parity_regression.py     | 18 +----
 python/pyspark/ml/tests/test_classification.py     | 11 +--
 python/pyspark/ml/tests/test_clustering.py         | 10 +--
 python/pyspark/ml/tests/test_evaluation.py         | 11 +--
 python/pyspark/ml/tests/test_feature.py            |  4 +-
 python/pyspark/ml/tests/test_fpm.py                | 11 +--
 python/pyspark/ml/tests/test_regression.py         | 87 +++++++++++++++-------
 10 files changed, 88 insertions(+), 118 deletions(-)

diff --git a/python/pyspark/ml/tests/connect/test_parity_classification.py 
b/python/pyspark/ml/tests/connect/test_parity_classification.py
index ae358f70b184..3c7e8ff71a2d 100644
--- a/python/pyspark/ml/tests/connect/test_parity_classification.py
+++ b/python/pyspark/ml/tests/connect/test_parity_classification.py
@@ -15,26 +15,14 @@
 # limitations under the License.
 #
 
-import os
 import unittest
 
 from pyspark.ml.tests.test_classification import ClassificationTestsMixin
-from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class ClassificationParityTests(ClassificationTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote(
-            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-        ).getOrCreate()
-
-    def test_assert_remote_mode(self):
-        from pyspark.sql import is_remote
-
-        self.assertTrue(is_remote())
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class ClassificationParityTests(ClassificationTestsMixin, 
ReusedConnectTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_parity_clustering.py 
b/python/pyspark/ml/tests/connect/test_parity_clustering.py
index 0297ce11c3c1..99714b0d6962 100644
--- a/python/pyspark/ml/tests/connect/test_parity_clustering.py
+++ b/python/pyspark/ml/tests/connect/test_parity_clustering.py
@@ -15,26 +15,14 @@
 # limitations under the License.
 #
 
-import os
 import unittest
 
 from pyspark.ml.tests.test_clustering import ClusteringTestsMixin
-from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class ClusteringParityTests(ClusteringTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote(
-            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-        ).getOrCreate()
-
-    def test_assert_remote_mode(self):
-        from pyspark.sql import is_remote
-
-        self.assertTrue(is_remote())
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class ClusteringParityTests(ClusteringTestsMixin, ReusedConnectTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_parity_evaluation.py 
b/python/pyspark/ml/tests/connect/test_parity_evaluation.py
index 9f78313a318e..0325528da37b 100644
--- a/python/pyspark/ml/tests/connect/test_parity_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_parity_evaluation.py
@@ -15,26 +15,14 @@
 # limitations under the License.
 #
 
-import os
 import unittest
 
 from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin
-from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote(
-            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-        ).getOrCreate()
-
-    def test_assert_remote_mode(self):
-        from pyspark.sql import is_remote
-
-        self.assertTrue(is_remote())
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class EvaluatorParityTests(EvaluatorTestsMixin, ReusedConnectTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_parity_regression.py 
b/python/pyspark/ml/tests/connect/test_parity_regression.py
index 67187bb74bd5..7c2743a938fa 100644
--- a/python/pyspark/ml/tests/connect/test_parity_regression.py
+++ b/python/pyspark/ml/tests/connect/test_parity_regression.py
@@ -15,26 +15,14 @@
 # limitations under the License.
 #
 
-import os
 import unittest
 
 from pyspark.ml.tests.test_regression import RegressionTestsMixin
-from pyspark.sql import SparkSession
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
-class RegressionParityTests(RegressionTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.remote(
-            os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-        ).getOrCreate()
-
-    def test_assert_remote_mode(self):
-        from pyspark.sql import is_remote
-
-        self.assertTrue(is_remote())
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class RegressionParityTests(RegressionTestsMixin, ReusedConnectTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/test_classification.py 
b/python/pyspark/ml/tests/test_classification.py
index a5b4238b38ee..5c793dc344c7 100644
--- a/python/pyspark/ml/tests/test_classification.py
+++ b/python/pyspark/ml/tests/test_classification.py
@@ -22,7 +22,7 @@ from shutil import rmtree
 import numpy as np
 
 from pyspark.ml.linalg import Vectors, Matrices
-from pyspark.sql import SparkSession, DataFrame, Row
+from pyspark.sql import DataFrame, Row
 from pyspark.ml.classification import (
     NaiveBayes,
     NaiveBayesModel,
@@ -54,6 +54,7 @@ from pyspark.ml.classification import (
     MultilayerPerceptronClassificationTrainingSummary,
 )
 from pyspark.ml.regression import DecisionTreeRegressionModel
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class ClassificationTestsMixin:
@@ -978,12 +979,8 @@ class ClassificationTestsMixin:
             self.assertEqual(str(model), str(model2))
 
 
-class ClassificationTests(ClassificationTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class ClassificationTests(ClassificationTestsMixin, ReusedSQLTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/test_clustering.py 
b/python/pyspark/ml/tests/test_clustering.py
index 8ec1fcc48ca6..c6724a397b59 100644
--- a/python/pyspark/ml/tests/test_clustering.py
+++ b/python/pyspark/ml/tests/test_clustering.py
@@ -21,7 +21,6 @@ import unittest
 import numpy as np
 
 from pyspark.ml.linalg import Vectors, SparseVector
-from pyspark.sql import SparkSession
 from pyspark.ml.clustering import (
     KMeans,
     KMeansModel,
@@ -38,6 +37,7 @@ from pyspark.ml.clustering import (
     DistributedLDAModel,
     PowerIterationClustering,
 )
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class ClusteringTestsMixin:
@@ -506,12 +506,8 @@ class ClusteringTestsMixin:
             self.assertEqual(pic.getWeightCol(), pic2.getWeightCol())
 
 
-class ClusteringTests(ClusteringTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class ClusteringTests(ClusteringTestsMixin, ReusedSQLTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/test_evaluation.py 
b/python/pyspark/ml/tests/test_evaluation.py
index 3cd84ba528a7..ab552412103a 100644
--- a/python/pyspark/ml/tests/test_evaluation.py
+++ b/python/pyspark/ml/tests/test_evaluation.py
@@ -28,7 +28,8 @@ from pyspark.ml.evaluation import (
     RankingEvaluator,
 )
 from pyspark.ml.linalg import Vectors
-from pyspark.sql import Row, SparkSession
+from pyspark.sql import Row
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class EvaluatorTestsMixin:
@@ -355,13 +356,7 @@ class EvaluatorTestsMixin:
             self.assertTrue(evaluator.isLargerBetter())
 
 
-class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
-
-    def tearDown(self) -> None:
-        self.spark.stop()
-
+class EvaluatorTests(EvaluatorTestsMixin, ReusedSQLTestCase):
     def test_evaluate_invalid_type(self):
         evaluator = RegressionEvaluator(metricName="r2")
         df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)])
diff --git a/python/pyspark/ml/tests/test_feature.py 
b/python/pyspark/ml/tests/test_feature.py
index cd2011c5ec87..c48f645a88f6 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -83,7 +83,7 @@ from pyspark.ml.feature import (
 )
 from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
 from pyspark.sql import Row
-from pyspark.testing.mlutils import SparkSessionTestCase
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class FeatureTestsMixin:
@@ -1772,7 +1772,7 @@ class FeatureTestsMixin:
             self.assertEqual(str(model), str(model2))
 
 
-class FeatureTests(FeatureTestsMixin, SparkSessionTestCase):
+class FeatureTests(FeatureTestsMixin, ReusedSQLTestCase):
     pass
 
 
diff --git a/python/pyspark/ml/tests/test_fpm.py 
b/python/pyspark/ml/tests/test_fpm.py
index 1c2b717b1c85..ea94216c9860 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,13 +18,14 @@
 import tempfile
 import unittest
 
-from pyspark.sql import SparkSession, Row
+from pyspark.sql import Row
 import pyspark.sql.functions as sf
 from pyspark.ml.fpm import (
     FPGrowth,
     FPGrowthModel,
     PrefixSpan,
 )
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class FPMTestsMixin:
@@ -99,12 +100,8 @@ class FPMTestsMixin:
         self.assertEqual(head.freq, 3)
 
 
-class FPMTests(FPMTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class FPMTests(FPMTestsMixin, ReusedSQLTestCase):
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/test_regression.py 
b/python/pyspark/ml/tests/test_regression.py
index d450cdf3d15c..8638fb4d6078 100644
--- a/python/pyspark/ml/tests/test_regression.py
+++ b/python/pyspark/ml/tests/test_regression.py
@@ -21,7 +21,6 @@ import unittest
 import numpy as np
 
 from pyspark.ml.linalg import Vectors
-from pyspark.sql import SparkSession
 from pyspark.ml.regression import (
     AFTSurvivalRegression,
     AFTSurvivalRegressionModel,
@@ -44,25 +43,10 @@ from pyspark.ml.regression import (
     GBTRegressor,
     GBTRegressionModel,
 )
+from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
 class RegressionTestsMixin:
-    @property
-    def df(self):
-        return (
-            self.spark.createDataFrame(
-                [
-                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
-                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
-                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
-                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
-                ],
-                ["label", "weight", "features"],
-            )
-            .coalesce(1)
-            .sortWithinPartitions("weight")
-        )
-
     def test_aft_survival(self):
         spark = self.spark
         df = spark.createDataFrame(
@@ -162,7 +146,21 @@ class RegressionTestsMixin:
             self.assertEqual(str(model), str(model2))
 
     def test_linear_regression(self):
-        df = self.df
+        spark = self.spark
+        df = (
+            spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
+
         lr = LinearRegression(
             regParam=0.0,
             maxIter=2,
@@ -434,7 +432,20 @@ class RegressionTestsMixin:
             self.assertEqual(str(model), str(model2))
 
     def test_decision_tree_regressor(self):
-        df = self.df
+        spark = self.spark
+        df = (
+            spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
 
         dt = DecisionTreeRegressor(
             maxDepth=2,
@@ -490,7 +501,20 @@ class RegressionTestsMixin:
             self.assertEqual(model.toDebugString, model2.toDebugString)
 
     def test_gbt_regressor(self):
-        df = self.df
+        spark = self.spark
+        df = (
+            spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
 
         gbt = GBTRegressor(
             maxIter=3,
@@ -575,7 +599,20 @@ class RegressionTestsMixin:
             self.assertEqual(model.toDebugString, model2.toDebugString)
 
     def test_random_forest_regressor(self):
-        df = self.df
+        spark = self.spark
+        df = (
+            spark.createDataFrame(
+                [
+                    (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                    (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                    (1.5, 3.0, Vectors.dense(2.0, 1.0)),
+                    (0.7, 4.0, Vectors.dense(1.5, 3.0)),
+                ],
+                ["label", "weight", "features"],
+            )
+            .coalesce(1)
+            .sortWithinPartitions("weight")
+        )
 
         rf = RandomForestRegressor(
             numTrees=3,
@@ -643,12 +680,8 @@ class RegressionTestsMixin:
             self.assertEqual(model.toDebugString, model2.toDebugString)
 
 
-class RegressionTests(RegressionTestsMixin, unittest.TestCase):
-    def setUp(self) -> None:
-        self.spark = SparkSession.builder.master("local[4]").getOrCreate()
-
-    def tearDown(self) -> None:
-        self.spark.stop()
+class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase):
+    pass
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to