This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7c7b9585a2a [SPARK-43546][PYTHON][CONNECT][TESTS] Complete parity 
tests of Pandas UDF
7c7b9585a2a is described below

commit 7c7b9585a2aba7bbd52c197b07ed0181ae049c75
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Wed May 24 15:54:18 2023 +0800

    [SPARK-43546][PYTHON][CONNECT][TESTS] Complete parity tests of Pandas UDF
    
    ### What changes were proposed in this pull request?
    Complete parity tests of Pandas UDF.
    
    Specifically, parity tests are added referencing
    
    ```
    test_pandas_udf_grouped_agg.py
    test_pandas_udf_scalar.py
    test_pandas_udf_window.py
    ```
    
    ### Why are the changes needed?
    Parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #41268 from xinrong-meng/more_parity.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 dev/sparktestsupport/modules.py                    |   3 +
 .../connect/test_parity_pandas_udf_grouped_agg.py  |  53 ++++++++
 .../tests/connect/test_parity_pandas_udf_scalar.py |  69 ++++++++++
 .../tests/connect/test_parity_pandas_udf_window.py |  40 ++++++
 .../tests/pandas/test_pandas_udf_grouped_agg.py    | 126 +++++++++---------
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 146 ++++++++++++---------
 .../sql/tests/pandas/test_pandas_udf_window.py     |   6 +-
 7 files changed, 316 insertions(+), 127 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index e68a83643ff..a95d2425136 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -782,6 +782,9 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.streaming.test_parity_streaming",
         "pyspark.sql.tests.connect.streaming.test_parity_foreach",
         "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
+        "pyspark.sql.tests.connect.test_parity_pandas_udf_scalar",
+        "pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg",
+        "pyspark.sql.tests.connect.test_parity_pandas_udf_window",
         # ml doctests
         "pyspark.ml.connect.functions",
         # ml unittests
diff --git 
a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py
new file mode 100644
index 00000000000..25914a4b5b5
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_grouped_agg.py
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import 
GroupedAggPandasUDFTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class PandasUDFGroupedAggParityTests(GroupedAggPandasUDFTestsMixin, 
ReusedConnectTestCase):
+    def test_unsupported_types(self):
+        self.check_unsupported_types()
+
+    def test_invalid_args(self):
+        self.check_invalid_args()
+
+    @unittest.skip("Spark Connect doesn't support RDD but the test depends on 
it.")
+    def test_grouped_with_empty_partition(self):
+        super().test_grouped_with_empty_partition()
+
+    # TODO(SPARK-43727): Parity returnType check in Spark Connect
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def check_unsupported_types(self):
+        super().check_unsupported_types()
+
+    @unittest.skip("Spark Connect does not support convert UNPARSED to 
catalyst types.")
+    def test_manual(self):
+        super().test_manual()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_pandas_udf_grouped_agg import * 
 # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py
new file mode 100644
index 00000000000..7a3e0eaf650
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_scalar.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+from pyspark.sql.tests.pandas.test_pandas_udf_scalar import 
ScalarPandasUDFTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class PandasUDFScalarParityTests(ScalarPandasUDFTestsMixin, 
ReusedConnectTestCase):
+    def test_nondeterministic_vectorized_udf_in_aggregate(self):
+        self.check_nondeterministic_analysis_exception()
+
+    @unittest.skip("Spark Connect doesn't support RDD but the test depends on 
it.")
+    def test_vectorized_udf_empty_partition(self):
+        super().test_vectorized_udf_empty_partition()
+
+    @unittest.skip("Spark Connect doesn't support RDD but the test depends on 
it.")
+    def test_vectorized_udf_struct_with_empty_partition(self):
+        super().test_vectorized_udf_struct_with_empty_partition()
+
+    def test_vectorized_udf_exception(self):
+        self.check_vectorized_udf_exception()
+
+    def test_vectorized_udf_nested_struct(self):
+        self.check_vectorized_udf_nested_struct()
+
+    def test_vectorized_udf_return_scalar(self):
+        self.check_vectorized_udf_return_scalar()
+
+    def test_scalar_iter_udf_close(self):
+        self.check_scalar_iter_udf_close()
+
+    # TODO(SPARK-43727): Parity returnType check in Spark Connect
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_vectorized_udf_wrong_return_type(self):
+        self.check_vectorized_udf_wrong_return_type()
+
+    # TODO(SPARK-43727): Parity returnType check in Spark Connect
+    @unittest.skip("SFails in Spark Connect, should enable.")
+    def check_vectorized_udf_nested_struct(self):
+        super.check_vectorized_udf_nested_struct()
+
+    def test_vectorized_udf_invalid_length(self):
+        self.check_vectorized_udf_invalid_length()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_pandas_udf_scalar import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py
new file mode 100644
index 00000000000..98ed2a23df3
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf_window.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_udf_window import 
WindowPandasUDFTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class PandasUDFWindowParityTests(WindowPandasUDFTestsMixin, 
ReusedConnectTestCase):
+    # TODO(SPARK-43734): Expression "<lambda>(v)" within a window function 
doesn't raise a
+    #  AnalysisException
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_invalid_args(self):
+        super().test_invalid_args()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_pandas_udf_window import *  # 
noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index bf8d5a83ec4..96f257b4756 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -52,7 +52,7 @@ if have_pandas:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class GroupedAggPandasUDFTests(ReusedSQLTestCase):
+class GroupedAggPandasUDFTestsMixin:
     @property
     def data(self):
         return (
@@ -177,52 +177,50 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
 
     def test_unsupported_types(self):
         with QuietTest(self.sc):
-            with self.assertRaises(PySparkNotImplementedError) as pe:
-                pandas_udf(
-                    lambda x: x, ArrayType(ArrayType(TimestampType())), 
PandasUDFType.GROUPED_AGG
-                )
-
-            self.check_error(
-                exception=pe.exception,
-                error_class="NOT_IMPLEMENTED",
-                message_parameters={
-                    "feature": "Invalid return type with grouped aggregate 
Pandas UDFs: "
-                    "ArrayType(ArrayType(TimestampType(), True), True)"
-                },
-            )
+            self.check_unsupported_types()
 
-        with QuietTest(self.sc):
-            with self.assertRaises(PySparkNotImplementedError) as pe:
-
-                @pandas_udf("mean double, std double", 
PandasUDFType.GROUPED_AGG)
-                def mean_and_std_udf(v):
-                    return v.mean(), v.std()
-
-            self.check_error(
-                exception=pe.exception,
-                error_class="NOT_IMPLEMENTED",
-                message_parameters={
-                    "feature": "Invalid return type with grouped aggregate 
Pandas UDFs: "
-                    "StructType([StructField('mean', DoubleType(), True), "
-                    "StructField('std', DoubleType(), True)])"
-                },
-            )
-
-        with QuietTest(self.sc):
-            with self.assertRaises(PySparkNotImplementedError) as pe:
-
-                @pandas_udf(ArrayType(TimestampType()), 
PandasUDFType.GROUPED_AGG)
-                def mean_and_std_udf(v):  # noqa: F811
-                    return {v.mean(): v.std()}
-
-            self.check_error(
-                exception=pe.exception,
-                error_class="NOT_IMPLEMENTED",
-                message_parameters={
-                    "feature": "Invalid return type with grouped aggregate 
Pandas UDFs: "
-                    "ArrayType(TimestampType(), True)"
-                },
+    def check_unsupported_types(self):
+        with self.assertRaises(PySparkNotImplementedError) as pe:
+            pandas_udf(
+                lambda x: x, ArrayType(ArrayType(TimestampType())), 
PandasUDFType.GROUPED_AGG
             )
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={
+                "feature": "Invalid return type with grouped aggregate Pandas 
UDFs: "
+                "ArrayType(ArrayType(TimestampType(), True), True)"
+            },
+        )
+        with self.assertRaises(PySparkNotImplementedError) as pe:
+
+            @pandas_udf("mean double, std double", PandasUDFType.GROUPED_AGG)
+            def mean_and_std_udf(v):
+                return v.mean(), v.std()
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={
+                "feature": "Invalid return type with grouped aggregate Pandas 
UDFs: "
+                "StructType([StructField('mean', DoubleType(), True), "
+                "StructField('std', DoubleType(), True)])"
+            },
+        )
+        with self.assertRaises(PySparkNotImplementedError) as pe:
+
+            @pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG)
+            def mean_and_std_udf(v):  # noqa: F811
+                return {v.mean(): v.std()}
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={
+                "feature": "Invalid return type with grouped aggregate Pandas 
UDFs: "
+                "ArrayType(TimestampType(), True)"
+            },
+        )
 
     def test_alias(self):
         df = self.data
@@ -498,27 +496,25 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         self.assertEqual(result1.first()["v2"], [1.0, 2.0])
 
     def test_invalid_args(self):
+        with QuietTest(self.sc):
+            self.check_invalid_args()
+
+    def check_invalid_args(self):
         df = self.data
         plus_one = self.python_plus_one
         mean_udf = self.pandas_agg_mean_udf
-
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(AnalysisException, 
"[MISSING_AGGREGATION]"):
-                df.groupby(df.id).agg(plus_one(df.v)).collect()
-
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                AnalysisException, "aggregate function.*argument.*aggregate 
function"
-            ):
-                df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
-
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                AnalysisException,
-                "The group aggregate pandas UDF `avg` cannot be invoked 
together with as other, "
-                "non-pandas aggregate functions.",
-            ):
-                df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
+        with self.assertRaisesRegex(AnalysisException, 
"[MISSING_AGGREGATION]"):
+            df.groupby(df.id).agg(plus_one(df.v)).collect()
+        with self.assertRaisesRegex(
+            AnalysisException, "aggregate function.*argument.*aggregate 
function"
+        ):
+            df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "The group aggregate pandas UDF `avg` cannot be invoked together 
with as other, "
+            "non-pandas aggregate functions.",
+        ):
+            df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
 
     def test_register_vectorized_udf_basic(self):
         sum_pandas_udf = pandas_udf(
@@ -575,6 +571,10 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
         assert filtered.collect()[0]["mean"] == 42.0
 
 
+class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, 
ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_udf_grouped_agg import *  # 
noqa: F401
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 8fa6010a62f..b7f76635761 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -72,28 +72,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class ScalarPandasUDFTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        ReusedSQLTestCase.setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-        ReusedSQLTestCase.tearDownClass()
-
+class ScalarPandasUDFTestsMixin:
     @property
     def nondeterministic_vectorized_udf(self):
         import numpy as np
@@ -478,6 +457,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                 self.assertListEqual([i, i + 1], f[1])
 
     def test_vectorized_udf_nested_struct(self):
+        with QuietTest(self.sc):
+            self.check_vectorized_udf_nested_struct()
+
+    def check_vectorized_udf_nested_struct(self):
         nested_type = StructType(
             [
                 StructField("id", IntegerType()),
@@ -487,13 +470,9 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                 ),
             ]
         )
-
         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
-            with QuietTest(self.sc):
-                with self.assertRaisesRegex(
-                    Exception, "Invalid return type with scalar Pandas UDFs"
-                ):
-                    pandas_udf(lambda x: x, returnType=nested_type, 
functionType=udf_type)
+            with self.assertRaisesRegex(Exception, "Invalid return type with 
scalar Pandas UDFs"):
+                pandas_udf(lambda x: x, returnType=nested_type, 
functionType=udf_type)
 
     def test_vectorized_udf_map_type(self):
         data = [({},), ({"a": 1},), ({"a": 1, "b": 2},), ({"a": 1, "b": 2, 
"c": 3},)]
@@ -543,6 +522,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertEqual(expected.collect(), res.collect())
 
     def test_vectorized_udf_exception(self):
+        with QuietTest(self.sc):
+            self.check_vectorized_udf_exception()
+
+    def check_vectorized_udf_exception(self):
         df = self.spark.range(10)
         scalar_raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
 
@@ -552,29 +535,30 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                 yield x * (1 / 0)
 
         for raise_exception in [scalar_raise_exception, iter_raise_exception]:
-            with QuietTest(self.sc):
-                with self.assertRaisesRegex(Exception, "division( or modulo)? 
by zero"):
-                    df.select(raise_exception(col("id"))).collect()
+            with self.assertRaisesRegex(Exception, "division( or modulo)? by 
zero"):
+                df.select(raise_exception(col("id"))).collect()
 
     def test_vectorized_udf_invalid_length(self):
+        with QuietTest(self.sc):
+            self.check_vectorized_udf_invalid_length()
+
+    def check_vectorized_udf_invalid_length(self):
         df = self.spark.range(10)
         raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                Exception, "Result vector from pandas_udf was not the required 
length"
-            ):
-                df.select(raise_exception(col("id"))).collect()
+        with self.assertRaisesRegex(
+            Exception, "Result vector from pandas_udf was not the required 
length"
+        ):
+            df.select(raise_exception(col("id"))).collect()
 
         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
         def iter_udf_wong_output_size(it):
             for _ in it:
                 yield pd.Series(1)
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                Exception, "The length of output in Scalar iterator.*" "the 
length of output was 1"
-            ):
-                df.select(iter_udf_wong_output_size(col("id"))).collect()
+        with self.assertRaisesRegex(
+            Exception, "The length of output in Scalar iterator.*" "the length 
of output was 1"
+        ):
+            df.select(iter_udf_wong_output_size(col("id"))).collect()
 
         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
         def iter_udf_not_reading_all_input(it):
@@ -585,9 +569,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
         with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 
3}):
             df1 = self.spark.range(10).repartition(1)
-            with QuietTest(self.sc):
-                with self.assertRaisesRegex(Exception, "pandas iterator UDF 
should exhaust"):
-                    
df1.select(iter_udf_not_reading_all_input(col("id"))).collect()
+            with self.assertRaisesRegex(Exception, "pandas iterator UDF should 
exhaust"):
+                df1.select(iter_udf_not_reading_all_input(col("id"))).collect()
 
     def test_vectorized_udf_chained(self):
         df = self.spark.range(10)
@@ -632,23 +615,29 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
     def test_vectorized_udf_wrong_return_type(self):
         with QuietTest(self.sc):
-            for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
-                with self.assertRaisesRegex(
-                    NotImplementedError,
-                    "Invalid return type.*scalar Pandas 
UDF.*ArrayType.*TimestampType",
-                ):
-                    pandas_udf(lambda x: x, ArrayType(TimestampType()), 
udf_type)
+            self.check_vectorized_udf_wrong_return_type()
+
+    def check_vectorized_udf_wrong_return_type(self):
+        for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
+            with self.assertRaisesRegex(
+                NotImplementedError,
+                "Invalid return type.*scalar Pandas 
UDF.*ArrayType.*TimestampType",
+            ):
+                pandas_udf(lambda x: x, ArrayType(TimestampType()), udf_type)
 
     def test_vectorized_udf_return_scalar(self):
+        with QuietTest(self.sc):
+            self.check_vectorized_udf_return_scalar()
+
+    def check_vectorized_udf_return_scalar(self):
         df = self.spark.range(10)
         scalar_f = pandas_udf(lambda x: 1.0, DoubleType())
         iter_f = pandas_udf(
             lambda it: map(lambda x: 1.0, it), DoubleType(), 
PandasUDFType.SCALAR_ITER
         )
         for f in [scalar_f, iter_f]:
-            with QuietTest(self.sc):
-                with self.assertRaisesRegex(Exception, "Return.*type.*Series"):
-                    df.select(f(col("id"))).collect()
+            with self.assertRaisesRegex(Exception, "Return.*type.*Series"):
+                df.select(f(col("id"))).collect()
 
     def test_vectorized_udf_decorator(self):
         df = self.spark.range(10)
@@ -694,7 +683,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertEqual("Doe", row[0]["last"])
 
     def test_vectorized_udf_varargs(self):
-        df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
+        df = self.spark.range(start=1, end=2)
         scalar_f = pandas_udf(lambda *v: v[0], LongType())
 
         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
@@ -941,16 +930,19 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertTrue(result1["plus_ten(rand)"].equals(result1["rand"] + 
10))
 
     def test_nondeterministic_vectorized_udf_in_aggregate(self):
+        with QuietTest(self.sc):
+            self.check_nondeterministic_analysis_exception()
+
+    def check_nondeterministic_analysis_exception(self):
         df = self.spark.range(10)
         for random_udf in [
             self.nondeterministic_vectorized_udf,
             self.nondeterministic_vectorized_iter_udf,
         ]:
-            with QuietTest(self.sc):
-                with self.assertRaisesRegex(AnalysisException, 
"nondeterministic"):
-                    df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
-                with self.assertRaisesRegex(AnalysisException, 
"nondeterministic"):
-                    df.agg(sum(random_udf(df.id))).collect()
+            with self.assertRaisesRegex(AnalysisException, "nondeterministic"):
+                df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
+            with self.assertRaisesRegex(AnalysisException, "nondeterministic"):
+                df.agg(sum(random_udf(df.id))).collect()
 
     def test_register_vectorized_udf_basic(self):
         df = self.spark.range(10).select(
@@ -999,6 +991,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             )
 
     def test_scalar_iter_udf_close(self):
+        with QuietTest(self.sc):
+            self.check_scalar_iter_udf_close()
+
+    def check_scalar_iter_udf_close(self):
         @pandas_udf("int", PandasUDFType.SCALAR_ITER)
         def test_close(batch_iter):
             try:
@@ -1007,9 +1003,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             finally:
                 raise RuntimeError("reached finally block")
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(Exception, "reached finally block"):
-                self.spark.range(1).select(test_close(col("id"))).collect()
+        with self.assertRaisesRegex(Exception, "reached finally block"):
+            self.spark.range(1).select(test_close(col("id"))).collect()
 
     @unittest.skip("LimitPushDown should push limits through Python UDFs so 
this won't occur")
     def test_scalar_iter_udf_close_early(self):
@@ -1217,6 +1212,8 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             self.assertEqual(expected_multi, df_multi_2.collect())
 
     def test_mixed_udf_and_sql(self):
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
         df = self.spark.range(0, 1).toDF("v")
 
         # Test mixture of UDFs, Pandas UDFs and SQL expression.
@@ -1227,7 +1224,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             return x + 1
 
         def f2(x):
-            assert type(x) == Column
+            assert type(x) in (Column, ConnectColumn)
             return x + 10
 
         @pandas_udf("int")
@@ -1351,6 +1348,29 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             shutil.rmtree(path)
 
 
+class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        ReusedSQLTestCase.tearDownClass()
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_udf_scalar import *  # noqa: F401
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 8da97dee83b..eb80dccd9b2 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -50,7 +50,7 @@ if have_pandas:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class WindowPandasUDFTests(ReusedSQLTestCase):
+class WindowPandasUDFTestsMixin:
     @property
     def data(self):
         return (
@@ -394,6 +394,10 @@ class WindowPandasUDFTests(ReusedSQLTestCase):
         assert_frame_equal(expected1.toPandas(), result1.toPandas())
 
 
+class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_udf_window import *  # noqa: F401
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to