This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7c7b9585a2a [SPARK-43546][PYTHON][CONNECT][TESTS] Complete parity tests of Pandas UDF 7c7b9585a2a is described below commit 7c7b9585a2aba7bbd52c197b07ed0181ae049c75 Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Wed May 24 15:54:18 2023 +0800 [SPARK-43546][PYTHON][CONNECT][TESTS] Complete parity tests of Pandas UDF ### What changes were proposed in this pull request? Complete parity tests of Pandas UDF. Specifically, parity tests are added referencing ``` test_pandas_udf_grouped_agg.py test_pandas_udf_scalar.py test_pandas_udf_window.py ``` ### Why are the changes needed? Parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. Closes #41268 from xinrong-meng/more_parity. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- dev/sparktestsupport/modules.py | 3 + .../connect/test_parity_pandas_udf_grouped_agg.py | 53 ++++++++ .../tests/connect/test_parity_pandas_udf_scalar.py | 69 ++++++++++ .../tests/connect/test_parity_pandas_udf_window.py | 40 ++++++ .../tests/pandas/test_pandas_udf_grouped_agg.py | 126 +++++++++--------- .../sql/tests/pandas/test_pandas_udf_scalar.py | 146 ++++++++++++--------- .../sql/tests/pandas/test_pandas_udf_window.py | 6 +- 7 files changed, 316 insertions(+), 127 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index e68a83643ff..a95d2425136 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -782,6 +782,9 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.streaming.test_parity_streaming", "pyspark.sql.tests.connect.streaming.test_parity_foreach", "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state", + "pyspark.sql.tests.connect.test_parity_pandas_udf_scalar", + "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg", + "pyspark.sql.tests.connect.test_parity_pandas_udf_window", # ml doctests "pyspark.ml.connect.functions", # ml unittests diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py new file mode 100644 index 00000000000..25914a4b5b5 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import GroupedAggPandasUDFTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class PandasUDFGroupedAggParityTests(GroupedAggPandasUDFTestsMixin, ReusedConnectTestCase): + def test_unsupported_types(self): + self.check_unsupported_types() + + def test_invalid_args(self): + self.check_invalid_args() + + @unittest.skip("Spark Connect doesn't support RDD but the test depends on it.") + def test_grouped_with_empty_partition(self): + super().test_grouped_with_empty_partition() + + # TODO(SPARK-43727): Parity returnType check in Spark Connect + @unittest.skip("Fails in Spark Connect, should enable.") + def check_unsupported_types(self): + super().check_unsupported_types() + + @unittest.skip("Spark Connect does not support convert UNPARSED to catalyst types.") + def test_manual(self): + super().test_manual() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py new file mode 100644 index 00000000000..7a3e0eaf650 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest +from pyspark.sql.tests.pandas.test_pandas_udf_scalar import ScalarPandasUDFTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class PandasUDFScalarParityTests(ScalarPandasUDFTestsMixin, ReusedConnectTestCase): + def test_nondeterministic_vectorized_udf_in_aggregate(self): + self.check_nondeterministic_analysis_exception() + + @unittest.skip("Spark Connect doesn't support RDD but the test depends on it.") + def test_vectorized_udf_empty_partition(self): + super().test_vectorized_udf_empty_partition() + + @unittest.skip("Spark Connect doesn't support RDD but the test depends on it.") + def test_vectorized_udf_struct_with_empty_partition(self): + super().test_vectorized_udf_struct_with_empty_partition() + + def test_vectorized_udf_exception(self): + self.check_vectorized_udf_exception() + + def test_vectorized_udf_nested_struct(self): + self.check_vectorized_udf_nested_struct() + + def test_vectorized_udf_return_scalar(self): + self.check_vectorized_udf_return_scalar() + + def test_scalar_iter_udf_close(self): + self.check_scalar_iter_udf_close() + + # TODO(SPARK-43727): Parity returnType check in Spark Connect + @unittest.skip("Fails in Spark Connect, should enable.") + def test_vectorized_udf_wrong_return_type(self): + self.check_vectorized_udf_wrong_return_type() + + # TODO(SPARK-43727): Parity returnType check in Spark Connect + @unittest.skip("SFails in Spark Connect, should enable.") + def check_vectorized_udf_nested_struct(self): + super.check_vectorized_udf_nested_struct() + + def test_vectorized_udf_invalid_length(self): + self.check_vectorized_udf_invalid_length() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_pandas_udf_scalar import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py new file mode 100644 index 00000000000..98ed2a23df3 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.tests.pandas.test_pandas_udf_window import WindowPandasUDFTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class PandasUDFWindowParityTests(WindowPandasUDFTestsMixin, ReusedConnectTestCase): + # TODO(SPARK-43734): Expression "<lambda>(v)" within a window function doesn't raise a + # AnalysisException + @unittest.skip("Fails in Spark Connect, should enable.") + def test_invalid_args(self): + super().test_invalid_args() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_pandas_udf_window import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index bf8d5a83ec4..96f257b4756 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -52,7 +52,7 @@ if have_pandas: not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class GroupedAggPandasUDFTests(ReusedSQLTestCase): +class GroupedAggPandasUDFTestsMixin: @property def data(self): return ( @@ -177,52 +177,50 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): def test_unsupported_types(self): with QuietTest(self.sc): - with self.assertRaises(PySparkNotImplementedError) as pe: - pandas_udf( - lambda x: x, ArrayType(ArrayType(TimestampType())), PandasUDFType.GROUPED_AGG - ) - - self.check_error( - exception=pe.exception, - error_class="NOT_IMPLEMENTED", - message_parameters={ - "feature": "Invalid return type with grouped aggregate Pandas UDFs: " - "ArrayType(ArrayType(TimestampType(), True), True)" - }, - ) + self.check_unsupported_types() - with QuietTest(self.sc): - with self.assertRaises(PySparkNotImplementedError) as pe: - - @pandas_udf("mean double, std double", PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): - return v.mean(), v.std() - - self.check_error( - exception=pe.exception, - error_class="NOT_IMPLEMENTED", - message_parameters={ - "feature": "Invalid return type with grouped aggregate Pandas UDFs: " - "StructType([StructField('mean', DoubleType(), True), " - "StructField('std', DoubleType(), True)])" - }, - ) - - with QuietTest(self.sc): - with self.assertRaises(PySparkNotImplementedError) as pe: - - @pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): # noqa: F811 - return {v.mean(): v.std()} - - self.check_error( - exception=pe.exception, - error_class="NOT_IMPLEMENTED", - message_parameters={ - "feature": "Invalid return type with grouped aggregate Pandas UDFs: " - "ArrayType(TimestampType(), True)" - }, + def check_unsupported_types(self): + with self.assertRaises(PySparkNotImplementedError) as pe: + pandas_udf( + lambda x: x, ArrayType(ArrayType(TimestampType())), PandasUDFType.GROUPED_AGG ) + self.check_error( + exception=pe.exception, + error_class="NOT_IMPLEMENTED", + message_parameters={ + "feature": "Invalid return type with grouped aggregate Pandas UDFs: " + "ArrayType(ArrayType(TimestampType(), True), True)" + }, + ) + with self.assertRaises(PySparkNotImplementedError) as pe: + + @pandas_udf("mean double, std double", PandasUDFType.GROUPED_AGG) + def mean_and_std_udf(v): + return v.mean(), v.std() + + self.check_error( + exception=pe.exception, + error_class="NOT_IMPLEMENTED", + message_parameters={ + "feature": "Invalid return type with grouped aggregate Pandas UDFs: " + "StructType([StructField('mean', DoubleType(), True), " + "StructField('std', DoubleType(), True)])" + }, + ) + with self.assertRaises(PySparkNotImplementedError) as pe: + + @pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG) + def mean_and_std_udf(v): # noqa: F811 + return {v.mean(): v.std()} + + self.check_error( + exception=pe.exception, + error_class="NOT_IMPLEMENTED", + message_parameters={ + "feature": "Invalid return type with grouped aggregate Pandas UDFs: " + "ArrayType(TimestampType(), True)" + }, + ) def test_alias(self): df = self.data @@ -498,27 +496,25 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): self.assertEqual(result1.first()["v2"], [1.0, 2.0]) def test_invalid_args(self): + with QuietTest(self.sc): + self.check_invalid_args() + + def check_invalid_args(self): df = self.data plus_one = self.python_plus_one mean_udf = self.pandas_agg_mean_udf - - with QuietTest(self.sc): - with self.assertRaisesRegex(AnalysisException, "[MISSING_AGGREGATION]"): - df.groupby(df.id).agg(plus_one(df.v)).collect() - - with QuietTest(self.sc): - with self.assertRaisesRegex( - AnalysisException, "aggregate function.*argument.*aggregate function" - ): - df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() - - with QuietTest(self.sc): - with self.assertRaisesRegex( - AnalysisException, - "The group aggregate pandas UDF `avg` cannot be invoked together with as other, " - "non-pandas aggregate functions.", - ): - df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + with self.assertRaisesRegex(AnalysisException, "[MISSING_AGGREGATION]"): + df.groupby(df.id).agg(plus_one(df.v)).collect() + with self.assertRaisesRegex( + AnalysisException, "aggregate function.*argument.*aggregate function" + ): + df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect() + with self.assertRaisesRegex( + AnalysisException, + "The group aggregate pandas UDF `avg` cannot be invoked together with as other, " + "non-pandas aggregate functions.", + ): + df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() def test_register_vectorized_udf_basic(self): sum_pandas_udf = pandas_udf( @@ -575,6 +571,10 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase): assert filtered.collect()[0]["mean"] == 42.0 +class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import * # noqa: F401 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 8fa6010a62f..b7f76635761 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -72,28 +72,7 @@ if have_pyarrow: not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class ScalarPandasUDFTests(ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - +class ScalarPandasUDFTestsMixin: @property def nondeterministic_vectorized_udf(self): import numpy as np @@ -478,6 +457,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): self.assertListEqual([i, i + 1], f[1]) def test_vectorized_udf_nested_struct(self): + with QuietTest(self.sc): + self.check_vectorized_udf_nested_struct() + + def check_vectorized_udf_nested_struct(self): nested_type = StructType( [ StructField("id", IntegerType()), @@ -487,13 +470,9 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): ), ] ) - for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: - with QuietTest(self.sc): - with self.assertRaisesRegex( - Exception, "Invalid return type with scalar Pandas UDFs" - ): - pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type) + with self.assertRaisesRegex(Exception, "Invalid return type with scalar Pandas UDFs"): + pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type) def test_vectorized_udf_map_type(self): data = [({},), ({"a": 1},), ({"a": 1, "b": 2},), ({"a": 1, "b": 2, "c": 3},)] @@ -543,6 +522,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): self.assertEqual(expected.collect(), res.collect()) def test_vectorized_udf_exception(self): + with QuietTest(self.sc): + self.check_vectorized_udf_exception() + + def check_vectorized_udf_exception(self): df = self.spark.range(10) scalar_raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType()) @@ -552,29 +535,30 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): yield x * (1 / 0) for raise_exception in [scalar_raise_exception, iter_raise_exception]: - with QuietTest(self.sc): - with self.assertRaisesRegex(Exception, "division( or modulo)? by zero"): - df.select(raise_exception(col("id"))).collect() + with self.assertRaisesRegex(Exception, "division( or modulo)? by zero"): + df.select(raise_exception(col("id"))).collect() def test_vectorized_udf_invalid_length(self): + with QuietTest(self.sc): + self.check_vectorized_udf_invalid_length() + + def check_vectorized_udf_invalid_length(self): df = self.spark.range(10) raise_exception = pandas_udf(lambda _: pd.Series(1), LongType()) - with QuietTest(self.sc): - with self.assertRaisesRegex( - Exception, "Result vector from pandas_udf was not the required length" - ): - df.select(raise_exception(col("id"))).collect() + with self.assertRaisesRegex( + Exception, "Result vector from pandas_udf was not the required length" + ): + df.select(raise_exception(col("id"))).collect() @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER) def iter_udf_wong_output_size(it): for _ in it: yield pd.Series(1) - with QuietTest(self.sc): - with self.assertRaisesRegex( - Exception, "The length of output in Scalar iterator.*" "the length of output was 1" - ): - df.select(iter_udf_wong_output_size(col("id"))).collect() + with self.assertRaisesRegex( + Exception, "The length of output in Scalar iterator.*" "the length of output was 1" + ): + df.select(iter_udf_wong_output_size(col("id"))).collect() @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER) def iter_udf_not_reading_all_input(it): @@ -585,9 +569,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}): df1 = self.spark.range(10).repartition(1) - with QuietTest(self.sc): - with self.assertRaisesRegex(Exception, "pandas iterator UDF should exhaust"): - df1.select(iter_udf_not_reading_all_input(col("id"))).collect() + with self.assertRaisesRegex(Exception, "pandas iterator UDF should exhaust"): + df1.select(iter_udf_not_reading_all_input(col("id"))).collect() def test_vectorized_udf_chained(self): df = self.spark.range(10) @@ -632,23 +615,29 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): def test_vectorized_udf_wrong_return_type(self): with QuietTest(self.sc): - for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: - with self.assertRaisesRegex( - NotImplementedError, - "Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType", - ): - pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type) + self.check_vectorized_udf_wrong_return_type() + + def check_vectorized_udf_wrong_return_type(self): + for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]: + with self.assertRaisesRegex( + NotImplementedError, + "Invalid return type.*scalar Pandas UDF.*ArrayType.*TimestampType", + ): + pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type) def test_vectorized_udf_return_scalar(self): + with QuietTest(self.sc): + self.check_vectorized_udf_return_scalar() + + def check_vectorized_udf_return_scalar(self): df = self.spark.range(10) scalar_f = pandas_udf(lambda x: 1.0, DoubleType()) iter_f = pandas_udf( lambda it: map(lambda x: 1.0, it), DoubleType(), PandasUDFType.SCALAR_ITER ) for f in [scalar_f, iter_f]: - with QuietTest(self.sc): - with self.assertRaisesRegex(Exception, "Return.*type.*Series"): - df.select(f(col("id"))).collect() + with self.assertRaisesRegex(Exception, "Return.*type.*Series"): + df.select(f(col("id"))).collect() def test_vectorized_udf_decorator(self): df = self.spark.range(10) @@ -694,7 +683,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): self.assertEqual("Doe", row[0]["last"]) def test_vectorized_udf_varargs(self): - df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + df = self.spark.range(start=1, end=2) scalar_f = pandas_udf(lambda *v: v[0], LongType()) @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER) @@ -941,16 +930,19 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): self.assertTrue(result1["plus_ten(rand)"].equals(result1["rand"] + 10)) def test_nondeterministic_vectorized_udf_in_aggregate(self): + with QuietTest(self.sc): + self.check_nondeterministic_analysis_exception() + + def check_nondeterministic_analysis_exception(self): df = self.spark.range(10) for random_udf in [ self.nondeterministic_vectorized_udf, self.nondeterministic_vectorized_iter_udf, ]: - with QuietTest(self.sc): - with self.assertRaisesRegex(AnalysisException, "nondeterministic"): - df.groupby(df.id).agg(sum(random_udf(df.id))).collect() - with self.assertRaisesRegex(AnalysisException, "nondeterministic"): - df.agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegex(AnalysisException, "nondeterministic"): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegex(AnalysisException, "nondeterministic"): + df.agg(sum(random_udf(df.id))).collect() def test_register_vectorized_udf_basic(self): df = self.spark.range(10).select( @@ -999,6 +991,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): ) def test_scalar_iter_udf_close(self): + with QuietTest(self.sc): + self.check_scalar_iter_udf_close() + + def check_scalar_iter_udf_close(self): @pandas_udf("int", PandasUDFType.SCALAR_ITER) def test_close(batch_iter): try: @@ -1007,9 +1003,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): finally: raise RuntimeError("reached finally block") - with QuietTest(self.sc): - with self.assertRaisesRegex(Exception, "reached finally block"): - self.spark.range(1).select(test_close(col("id"))).collect() + with self.assertRaisesRegex(Exception, "reached finally block"): + self.spark.range(1).select(test_close(col("id"))).collect() @unittest.skip("LimitPushDown should push limits through Python UDFs so this won't occur") def test_scalar_iter_udf_close_early(self): @@ -1217,6 +1212,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): self.assertEqual(expected_multi, df_multi_2.collect()) def test_mixed_udf_and_sql(self): + from pyspark.sql.connect.column import Column as ConnectColumn + df = self.spark.range(0, 1).toDF("v") # Test mixture of UDFs, Pandas UDFs and SQL expression. @@ -1227,7 +1224,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): return x + 1 def f2(x): - assert type(x) == Column + assert type(x) in (Column, ConnectColumn) return x + 10 @pandas_udf("int") @@ -1351,6 +1348,29 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): shutil.rmtree(path) +class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_udf_scalar import * # noqa: F401 diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index 8da97dee83b..eb80dccd9b2 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -50,7 +50,7 @@ if have_pandas: not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class WindowPandasUDFTests(ReusedSQLTestCase): +class WindowPandasUDFTestsMixin: @property def data(self): return ( @@ -394,6 +394,10 @@ class WindowPandasUDFTests(ReusedSQLTestCase): assert_frame_equal(expected1.toPandas(), result1.toPandas()) +class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase): + pass + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_udf_window import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org