This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8e39b7f662cb [SPARK-51197][ML][PYTHON][CONNECT][TESTS] Unit test clean up 8e39b7f662cb is described below commit 8e39b7f662cb83961b8a57e53e873d9640b379b8 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Feb 13 14:25:53 2025 +0800 [SPARK-51197][ML][PYTHON][CONNECT][TESTS] Unit test clean up ### What changes were proposed in this pull request? Unit test clean up ### Why are the changes needed? test code clean up ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #49927 from zhengruifeng/ml_connect_test_cleanup. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../ml/tests/connect/test_parity_classification.py | 18 +---- .../ml/tests/connect/test_parity_clustering.py | 18 +---- .../ml/tests/connect/test_parity_evaluation.py | 18 +---- .../ml/tests/connect/test_parity_regression.py | 18 +---- python/pyspark/ml/tests/test_classification.py | 11 +-- python/pyspark/ml/tests/test_clustering.py | 10 +-- python/pyspark/ml/tests/test_evaluation.py | 11 +-- python/pyspark/ml/tests/test_feature.py | 4 +- python/pyspark/ml/tests/test_fpm.py | 11 +-- python/pyspark/ml/tests/test_regression.py | 87 +++++++++++++++------- 10 files changed, 88 insertions(+), 118 deletions(-) diff --git a/python/pyspark/ml/tests/connect/test_parity_classification.py b/python/pyspark/ml/tests/connect/test_parity_classification.py index ae358f70b184..3c7e8ff71a2d 100644 --- a/python/pyspark/ml/tests/connect/test_parity_classification.py +++ b/python/pyspark/ml/tests/connect/test_parity_classification.py @@ -15,26 +15,14 @@ # limitations under the License. # -import os import unittest from pyspark.ml.tests.test_classification import ClassificationTestsMixin -from pyspark.sql import SparkSession +from pyspark.testing.connectutils import ReusedConnectTestCase -class ClassificationParityTests(ClassificationTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def test_assert_remote_mode(self): - from pyspark.sql import is_remote - - self.assertTrue(is_remote()) - - def tearDown(self) -> None: - self.spark.stop() +class ClassificationParityTests(ClassificationTestsMixin, ReusedConnectTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_parity_clustering.py b/python/pyspark/ml/tests/connect/test_parity_clustering.py index 0297ce11c3c1..99714b0d6962 100644 --- a/python/pyspark/ml/tests/connect/test_parity_clustering.py +++ b/python/pyspark/ml/tests/connect/test_parity_clustering.py @@ -15,26 +15,14 @@ # limitations under the License. # -import os import unittest from pyspark.ml.tests.test_clustering import ClusteringTestsMixin -from pyspark.sql import SparkSession +from pyspark.testing.connectutils import ReusedConnectTestCase -class ClusteringParityTests(ClusteringTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def test_assert_remote_mode(self): - from pyspark.sql import is_remote - - self.assertTrue(is_remote()) - - def tearDown(self) -> None: - self.spark.stop() +class ClusteringParityTests(ClusteringTestsMixin, ReusedConnectTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_parity_evaluation.py b/python/pyspark/ml/tests/connect/test_parity_evaluation.py index 9f78313a318e..0325528da37b 100644 --- a/python/pyspark/ml/tests/connect/test_parity_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_parity_evaluation.py @@ -15,26 +15,14 @@ # limitations under the License. # -import os import unittest from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin -from pyspark.sql import SparkSession +from pyspark.testing.connectutils import ReusedConnectTestCase -class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def test_assert_remote_mode(self): - from pyspark.sql import is_remote - - self.assertTrue(is_remote()) - - def tearDown(self) -> None: - self.spark.stop() +class EvaluatorParityTests(EvaluatorTestsMixin, ReusedConnectTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_parity_regression.py b/python/pyspark/ml/tests/connect/test_parity_regression.py index 67187bb74bd5..7c2743a938fa 100644 --- a/python/pyspark/ml/tests/connect/test_parity_regression.py +++ b/python/pyspark/ml/tests/connect/test_parity_regression.py @@ -15,26 +15,14 @@ # limitations under the License. # -import os import unittest from pyspark.ml.tests.test_regression import RegressionTestsMixin -from pyspark.sql import SparkSession +from pyspark.testing.connectutils import ReusedConnectTestCase -class RegressionParityTests(RegressionTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def test_assert_remote_mode(self): - from pyspark.sql import is_remote - - self.assertTrue(is_remote()) - - def tearDown(self) -> None: - self.spark.stop() +class RegressionParityTests(RegressionTestsMixin, ReusedConnectTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/test_classification.py b/python/pyspark/ml/tests/test_classification.py index a5b4238b38ee..5c793dc344c7 100644 --- a/python/pyspark/ml/tests/test_classification.py +++ b/python/pyspark/ml/tests/test_classification.py @@ -22,7 +22,7 @@ from shutil import rmtree import numpy as np from pyspark.ml.linalg import Vectors, Matrices -from pyspark.sql import SparkSession, DataFrame, Row +from pyspark.sql import DataFrame, Row from pyspark.ml.classification import ( NaiveBayes, NaiveBayesModel, @@ -54,6 +54,7 @@ from pyspark.ml.classification import ( MultilayerPerceptronClassificationTrainingSummary, ) from pyspark.ml.regression import DecisionTreeRegressionModel +from pyspark.testing.sqlutils import ReusedSQLTestCase class ClassificationTestsMixin: @@ -978,12 +979,8 @@ class ClassificationTestsMixin: self.assertEqual(str(model), str(model2)) -class ClassificationTests(ClassificationTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.master("local[4]").getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() +class ClassificationTests(ClassificationTestsMixin, ReusedSQLTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index 8ec1fcc48ca6..c6724a397b59 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -21,7 +21,6 @@ import unittest import numpy as np from pyspark.ml.linalg import Vectors, SparseVector -from pyspark.sql import SparkSession from pyspark.ml.clustering import ( KMeans, KMeansModel, @@ -38,6 +37,7 @@ from pyspark.ml.clustering import ( DistributedLDAModel, PowerIterationClustering, ) +from pyspark.testing.sqlutils import ReusedSQLTestCase class ClusteringTestsMixin: @@ -506,12 +506,8 @@ class ClusteringTestsMixin: self.assertEqual(pic.getWeightCol(), pic2.getWeightCol()) -class ClusteringTests(ClusteringTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.master("local[4]").getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() +class ClusteringTests(ClusteringTestsMixin, ReusedSQLTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/test_evaluation.py b/python/pyspark/ml/tests/test_evaluation.py index 3cd84ba528a7..ab552412103a 100644 --- a/python/pyspark/ml/tests/test_evaluation.py +++ b/python/pyspark/ml/tests/test_evaluation.py @@ -28,7 +28,8 @@ from pyspark.ml.evaluation import ( RankingEvaluator, ) from pyspark.ml.linalg import Vectors -from pyspark.sql import Row, SparkSession +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase class EvaluatorTestsMixin: @@ -355,13 +356,7 @@ class EvaluatorTestsMixin: self.assertTrue(evaluator.isLargerBetter()) -class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.master("local[4]").getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() - +class EvaluatorTests(EvaluatorTestsMixin, ReusedSQLTestCase): def test_evaluate_invalid_type(self): evaluator = RegressionEvaluator(metricName="r2") df = self.spark.createDataFrame([Row(label=1.0, prediction=1.1)]) diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index cd2011c5ec87..c48f645a88f6 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -83,7 +83,7 @@ from pyspark.ml.feature import ( ) from pyspark.ml.linalg import DenseVector, SparseVector, Vectors from pyspark.sql import Row -from pyspark.testing.mlutils import SparkSessionTestCase +from pyspark.testing.sqlutils import ReusedSQLTestCase class FeatureTestsMixin: @@ -1772,7 +1772,7 @@ class FeatureTestsMixin: self.assertEqual(str(model), str(model2)) -class FeatureTests(FeatureTestsMixin, SparkSessionTestCase): +class FeatureTests(FeatureTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/ml/tests/test_fpm.py b/python/pyspark/ml/tests/test_fpm.py index 1c2b717b1c85..ea94216c9860 100644 --- a/python/pyspark/ml/tests/test_fpm.py +++ b/python/pyspark/ml/tests/test_fpm.py @@ -18,13 +18,14 @@ import tempfile import unittest -from pyspark.sql import SparkSession, Row +from pyspark.sql import Row import pyspark.sql.functions as sf from pyspark.ml.fpm import ( FPGrowth, FPGrowthModel, PrefixSpan, ) +from pyspark.testing.sqlutils import ReusedSQLTestCase class FPMTestsMixin: @@ -99,12 +100,8 @@ class FPMTestsMixin: self.assertEqual(head.freq, 3) -class FPMTests(FPMTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.master("local[4]").getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() +class FPMTests(FPMTestsMixin, ReusedSQLTestCase): + pass if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index d450cdf3d15c..8638fb4d6078 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -21,7 +21,6 @@ import unittest import numpy as np from pyspark.ml.linalg import Vectors -from pyspark.sql import SparkSession from pyspark.ml.regression import ( AFTSurvivalRegression, AFTSurvivalRegressionModel, @@ -44,25 +43,10 @@ from pyspark.ml.regression import ( GBTRegressor, GBTRegressionModel, ) +from pyspark.testing.sqlutils import ReusedSQLTestCase class RegressionTestsMixin: - @property - def df(self): - return ( - self.spark.createDataFrame( - [ - (1.0, 1.0, Vectors.dense(0.0, 5.0)), - (0.0, 2.0, Vectors.dense(1.0, 2.0)), - (1.5, 3.0, Vectors.dense(2.0, 1.0)), - (0.7, 4.0, Vectors.dense(1.5, 3.0)), - ], - ["label", "weight", "features"], - ) - .coalesce(1) - .sortWithinPartitions("weight") - ) - def test_aft_survival(self): spark = self.spark df = spark.createDataFrame( @@ -162,7 +146,21 @@ class RegressionTestsMixin: self.assertEqual(str(model), str(model2)) def test_linear_regression(self): - df = self.df + spark = self.spark + df = ( + spark.createDataFrame( + [ + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.5, 3.0, Vectors.dense(2.0, 1.0)), + (0.7, 4.0, Vectors.dense(1.5, 3.0)), + ], + ["label", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("weight") + ) + lr = LinearRegression( regParam=0.0, maxIter=2, @@ -434,7 +432,20 @@ class RegressionTestsMixin: self.assertEqual(str(model), str(model2)) def test_decision_tree_regressor(self): - df = self.df + spark = self.spark + df = ( + spark.createDataFrame( + [ + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.5, 3.0, Vectors.dense(2.0, 1.0)), + (0.7, 4.0, Vectors.dense(1.5, 3.0)), + ], + ["label", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("weight") + ) dt = DecisionTreeRegressor( maxDepth=2, @@ -490,7 +501,20 @@ class RegressionTestsMixin: self.assertEqual(model.toDebugString, model2.toDebugString) def test_gbt_regressor(self): - df = self.df + spark = self.spark + df = ( + spark.createDataFrame( + [ + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.5, 3.0, Vectors.dense(2.0, 1.0)), + (0.7, 4.0, Vectors.dense(1.5, 3.0)), + ], + ["label", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("weight") + ) gbt = GBTRegressor( maxIter=3, @@ -575,7 +599,20 @@ class RegressionTestsMixin: self.assertEqual(model.toDebugString, model2.toDebugString) def test_random_forest_regressor(self): - df = self.df + spark = self.spark + df = ( + spark.createDataFrame( + [ + (1.0, 1.0, Vectors.dense(0.0, 5.0)), + (0.0, 2.0, Vectors.dense(1.0, 2.0)), + (1.5, 3.0, Vectors.dense(2.0, 1.0)), + (0.7, 4.0, Vectors.dense(1.5, 3.0)), + ], + ["label", "weight", "features"], + ) + .coalesce(1) + .sortWithinPartitions("weight") + ) rf = RandomForestRegressor( numTrees=3, @@ -643,12 +680,8 @@ class RegressionTestsMixin: self.assertEqual(model.toDebugString, model2.toDebugString) -class RegressionTests(RegressionTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.master("local[4]").getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() +class RegressionTests(RegressionTestsMixin, ReusedSQLTestCase): + pass if __name__ == "__main__": --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org