This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 71cb9306085b [SPARK-48088][PYTHON][CONNECT][TESTS][3.5] Skip tests that fail in 3.5 client <> 4.0 server 71cb9306085b is described below commit 71cb9306085b07b63f2474e05144334cb7e4109d Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri May 3 15:20:29 2024 +0900 [SPARK-48088][PYTHON][CONNECT][TESTS][3.5] Skip tests that fail in 3.5 client <> 4.0 server ### What changes were proposed in this pull request? This PR proposes to skip the tests that fail with 3.5 client and 4.0 server in Spark Connect (by adding `SPARK_SKIP_CONNECT_COMPAT_TESTS`). This is a base work for https://github.com/apache/spark/pull/46298. This partially backports https://github.com/apache/spark/pull/45870 This PR also adds `SPARK_CONNECT_TESTING_REMOTE` environment variable so developers can run PySpark unittests against a Spark Connect server. ### Why are the changes needed? In order to set up the CI that tests 3.5 client and 4.0 server in Spark Connect. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Tested it in my fork, see https://github.com/apache/spark/pull/46298 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46334 from HyukjinKwon/SPARK-48088. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/ml/connect/functions.py | 3 +- .../tests/connect/test_connect_classification.py | 8 +- .../ml/tests/connect/test_connect_evaluation.py | 10 ++- .../ml/tests/connect/test_connect_feature.py | 7 +- .../ml/tests/connect/test_connect_function.py | 1 + .../ml/tests/connect/test_connect_pipeline.py | 5 +- .../ml/tests/connect/test_connect_summarizer.py | 7 +- .../ml/tests/connect/test_connect_tuning.py | 5 +- .../tests/connect/test_parity_torch_data_loader.py | 3 +- .../tests/connect/test_parity_torch_distributor.py | 4 +- python/pyspark/ml/torch/tests/test_data_loader.py | 3 +- python/pyspark/sql/connect/avro/functions.py | 2 +- python/pyspark/sql/connect/catalog.py | 5 +- python/pyspark/sql/connect/column.py | 5 +- python/pyspark/sql/connect/conf.py | 5 +- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/connect/functions.py | 3 +- python/pyspark/sql/connect/group.py | 5 +- python/pyspark/sql/connect/protobuf/functions.py | 2 +- python/pyspark/sql/connect/readwriter.py | 3 +- python/pyspark/sql/connect/session.py | 5 +- python/pyspark/sql/connect/streaming/query.py | 2 +- python/pyspark/sql/connect/streaming/readwriter.py | 3 +- python/pyspark/sql/connect/window.py | 5 +- python/pyspark/sql/dataframe.py | 8 +- python/pyspark/sql/functions.py | 6 +- .../sql/tests/connect/client/test_artifact.py | 4 +- .../connect/streaming/test_parity_listener.py | 6 +- .../connect/streaming/test_parity_streaming.py | 9 ++- .../sql/tests/connect/test_connect_basic.py | 7 +- .../sql/tests/connect/test_connect_function.py | 1 + .../tests/connect/test_parity_pandas_udf_scalar.py | 5 ++ .../pyspark/sql/tests/connect/test_parity_udtf.py | 86 ++++++++++++++++++++++ python/pyspark/sql/tests/connect/test_utils.py | 8 +- python/pyspark/sql/tests/pandas/test_pandas_map.py | 2 +- python/pyspark/sql/tests/pandas/test_pandas_udf.py | 10 ++- python/pyspark/sql/tests/test_datasources.py | 6 +- python/pyspark/sql/tests/test_types.py | 3 + python/pyspark/testing/connectutils.py | 2 +- python/run-tests.py | 25 ++++--- 40 files changed, 234 insertions(+), 57 deletions(-) diff --git a/python/pyspark/ml/connect/functions.py b/python/pyspark/ml/connect/functions.py index ab7e3ab3c9ad..d8aa54dcf9be 100644 --- a/python/pyspark/ml/connect/functions.py +++ b/python/pyspark/ml/connect/functions.py @@ -39,6 +39,7 @@ array_to_vector.__doc__ = PyMLFunctions.array_to_vector.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -54,7 +55,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("ml.connect.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index f3e621c19f0f..2763d3f613ae 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -15,12 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import unittest +import os + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin -have_torch = True +# TODO(SPARK-48083): Reenable this test case +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: @@ -31,7 +33,7 @@ except ImportError: class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py b/python/pyspark/ml/tests/connect/test_connect_evaluation.py index ce7cf03049d3..35af54605ca8 100644 --- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py @@ -14,12 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_evaluation import EvaluationTestsMixin -have_torcheval = True +# TODO(SPARK-48084): Reenable this test case +have_torcheval = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torcheval # noqa: F401 except ImportError: @@ -29,7 +31,9 @@ except ImportError: @unittest.skipIf(not have_torcheval, "torcheval is required") class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py index d7698c377220..49021f6e82c5 100644 --- a/python/pyspark/ml/tests/connect/test_connect_feature.py +++ b/python/pyspark/ml/tests/connect/test_connect_feature.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_function.py b/python/pyspark/ml/tests/connect/test_connect_function.py index 7da3d3f1addd..fc3344ecebfe 100644 --- a/python/pyspark/ml/tests/connect/test_connect_function.py +++ b/python/pyspark/ml/tests/connect/test_connect_function.py @@ -33,6 +33,7 @@ if should_test_connect: from pyspark.ml.connect import functions as CF +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectMLFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index e676c8bfee95..dc7490bf14b1 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -15,8 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin @@ -24,7 +25,7 @@ from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixi class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py b/python/pyspark/ml/tests/connect/test_connect_summarizer.py index 0b0537dfee3c..28cfa4b4dc1b 100644 --- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py +++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest + from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_summarizer import SummarizerTestsMixin class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase): def setUp(self) -> None: - self.spark = SparkSession.builder.remote("local[2]").getOrCreate() + self.spark = SparkSession.builder.remote( + os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") + ).getOrCreate() def tearDown(self) -> None: self.spark.stop() diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index 18673d4b26be..901367e44d20 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -15,16 +15,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest from pyspark.sql import SparkSession from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( - SparkSession.builder.remote("local[2]") + SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py index 18556633d89f..60f683bf726c 100644 --- a/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py +++ b/python/pyspark/ml/tests/connect/test_parity_torch_data_loader.py @@ -15,10 +15,11 @@ # limitations under the License. # +import os import unittest from pyspark.sql import SparkSession -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py index b855332f96c4..238775ded2a2 100644 --- a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py +++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py @@ -19,7 +19,7 @@ import os import shutil import unittest -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: @@ -81,7 +81,7 @@ class TorchDistributorLocalUnitTestsOnConnect( ] -@unittest.skipIf(not have_torch, "torch is required") +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class TorchDistributorLocalUnitTestsIIOnConnect( TorchDistributorLocalUnitTestsMixin, unittest.TestCase ): diff --git a/python/pyspark/ml/torch/tests/test_data_loader.py b/python/pyspark/ml/torch/tests/test_data_loader.py index 67ab6e378cea..f7814f819541 100644 --- a/python/pyspark/ml/torch/tests/test_data_loader.py +++ b/python/pyspark/ml/torch/tests/test_data_loader.py @@ -15,10 +15,11 @@ # limitations under the License. # +import os import numpy as np import unittest -have_torch = True +have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ try: import torch # noqa: F401 except ImportError: diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py index bf019ef8fe7d..821660fdbd30 100644 --- a/python/pyspark/sql/connect/avro/functions.py +++ b/python/pyspark/sql/connect/avro/functions.py @@ -85,7 +85,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.avro.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/catalog.py b/python/pyspark/sql/connect/catalog.py index 2a54a0d727af..069a8d013ff3 100644 --- a/python/pyspark/sql/connect/catalog.py +++ b/python/pyspark/sql/connect/catalog.py @@ -326,6 +326,7 @@ Catalog.__doc__ = PySparkCatalog.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -333,7 +334,9 @@ def _test() -> None: globs = pyspark.sql.connect.catalog.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.catalog tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.catalog tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 052929381633..464f5397b85b 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -483,6 +483,7 @@ Column.__doc__ = PySparkColumn.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -490,7 +491,9 @@ def _test() -> None: globs = pyspark.sql.connect.column.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.column tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.column tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/conf.py b/python/pyspark/sql/connect/conf.py index d323de716c46..cb296a750e62 100644 --- a/python/pyspark/sql/connect/conf.py +++ b/python/pyspark/sql/connect/conf.py @@ -97,6 +97,7 @@ RuntimeConf.__doc__ = PySparkRuntimeConfig.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -104,7 +105,9 @@ def _test() -> None: globs = pyspark.sql.connect.conf.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.conf tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.conf tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 7b326538a8e0..ff6191642025 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -2150,7 +2150,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.dataframe tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index e2583f84c417..ecb800bbee93 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -3906,6 +3906,7 @@ call_function.__doc__ = pysparkfuncs.call_function.__doc__ def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.functions @@ -3914,7 +3915,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.functions tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index a393d2cb37e8..2d5a66fd6ef9 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -388,6 +388,7 @@ PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -396,7 +397,9 @@ def _test() -> None: globs = pyspark.sql.connect.group.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.group tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.group tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py index 56119f4bc4eb..c8e12640b313 100644 --- a/python/pyspark/sql/connect/protobuf/functions.py +++ b/python/pyspark/sql/connect/protobuf/functions.py @@ -144,7 +144,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.protobuf.functions tests") - .remote("local[2]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index cfcbcede3487..7cfdf9910d7e 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -830,6 +830,7 @@ class DataFrameWriterV2(OptionUtils): def _test() -> None: import sys + import os import doctest from pyspark.sql import SparkSession as PySparkSession import pyspark.sql.connect.readwriter @@ -838,7 +839,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.readwriter tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 1307c8bdd84e..10d599ca397b 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -910,6 +910,7 @@ SparkSession.__doc__ = PySparkSession.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -917,7 +918,9 @@ def _test() -> None: globs = pyspark.sql.connect.session.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.session tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.session tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) # Uses PySpark session to test builder. diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index 021d27e939de..7d968b175f28 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -276,7 +276,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.streaming.query tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/streaming/readwriter.py b/python/pyspark/sql/connect/streaming/readwriter.py index 89097fcf43a0..afee833fda4e 100644 --- a/python/pyspark/sql/connect/streaming/readwriter.py +++ b/python/pyspark/sql/connect/streaming/readwriter.py @@ -586,6 +586,7 @@ class DataStreamWriter: def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -595,7 +596,7 @@ def _test() -> None: globs["spark"] = ( PySparkSession.builder.appName("sql.connect.streaming.readwriter tests") - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) diff --git a/python/pyspark/sql/connect/window.py b/python/pyspark/sql/connect/window.py index ad082c6e265d..922a641c2428 100644 --- a/python/pyspark/sql/connect/window.py +++ b/python/pyspark/sql/connect/window.py @@ -235,6 +235,7 @@ Window.__doc__ = PySparkWindow.__doc__ def _test() -> None: + import os import sys import doctest from pyspark.sql import SparkSession as PySparkSession @@ -242,7 +243,9 @@ def _test() -> None: globs = pyspark.sql.connect.window.__dict__.copy() globs["spark"] = ( - PySparkSession.builder.appName("sql.connect.window tests").remote("local[4]").getOrCreate() + PySparkSession.builder.appName("sql.connect.window tests") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) + .getOrCreate() ) (failure_count, test_count) = doctest.testmod( diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 97f60967da70..afa979dab019 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1513,7 +1513,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): >>> df.cache() DataFrame[id: bigint] - >>> df.explain() + >>> df.explain() # doctest: +SKIP == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- InMemoryTableScan ... @@ -1556,7 +1556,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): >>> df.persist() DataFrame[id: bigint] - >>> df.explain() + >>> df.explain() # doctest: +SKIP == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- InMemoryTableScan ... @@ -3887,8 +3887,8 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): >>> df2 = spark.createDataFrame([(3, "Charlie"), (4, "Dave")], ["id", "name"]) >>> df1 = df1.withColumn("age", lit(30)) >>> df2 = df2.withColumn("age", lit(40)) - >>> df3 = df1.union(df2) - >>> df3.show() + >>> df3 = df1.union(df2) # doctest: +SKIP + >>> df3.show() # doctest: +SKIP +-----+-------+---+ | name| id|age| +-----+-------+---+ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 06cb3063d1b1..7e1a8faf0017 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7973,7 +7973,7 @@ def to_unix_timestamp( >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([("2016-04-08",)], ["e"]) - >>> df.select(to_unix_timestamp(df.e).alias('r')).collect() + >>> df.select(to_unix_timestamp(df.e).alias('r')).collect() # doctest: +SKIP [Row(r=None)] >>> spark.conf.unset("spark.sql.session.timeZone") """ @@ -8084,7 +8084,7 @@ def current_database() -> Column: Examples -------- - >>> spark.range(1).select(current_database()).show() + >>> spark.range(1).select(current_database()).show() # doctest: +SKIP +------------------+ |current_database()| +------------------+ @@ -8103,7 +8103,7 @@ def current_schema() -> Column: Examples -------- >>> import pyspark.sql.functions as sf - >>> spark.range(1).select(sf.current_schema()).show() + >>> spark.range(1).select(sf.current_schema()).show() # doctest: +SKIP +------------------+ |current_database()| +------------------+ diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index d45230e926b1..cf3eea0b5560 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -146,6 +146,7 @@ class ArtifactTestsMixin: ) +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def root(cls): @@ -389,6 +390,7 @@ class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True) +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires local-cluster") class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def conf(cls): @@ -403,7 +405,7 @@ class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): @classmethod def master(cls): - return "local-cluster[2,2,512]" + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local-cluster[2,2,512]") if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 5069a76cfdb7..35ca2681cc97 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest import time @@ -45,6 +45,10 @@ class TestListener(StreamingQueryListener): df.write.mode("append").saveAsTable("listener_terminated_events") +# TODO(SPARK-48089): Reenable this test case +@unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" +) class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): def test_listener_events(self): test_listener = TestListener() diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py index 6fe2b8940801..e7c1958064bb 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -14,13 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import unittest from pyspark.sql.tests.streaming.test_streaming import StreamingTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase class StreamingParityTests(StreamingTestsMixin, ReusedConnectTestCase): - pass + # TODO(SPARK-48090): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_stream_exception(self): + super(StreamingParityTests, self).test_stream_exception() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 2904eb42587e..48e5248e28f5 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -80,6 +80,7 @@ if should_test_connect: from pyspark.sql.connect.client.core import Retrying, SparkConnectClient +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, PandasOnSparkTestUtils): """Parent test fixture class for all Spark Connect related test cases.""" @@ -3250,12 +3251,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertTrue(df.is_cached) +@unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Session creation different from local mode" +) class SparkConnectSessionTests(ReusedConnectTestCase): def setUp(self) -> None: self.spark = ( PySparkSession.builder.config(conf=self.conf()) .appName(self.__class__.__name__) - .remote("local[4]") + .remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")) .getOrCreate() ) @@ -3347,6 +3351,7 @@ class SparkConnectSessionTests(ReusedConnectTestCase): self.assertIn("Create a new SparkSession is only supported with SparkConnect.", str(e)) +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectSessionWithOptionsTest(unittest.TestCase): def setUp(self) -> None: self.spark = ( diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index a5d330fe1a7e..dc101e98e01d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -36,6 +36,7 @@ if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as CDF +@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access") class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py index c950ca2e17c3..960f7f11e873 100644 --- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import unittest from pyspark.sql.tests.pandas.test_pandas_udf_scalar import ScalarPandasUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase @@ -31,6 +32,10 @@ class PandasUDFScalarParityTests(ScalarPandasUDFTestsMixin, ReusedConnectTestCas def test_vectorized_udf_struct_with_empty_partition(self): super().test_vectorized_udf_struct_with_empty_partition() + # TODO(SPARK-48086): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_vectorized_udf_exception(self): self.check_vectorized_udf_exception() diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 1222b1bb5b44..5955b502e48b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import unittest from pyspark.testing.connectutils import should_test_connect if should_test_connect: @@ -57,6 +59,90 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): ): TestUDTF(lit(1)).collect() + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_init_with_additional_args(self): + super(UDTFParityTests, self).test_udtf_init_with_additional_args() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_with_wrong_num_input(self): + super(UDTFParityTests, self).test_udtf_with_wrong_num_input() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_array_output_type_casting(self): + super(UDTFParityTests, self).test_array_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_map_output_type_casting(self): + super(UDTFParityTests, self).test_map_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_numeric_output_type_casting(self): + super(UDTFParityTests, self).test_numeric_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_numeric_output_type_casting(self): + super(UDTFParityTests, self).test_numeric_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_numeric_string_output_type_casting(self): + super(UDTFParityTests, self).test_numeric_string_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_string_output_type_casting(self): + super(UDTFParityTests, self).test_string_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_string_output_type_casting(self): + super(UDTFParityTests, self).test_string_output_type_casting() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_struct_output_type_casting_dict(self): + super(UDTFParityTests, self).test_struct_output_type_casting_dict() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_init_with_additional_args(self): + super(UDTFParityTests, self).test_udtf_init_with_additional_args() + + # TODO(SPARK-48087): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_udtf_with_wrong_num_input(self): + super(UDTFParityTests, self).test_udtf_with_wrong_num_input() + class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): @classmethod diff --git a/python/pyspark/sql/tests/connect/test_utils.py b/python/pyspark/sql/tests/connect/test_utils.py index 917cb58057f7..19fa9cd93f32 100644 --- a/python/pyspark/sql/tests/connect/test_utils.py +++ b/python/pyspark/sql/tests/connect/test_utils.py @@ -14,13 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import unittest from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.sql.tests.test_utils import UtilsTestsMixin class ConnectUtilsTests(ReusedConnectTestCase, UtilsTestsMixin): - pass + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) + def test_assert_approx_equal_decimaltype_custom_rtol_pass(self): + super(ConnectUtilsTests, self).test_assert_approx_equal_decimaltype_custom_rtol_pass() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index fb2f9214c5d8..c3ba7b3e93a0 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -110,7 +110,7 @@ class MapInPandasTestsMixin: df = ( self.spark.range(10, numPartitions=3) .select(col("id").cast("string").alias("str")) - .withColumn("bin", encode(col("str"), "utf8")) + .withColumn("bin", encode(col("str"), "utf-8")) ) actual = df.mapInPandas(func, "str string, bin binary").collect() expected = df.collect() diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 34cd9c235819..4673375ccf69 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest import datetime from typing import cast @@ -262,6 +262,10 @@ class PandasUDFTestsMixin: .collect, ) + # TODO(SPARK-48086): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_pandas_udf_detect_unsafe_type_conversion(self): import pandas as pd import numpy as np @@ -285,6 +289,10 @@ class PandasUDFTestsMixin: with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}): df.select(["A"]).withColumn("udf", udf("A")).collect() + # TODO(SPARK-48086): Reenable this test case + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_pandas_udf_arrow_overflow(self): import pandas as pd diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 6418983b06a4..c920fa75f4b2 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os +import unittest import shutil import tempfile import uuid @@ -146,6 +147,9 @@ class DataSourcesTestsMixin: schema = self.spark.read.option("inferSchema", True).csv(rdd, samplingRatio=0.5).schema self.assertEqual(schema, StructType([StructField("_c0", IntegerType(), True)])) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_checking_csv_header(self): path = tempfile.mkdtemp() shutil.rmtree(path) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 90ecfd657765..00bd1d9a6f83 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -812,6 +812,9 @@ class TypesTestsMixin: self.assertRaises(IndexError, lambda: struct1[9]) self.assertRaises(TypeError, lambda: struct1[9.9]) + @unittest.skipIf( + "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Failed with different Client <> Server" + ) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index ba81c7836728..a063f27c9ea2 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -178,7 +178,7 @@ class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, PySparkErrorTestUti @classmethod def master(cls): - return "local[4]" + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]") @classmethod def setUpClass(cls): diff --git a/python/run-tests.py b/python/run-tests.py index b9031765d943..ca8ddb5ff863 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -60,15 +60,17 @@ LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() -# Find out where the assembly jars are located. -# TODO: revisit for Scala 2.13 -for scala in ["2.12", "2.13"]: - build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) - if os.path.isdir(build_dir): - SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") - break -else: - raise RuntimeError("Cannot find assembly build directory, please build Spark first.") +SPARK_DIST_CLASSPATH = "" +if "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ: + # Find out where the assembly jars are located. + # TODO: revisit for Scala 2.13 + for scala in ["2.12", "2.13"]: + build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) + if os.path.isdir(build_dir): + SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") + break + else: + raise RuntimeError("Cannot find assembly build directory, please build Spark first.") def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_output): @@ -98,6 +100,11 @@ def run_individual_python_test(target_dir, test_name, pyspark_python, keep_test_ 'PYARROW_IGNORE_TIMEZONE': '1', }) + if "SPARK_CONNECT_TESTING_REMOTE" in os.environ: + env.update({"SPARK_CONNECT_TESTING_REMOTE": os.environ["SPARK_CONNECT_TESTING_REMOTE"]}) + if "SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ: + env.update({"SPARK_SKIP_JVM_REQUIRED_TESTS": os.environ["SPARK_SKIP_CONNECT_COMPAT_TESTS"]}) + # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is # recognized by the tempfile module to override the default system temp directory. tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org