This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new a86324cb52c Revert "[SPARK-40770][PYTHON] Improved error messages for 
applyInPandas for schema mismatch"
a86324cb52c is described below

commit a86324cb52ce341339389e1f4079297cd2ec9d76
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Feb 9 09:53:56 2023 +0900

    Revert "[SPARK-40770][PYTHON] Improved error messages for applyInPandas for 
schema mismatch"
    
    This reverts commit c4c28cfbfe7b3f58f08c93d2c1cd421c302b0cd3.
---
 python/pyspark/sql/pandas/serializers.py           |  29 +-
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  | 317 ++++++++-------------
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 183 ++++--------
 .../pandas/test_pandas_grouped_map_with_state.py   |   2 +-
 python/pyspark/sql/tests/test_arrow.py             |  17 +-
 python/pyspark/worker.py                           | 108 +++----
 6 files changed, 232 insertions(+), 424 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 30c2d102456..ca249c75ea5 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -231,25 +231,18 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
                 s = s.astype(s.dtypes.categories.dtype)
             try:
                 array = pa.Array.from_pandas(s, mask=mask, type=t, 
safe=self._safecheck)
-            except TypeError as e:
-                error_msg = (
-                    "Exception thrown when converting pandas.Series (%s) "
-                    "with name '%s' to Arrow Array (%s)."
-                )
-                raise TypeError(error_msg % (s.dtype, s.name, t)) from e
             except ValueError as e:
-                error_msg = (
-                    "Exception thrown when converting pandas.Series (%s) "
-                    "with name '%s' to Arrow Array (%s)."
-                )
                 if self._safecheck:
-                    error_msg = error_msg + (
-                        " It can be caused by overflows or other "
-                        "unsafe conversions warned by Arrow. Arrow safe type 
check "
-                        "can be disabled by using SQL config "
-                        
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                    error_msg = (
+                        "Exception thrown when converting pandas.Series (%s) 
to "
+                        + "Arrow Array (%s). It can be caused by overflows or 
other "
+                        + "unsafe conversions warned by Arrow. Arrow safe type 
check "
+                        + "can be disabled by using SQL config "
+                        + 
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
                     )
-                raise ValueError(error_msg % (s.dtype, s.name, t)) from e
+                    raise ValueError(error_msg % (s.dtype, t)) from e
+                else:
+                    raise e
             return array
 
         arrs = []
@@ -272,9 +265,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
                 # Assign result columns by  position
                 else:
                     arrs_names = [
-                        # the selected series has name '1', so we rename it to 
field.name
-                        # as the name is used by create_array to provide a 
meaningful error message
-                        (create_array(s[s.columns[i]].rename(field.name), 
field.type), field.name)
+                        (create_array(s[s.columns[i]], field.type), field.name)
                         for i, field in enumerate(t)
                     ]
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 47ed12d2f46..5cbc9e1caa4 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -43,7 +43,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class CogroupedApplyInPandasTests(ReusedSQLTestCase):
+class CogroupedMapInPandasTests(ReusedSQLTestCase):
     @property
     def data1(self):
         return (
@@ -79,9 +79,7 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
 
     def test_different_schemas(self):
         right = self.data2.withColumn("v3", lit("a"))
-        self._test_merge(
-            self.data1, right, output_schema="id long, k int, v int, v2 int, 
v3 string"
-        )
+        self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3 
string")
 
     def test_different_keys(self):
         left = self.data1
@@ -130,7 +128,26 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
         assert_frame_equal(expected, result)
 
     def test_empty_group_by(self):
-        self._test_merge(self.data1, self.data2, by=[])
+        left = self.data1
+        right = self.data2
+
+        def merge_pandas(lft, rgt):
+            return pd.merge(lft, rgt, on=["id", "k"])
+
+        result = (
+            left.groupby()
+            .cogroup(right.groupby())
+            .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
+            .sort(["id", "k"])
+            .toPandas()
+        )
+
+        left = left.toPandas()
+        right = right.toPandas()
+
+        expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", 
"k"])
+
+        assert_frame_equal(expected, result)
 
     def test_different_group_key_cardinality(self):
         left = self.data1
@@ -149,35 +166,29 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
                 )
 
     def test_apply_in_pandas_not_returning_pandas_dataframe(self):
-        self._test_merge_error(
-            fn=lambda lft, rgt: lft.size + rgt.size,
-            error_class=PythonException,
-            error_message_regex="Return type of the user-defined function "
-            "should be pandas.DataFrame, but is <class 'numpy.int64'>",
-        )
-
-    def test_apply_in_pandas_returning_column_names(self):
-        self._test_merge(fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", 
"k"]))
+        left = self.data1
+        right = self.data2
 
-    def test_apply_in_pandas_returning_no_column_names(self):
         def merge_pandas(lft, rgt):
-            res = pd.merge(lft, rgt, on=["id", "k"])
-            res.columns = range(res.columns.size)
-            return res
-
-        self._test_merge(fn=merge_pandas)
+            return lft.size + rgt.size
 
-    def test_apply_in_pandas_returning_column_names_sometimes(self):
-        def merge_pandas(lft, rgt):
-            res = pd.merge(lft, rgt, on=["id", "k"])
-            if 0 in lft["id"] and lft["id"][0] % 2 == 0:
-                return res
-            res.columns = range(res.columns.size)
-            return res
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Return type of the user-defined function should be 
pandas.DataFrame, "
+                "but is <class 'numpy.int64'>",
+            ):
+                (
+                    left.groupby("id")
+                    .cogroup(right.groupby("id"))
+                    .applyInPandas(merge_pandas, "id long, k int, v int, v2 
int")
+                    .collect()
+                )
 
-        self._test_merge(fn=merge_pandas)
+    def test_apply_in_pandas_returning_wrong_number_of_columns(self):
+        left = self.data1
+        right = self.data2
 
-    def test_apply_in_pandas_returning_wrong_column_names(self):
         def merge_pandas(lft, rgt):
             if 0 in lft["id"] and lft["id"][0] % 2 == 0:
                 lft["add"] = 0
@@ -185,77 +196,70 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
                 rgt["more"] = 1
             return pd.merge(lft, rgt, on=["id", "k"])
 
-        self._test_merge_error(
-            fn=merge_pandas,
-            error_class=PythonException,
-            error_message_regex="Column names of the returned pandas.DataFrame 
"
-            "do not match specified schema. Unexpected: add, more.\n",
-        )
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Number of columns of the returned pandas.DataFrame "
+                "doesn't match specified schema. Expected: 4 Actual: 6",
+            ):
+                (
+                    # merge_pandas returns two columns for even keys while we 
set schema to four
+                    left.groupby("id")
+                    .cogroup(right.groupby("id"))
+                    .applyInPandas(merge_pandas, "id long, k int, v int, v2 
int")
+                    .collect()
+                )
+
+    def test_apply_in_pandas_returning_empty_dataframe(self):
+        left = self.data1
+        right = self.data2
 
-    def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
         def merge_pandas(lft, rgt):
             if 0 in lft["id"] and lft["id"][0] % 2 == 0:
-                lft[3] = 0
+                return pd.DataFrame([])
             if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
-                rgt[3] = 1
-            res = pd.merge(lft, rgt, on=["id", "k"])
-            res.columns = range(res.columns.size)
-            return res
-
-        self._test_merge_error(
-            fn=merge_pandas,
-            error_class=PythonException,
-            error_message_regex="Number of columns of the returned 
pandas.DataFrame "
-            "doesn't match specified schema. Expected: 4 Actual: 6\n",
+                return pd.DataFrame([])
+            return pd.merge(lft, rgt, on=["id", "k"])
+
+        result = (
+            left.groupby("id")
+            .cogroup(right.groupby("id"))
+            .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
+            .sort(["id", "k"])
+            .toPandas()
         )
 
-    def test_apply_in_pandas_returning_empty_dataframe(self):
+        left = left.toPandas()
+        right = right.toPandas()
+
+        expected = pd.merge(
+            left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", 
"k"]
+        ).sort_values(by=["id", "k"])
+
+        assert_frame_equal(expected, result)
+
+    def 
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
+        left = self.data1
+        right = self.data2
+
         def merge_pandas(lft, rgt):
             if 0 in lft["id"] and lft["id"][0] % 2 == 0:
-                return pd.DataFrame()
-            if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
-                return pd.DataFrame()
+                return pd.DataFrame([], columns=["id", "k"])
             return pd.merge(lft, rgt, on=["id", "k"])
 
-        self._test_merge_empty(fn=merge_pandas)
-
-    def test_apply_in_pandas_returning_incompatible_type(self):
-        for safely in [True, False]:
-            with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
-                {"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}
-            ), QuietTest(self.sc):
-                # sometimes we see ValueErrors
-                with self.subTest(convert="string to double"):
-                    expected = (
-                        r"ValueError: Exception thrown when converting 
pandas.Series \(object\) "
-                        r"with name 'k' to Arrow Array \(double\)."
-                    )
-                    if safely:
-                        expected = expected + (
-                            " It can be caused by overflows or other "
-                            "unsafe conversions warned by Arrow. Arrow safe 
type check "
-                            "can be disabled by using SQL config "
-                            
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
-                        )
-                    self._test_merge_error(
-                        fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
["2.0"]}),
-                        output_schema="id long, k double",
-                        error_class=PythonException,
-                        error_message_regex=expected,
-                    )
-
-                # sometimes we see TypeErrors
-                with self.subTest(convert="double to string"):
-                    expected = (
-                        r"TypeError: Exception thrown when converting 
pandas.Series \(float64\) "
-                        r"with name 'k' to Arrow Array \(string\).\n"
-                    )
-                    self._test_merge_error(
-                        fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
[2.0]}),
-                        output_schema="id long, k string",
-                        error_class=PythonException,
-                        error_message_regex=expected,
-                    )
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Number of columns of the returned pandas.DataFrame doesn't "
+                "match specified schema. Expected: 4 Actual: 2",
+            ):
+                (
+                    # merge_pandas returns two columns for even keys while we 
set schema to four
+                    left.groupby("id")
+                    .cogroup(right.groupby("id"))
+                    .applyInPandas(merge_pandas, "id long, k int, v int, v2 
int")
+                    .collect()
+                )
 
     def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
         df = self.spark.range(0, 10).toDF("v1")
@@ -308,20 +312,23 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
 
     def test_wrong_return_type(self):
         # Test that we get a sensible exception invalid values passed to apply
-        self._test_merge_error(
-            fn=lambda l, r: l,
-            output_schema="id long, v array<timestamp>",
-            error_class=NotImplementedError,
-            error_message_regex="Invalid return 
type.*ArrayType.*TimestampType",
-        )
+        left = self.data1
+        right = self.data2
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                NotImplementedError, "Invalid return 
type.*ArrayType.*TimestampType"
+            ):
+                left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
+                    lambda l, r: l, "id long, v array<timestamp>"
+                )
 
     def test_wrong_args(self):
-        self.__test_merge_error(
-            fn=lambda: 1,
-            output_schema=StructType([StructField("d", DoubleType())]),
-            error_class=ValueError,
-            error_message_regex="Invalid function",
-        )
+        left = self.data1
+        right = self.data2
+        with self.assertRaisesRegex(ValueError, "Invalid function"):
+            left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
+                lambda: 1, StructType([StructField("d", DoubleType())])
+            )
 
     def test_case_insensitive_grouping_column(self):
         # SPARK-31915: case-insensitive grouping column should work.
@@ -427,51 +434,15 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
 
         assert_frame_equal(expected, result)
 
-    def _test_merge_empty(self, fn):
-        left = self.data1.toPandas()
-        right = self.data2.toPandas()
-
-        expected = pd.merge(
-            left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", 
"k"]
-        ).sort_values(by=["id", "k"])
-
-        self._test_merge(self.data1, self.data2, fn=fn, expected=expected)
-
-    def _test_merge(
-        self,
-        left=None,
-        right=None,
-        by=["id"],
-        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
-        output_schema="id long, k int, v int, v2 int",
-        expected=None,
-    ):
-        def fn_with_key(_, lft, rgt):
-            return fn(lft, rgt)
-
-        # Test fn with and without key argument
-        with self.subTest("without key"):
-            self.__test_merge(left, right, by, fn, output_schema, expected)
-        with self.subTest("with key"):
-            self.__test_merge(left, right, by, fn_with_key, output_schema, 
expected)
-
-    def __test_merge(
-        self,
-        left=None,
-        right=None,
-        by=["id"],
-        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
-        output_schema="id long, k int, v int, v2 int",
-        expected=None,
-    ):
-        # Test fn as is, cf. _test_merge
-        left = self.data1 if left is None else left
-        right = self.data2 if right is None else right
+    @staticmethod
+    def _test_merge(left, right, output_schema="id long, k int, v int, v2 
int"):
+        def merge_pandas(lft, rgt):
+            return pd.merge(lft, rgt, on=["id", "k"])
 
         result = (
-            left.groupby(*by)
-            .cogroup(right.groupby(*by))
-            .applyInPandas(fn, output_schema)
+            left.groupby("id")
+            .cogroup(right.groupby("id"))
+            .applyInPandas(merge_pandas, output_schema)
             .sort(["id", "k"])
             .toPandas()
         )
@@ -479,64 +450,10 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
         left = left.toPandas()
         right = right.toPandas()
 
-        expected = (
-            pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
-            if expected is None
-            else expected
-        )
+        expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", 
"k"])
 
         assert_frame_equal(expected, result)
 
-    def _test_merge_error(
-        self,
-        error_class,
-        error_message_regex,
-        left=None,
-        right=None,
-        by=["id"],
-        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
-        output_schema="id long, k int, v int, v2 int",
-    ):
-        def fn_with_key(_, lft, rgt):
-            return fn(lft, rgt)
-
-        # Test fn with and without key argument
-        with self.subTest("without key"):
-            self.__test_merge_error(
-                left=left,
-                right=right,
-                by=by,
-                fn=fn,
-                output_schema=output_schema,
-                error_class=error_class,
-                error_message_regex=error_message_regex,
-            )
-        with self.subTest("with key"):
-            self.__test_merge_error(
-                left=left,
-                right=right,
-                by=by,
-                fn=fn_with_key,
-                output_schema=output_schema,
-                error_class=error_class,
-                error_message_regex=error_message_regex,
-            )
-
-    def __test_merge_error(
-        self,
-        error_class,
-        error_message_regex,
-        left=None,
-        right=None,
-        by=["id"],
-        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
-        output_schema="id long, k int, v int, v2 int",
-    ):
-        # Test fn as is, cf. _test_merge_error
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(error_class, error_message_regex):
-                self.__test_merge(left, right, by, fn, output_schema)
-
 
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import *  # noqa: 
F401
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 88e68b04303..5f103c97926 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -73,7 +73,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class GroupedApplyInPandasTests(ReusedSQLTestCase):
+class GroupedMapInPandasTests(ReusedSQLTestCase):
     @property
     def data(self):
         return (
@@ -270,101 +270,79 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
         assert_frame_equal(expected, result)
 
     def test_apply_in_pandas_not_returning_pandas_dataframe(self):
+        df = self.data
+
+        def stats(key, _):
+            return key
+
         with QuietTest(self.sc):
             with self.assertRaisesRegex(
                 PythonException,
                 "Return type of the user-defined function should be 
pandas.DataFrame, "
                 "but is <class 'tuple'>",
             ):
-                self._test_apply_in_pandas(lambda key, pdf: key)
-
-    @staticmethod
-    def stats_with_column_names(key, pdf):
-        # order of column can be different to applyInPandas schema when column 
names are given
-        return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
-
-    @staticmethod
-    def stats_with_no_column_names(key, pdf):
-        # columns must be in order of applyInPandas schema when no columns 
given
-        return pd.DataFrame([key + (pdf.v.mean(),)])
+                df.groupby("id").applyInPandas(stats, schema="id integer, m 
double").collect()
 
-    def test_apply_in_pandas_returning_column_names(self):
-        
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_column_names)
-
-    def test_apply_in_pandas_returning_no_column_names(self):
-        
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_no_column_names)
+    def test_apply_in_pandas_returning_wrong_number_of_columns(self):
+        df = self.data
 
-    def test_apply_in_pandas_returning_column_names_sometimes(self):
         def stats(key, pdf):
-            if key[0] % 2:
-                return GroupedApplyInPandasTests.stats_with_column_names(key, 
pdf)
-            else:
-                return 
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
-
-        self._test_apply_in_pandas(stats)
+            v = pdf.v
+            # returning three columns
+            res = pd.DataFrame([key + (v.mean(), v.std())])
+            return res
 
-    def test_apply_in_pandas_returning_wrong_column_names(self):
         with QuietTest(self.sc):
             with self.assertRaisesRegex(
                 PythonException,
-                "Column names of the returned pandas.DataFrame do not match 
specified schema. "
-                "Missing: mean. Unexpected: median, std.\n",
+                "Number of columns of the returned pandas.DataFrame doesn't 
match "
+                "specified schema. Expected: 2 Actual: 3",
             ):
-                self._test_apply_in_pandas(
-                    lambda key, pdf: pd.DataFrame(
-                        [key + (pdf.v.median(), pdf.v.std())], columns=["id", 
"median", "std"]
-                    )
-                )
+                # stats returns three columns while here we set schema with 
two columns
+                df.groupby("id").applyInPandas(stats, schema="id integer, m 
double").collect()
+
+    def test_apply_in_pandas_returning_empty_dataframe(self):
+        df = self.data
+
+        def odd_means(key, pdf):
+            if key[0] % 2 == 0:
+                return pd.DataFrame([])
+            else:
+                return pd.DataFrame([key + (pdf.v.mean(),)])
+
+        expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 != 
0}
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(odd_means, schema="id integer, m double")
+            .sort("id", "m")
+            .collect()
+        )
+
+        actual_ids = {row[0] for row in result}
+        self.assertSetEqual(expected_ids, actual_ids)
+
+        self.assertEqual(len(expected_ids), len(result))
+        for row in result:
+            self.assertEqual(24.5, row[1])
+
+    def 
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
+        df = self.data
+
+        def odd_means(key, pdf):
+            if key[0] % 2 == 0:
+                return pd.DataFrame([], columns=["id"])
+            else:
+                return pd.DataFrame([key + (pdf.v.mean(),)])
 
-    def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
         with QuietTest(self.sc):
             with self.assertRaisesRegex(
                 PythonException,
                 "Number of columns of the returned pandas.DataFrame doesn't 
match "
-                "specified schema. Expected: 2 Actual: 3\n",
+                "specified schema. Expected: 2 Actual: 1",
             ):
-                self._test_apply_in_pandas(
-                    lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), 
pdf.v.std())])
-                )
-
-    def test_apply_in_pandas_returning_empty_dataframe(self):
-        self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame())
-
-    def test_apply_in_pandas_returning_incompatible_type(self):
-        for safely in [True, False]:
-            with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
-                {"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}
-            ), QuietTest(self.sc):
-                # sometimes we see ValueErrors
-                with self.subTest(convert="string to double"):
-                    expected = (
-                        r"ValueError: Exception thrown when converting 
pandas.Series \(object\) "
-                        r"with name 'mean' to Arrow Array \(double\)."
-                    )
-                    if safely:
-                        expected = expected + (
-                            " It can be caused by overflows or other "
-                            "unsafe conversions warned by Arrow. Arrow safe 
type check "
-                            "can be disabled by using SQL config "
-                            
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
-                        )
-                    with self.assertRaisesRegex(PythonException, expected + 
"\n"):
-                        self._test_apply_in_pandas(
-                            lambda key, pdf: pd.DataFrame([key + 
(str(pdf.v.mean()),)]),
-                            output_schema="id long, mean double",
-                        )
-
-                # sometimes we see TypeErrors
-                with self.subTest(convert="double to string"):
-                    with self.assertRaisesRegex(
-                        PythonException,
-                        r"TypeError: Exception thrown when converting 
pandas.Series \(float64\) "
-                        r"with name 'mean' to Arrow Array \(string\).\n",
-                    ):
-                        self._test_apply_in_pandas(
-                            lambda key, pdf: pd.DataFrame([key + 
(pdf.v.mean(),)]),
-                            output_schema="id long, mean string",
-                        )
+                # stats returns one column for even keys while here we set 
schema with two columns
+                df.groupby("id").applyInPandas(odd_means, schema="id integer, 
m double").collect()
 
     def test_datatype_string(self):
         df = self.data
@@ -588,11 +566,7 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
 
         with 
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
             with QuietTest(self.sc):
-                with self.assertRaisesRegex(
-                    PythonException,
-                    "RuntimeError: Column names of the returned 
pandas.DataFrame do not match "
-                    "specified schema. Missing: id. Unexpected: iid.\n",
-                ):
+                with self.assertRaisesRegex(Exception, "KeyError: 'id'"):
                     grouped_df.apply(column_name_typo).collect()
                 with self.assertRaisesRegex(Exception, 
"[D|d]ecimal.*got.*date"):
                     grouped_df.apply(invalid_positional_types).collect()
@@ -681,11 +655,10 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
             df.groupby("group", window("ts", "5 days"))
             .applyInPandas(f, df.schema)
             .select("id", "result")
-            .orderBy("id")
             .collect()
         )
-
-        self.assertListEqual([Row(id=key, result=val) for key, val in 
expected.items()], result)
+        for r in result:
+            self.assertListEqual(expected[r[0]], r[1])
 
     def test_grouped_over_window_with_key(self):
 
@@ -747,11 +720,11 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
             df.groupby("group", window("ts", "5 days"))
             .applyInPandas(f, df.schema)
             .select("id", "result")
-            .orderBy("id")
             .collect()
         )
 
-        self.assertListEqual([Row(id=key, result=val) for key, val in 
expected.items()], result)
+        for r in result:
+            self.assertListEqual(expected[r[0]], r[1])
 
     def test_case_insensitive_grouping_column(self):
         # SPARK-31915: case-insensitive grouping column should work.
@@ -766,44 +739,6 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
         )
         self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
 
-    def _test_apply_in_pandas(self, f, output_schema="id long, mean double"):
-        df = self.data
-
-        result = (
-            df.groupby("id").applyInPandas(f, schema=output_schema).sort("id", 
"mean").toPandas()
-        )
-        expected = df.select("id").distinct().withColumn("mean", 
lit(24.5)).toPandas()
-
-        assert_frame_equal(expected, result)
-
-    def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df):
-        """Tests some returned DataFrames are empty."""
-        df = self.data
-
-        def stats(key, pdf):
-            if key[0] % 2 == 0:
-                return 
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
-            return empty_df
-
-        result = (
-            df.groupby("id")
-            .applyInPandas(stats, schema="id long, mean double")
-            .sort("id", "mean")
-            .collect()
-        )
-
-        actual_ids = {row[0] for row in result}
-        expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 == 
0}
-        self.assertSetEqual(expected_ids, actual_ids)
-        self.assertEqual(len(expected_ids), len(result))
-        for row in result:
-            self.assertEqual(24.5, row[1])
-
-    def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df, 
error):
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(PythonException, error):
-                self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
-
 
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_grouped_map import *  # noqa: 
F401
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index 9600d1e3445..655f0bf151d 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -53,7 +53,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class GroupedApplyInPandasWithStateTests(ReusedSQLTestCase):
+class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
     @classmethod
     def conf(cls):
         cfg = SparkConf()
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index c61994380e6..6083f31ac81 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -465,24 +465,9 @@ class ArrowTests(ReusedSQLTestCase):
         wrong_schema = StructType(fields)
         with 
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
             with QuietTest(self.sc):
-                with self.assertRaises(Exception) as context:
+                with self.assertRaisesRegex(Exception, 
"[D|d]ecimal.*got.*date"):
                     self.spark.createDataFrame(pdf, schema=wrong_schema)
 
-                # the exception provides us with the column that is incorrect
-                exception = context.exception
-                self.assertTrue(hasattr(exception, "args"))
-                self.assertEqual(len(exception.args), 1)
-                self.assertRegex(
-                    exception.args[0],
-                    "with name '7_date_t' " "to Arrow Array 
\\(decimal128\\(38, 18\\)\\)",
-                )
-
-                # the inner exception provides us with the incorrect types
-                exception = exception.__context__
-                self.assertTrue(hasattr(exception, "args"))
-                self.assertEqual(len(exception.args), 1)
-                self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date")
-
     def test_createDataFrame_with_names(self):
         pdf = self.create_pandas_data_frame()
         new_names = list(map(str, range(len(self.schema.fieldNames()))))
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f7d98a9a18c..c1c3669701f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -146,49 +146,7 @@ def wrap_batch_iter_udf(f, return_type):
     )
 
 
-def verify_pandas_result(result, return_type, assign_cols_by_name):
-    import pandas as pd
-
-    if not isinstance(result, pd.DataFrame):
-        raise TypeError(
-            "Return type of the user-defined function should be "
-            "pandas.DataFrame, but is {}".format(type(result))
-        )
-
-    # check the schema of the result only if it is not empty or has columns
-    if not result.empty or len(result.columns) != 0:
-        # if any column name of the result is a string
-        # the column names of the result have to match the return type
-        #   see create_array in 
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
-        field_names = set([field.name for field in return_type.fields])
-        column_names = set(result.columns)
-        if (
-            assign_cols_by_name
-            and any(isinstance(name, str) for name in result.columns)
-            and column_names != field_names
-        ):
-            missing = sorted(list(field_names.difference(column_names)))
-            missing = f" Missing: {', '.join(missing)}." if missing else ""
-
-            extra = sorted(list(column_names.difference(field_names)))
-            extra = f" Unexpected: {', '.join(extra)}." if extra else ""
-
-            raise RuntimeError(
-                "Column names of the returned pandas.DataFrame do not match 
specified schema."
-                "{}{}".format(missing, extra)
-            )
-        # otherwise the number of columns of result have to match the return 
type
-        elif len(result.columns) != len(return_type):
-            raise RuntimeError(
-                "Number of columns of the returned pandas.DataFrame "
-                "doesn't match specified schema. "
-                "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
-            )
-
-
-def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
-    _assign_cols_by_name = assign_cols_by_name(runner_conf)
-
+def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
     def wrapped(left_key_series, left_value_series, right_key_series, 
right_value_series):
         import pandas as pd
 
@@ -201,16 +159,27 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, 
argspec, runner_conf):
             key_series = left_key_series if not left_df.empty else 
right_key_series
             key = tuple(s[0] for s in key_series)
             result = f(key, left_df, right_df)
-        verify_pandas_result(result, return_type, _assign_cols_by_name)
-
+        if not isinstance(result, pd.DataFrame):
+            raise TypeError(
+                "Return type of the user-defined function should be "
+                "pandas.DataFrame, but is {}".format(type(result))
+            )
+        # the number of columns of result have to match the return type
+        # but it is fine for result to have no columns at all if it is empty
+        if not (
+            len(result.columns) == len(return_type) or len(result.columns) == 
0 and result.empty
+        ):
+            raise RuntimeError(
+                "Number of columns of the returned pandas.DataFrame "
+                "doesn't match specified schema. "
+                "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
+            )
         return result
 
     return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), 
to_arrow_type(return_type))]
 
 
-def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
-    _assign_cols_by_name = assign_cols_by_name(runner_conf)
-
+def wrap_grouped_map_pandas_udf(f, return_type, argspec):
     def wrapped(key_series, value_series):
         import pandas as pd
 
@@ -219,8 +188,22 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
         elif len(argspec.args) == 2:
             key = tuple(s[0] for s in key_series)
             result = f(key, pd.concat(value_series, axis=1))
-        verify_pandas_result(result, return_type, _assign_cols_by_name)
 
+        if not isinstance(result, pd.DataFrame):
+            raise TypeError(
+                "Return type of the user-defined function should be "
+                "pandas.DataFrame, but is {}".format(type(result))
+            )
+        # the number of columns of result have to match the return type
+        # but it is fine for result to have no columns at all if it is empty
+        if not (
+            len(result.columns) == len(return_type) or len(result.columns) == 
0 and result.empty
+        ):
+            raise RuntimeError(
+                "Number of columns of the returned pandas.DataFrame "
+                "doesn't match specified schema. "
+                "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
+            )
         return result
 
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
@@ -413,12 +396,12 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         return arg_offsets, wrap_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
         return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type)
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
+        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
         return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -429,16 +412,6 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         raise ValueError("Unknown eval type: {}".format(eval_type))
 
 
-# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning 
StructType
-def assign_cols_by_name(runner_conf):
-    return (
-        runner_conf.get(
-            
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
-        ).lower()
-        == "true"
-    )
-
-
 def read_udfs(pickleSer, infile, eval_type):
     runner_conf = {}
 
@@ -471,9 +444,16 @@ def read_udfs(pickleSer, infile, eval_type):
             
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
             == "true"
         )
+        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when 
returning StructType
+        assign_cols_by_name = (
+            runner_conf.get(
+                
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
+            ).lower()
+            == "true"
+        )
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
-            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name(runner_conf))
+            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name)
         elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
             arrow_max_records_per_batch = runner_conf.get(
                 "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
@@ -483,7 +463,7 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = ApplyInPandasWithStateSerializer(
                 timezone,
                 safecheck,
-                assign_cols_by_name(runner_conf),
+                assign_cols_by_name,
                 state_object_schema,
                 arrow_max_records_per_batch,
             )
@@ -498,7 +478,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
             )
             ser = ArrowStreamPandasUDFSerializer(
-                timezone, safecheck, assign_cols_by_name(runner_conf), 
df_for_struct
+                timezone, safecheck, assign_cols_by_name, df_for_struct
             )
     else:
         ser = BatchedSerializer(CPickleSerializer(), 100)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to