This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bdf56c48357 [SPARK-40770][PYTHON] Improved error messages for 
applyInPandas for schema mismatch
bdf56c48357 is described below

commit bdf56c483574d347e5c283273c170377f20d0f10
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Thu Feb 9 09:52:11 2023 +0900

    [SPARK-40770][PYTHON] Improved error messages for applyInPandas for schema 
mismatch
    
    ### What changes were proposed in this pull request?
    Improve the error messages when a Python method provided to 
`DataFrame.groupby(...).applyInPandas` / 
`DataFrame.groupby(...).cogroup(...).applyInPandas` returns a Pandas DataFrame 
that does not match the expected schema.
    
    With
    ```Python
    gdf = spark.range(2).join(spark.range(3).withColumnRenamed("id", 
"val")).groupby("id")
    ```
    
    **Mismatching column names, matching number of columns:**
    ```Python
    gdf.applyInPandas(lambda pdf: pdf.rename(columns={"val": "v"}), "id long, 
val long").show()
    # was: KeyError: 'val'
    # now: RuntimeError: Column names of the returned pandas.DataFrame do not 
match specified schema.
    #      Missing: val  Unexpected: v
    ```
    
    **Mismatching column names, different number of columns:**
    ```Python
    gdf.applyInPandas(lambda pdf: pdf.assign(foo=[3, 3, 
3]).rename(columns={"val": "v"}), "id long, val long").show()
    # was: RuntimeError: Number of columns of the returned pandas.DataFrame 
doesn't match specified schema.
    #      Expected: 2 Actual: 3
    # now: RuntimeError: Column names of the returned pandas.DataFrame do not 
match specified schema.
    #      Missing: val  Unexpected: foo, v
    ```
    
    **Expected schema matches but has duplicates (`id`) so that number of 
columns match:**
    ```Python
    gdf.applyInPandas(lambda pdf: pdf.rename(columns={"val": "v"}), "id long, 
id long").show()
    # was: java.lang.IllegalArgumentException: not all nodes and buffers were 
consumed.
    #      nodes: [ArrowFieldNode [length=3, nullCount=0]]
    #      buffers: [ArrowBuf[304], address:139860828549160, length:0, 
ArrowBuf[305], address:139860828549160, length:24]
    # now: RuntimeError: Column names of the returned pandas.DataFrame do not 
match specified schema.
    #      Unexpected: v
    ```
    
    **In case the returned Pandas DataFrame contains no column names (none of 
the column labels is a string):**
    ```Python
    gdf.applyInPandas(lambda pdf: pdf.assign(foo=[3, 3, 
3]).rename(columns={"id": 0, "val": 1, "foo": 2}), "id long, val long").show()
    # was: RuntimeError: Number of columns of the returned pandas.DataFrame 
doesn't match specified schema.
    #      Expected: 2 Actual: 3
    # now: RuntimeError: Number of columns of the returned pandas.DataFrame 
doesn't match specified schema.
    #      Expected: 2  Actual: 3
    ```
    
    **Mismatching types (ValueError and TypeError):**
    ```Python
    gdf.applyInPandas(lambda pdf: pdf, "id int, val string").show()
    # was: pyarrow.lib.ArrowTypeError: Expected a string or bytes dtype, got 
int64
    # now: pyarrow.lib.ArrowTypeError: Expected a string or bytes dtype, got 
int64
    #      The above exception was the direct cause of the following exception:
    #      TypeError: Exception thrown when converting pandas.Series (int64) 
with name 'val' to Arrow Array (string).
    
    gdf.applyInPandas(lambda pdf: pdf.assign(val=pdf["val"].apply(str)), "id 
int, val double").show()
    # was: pyarrow.lib.ArrowInvalid: Could not convert '0' with type str: tried 
to convert to double
    # now: pyarrow.lib.ArrowInvalid: Could not convert '0' with type str: tried 
to convert to double
    #      The above exception was the direct cause of the following exception:
    #      ValueError: Exception thrown when converting pandas.Series (object) 
with name 'val' to Arrow Array (double).
    
    with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}):
      gdf.applyInPandas(lambda pdf: pdf.assign(val=pdf["val"].apply(str)), "id 
int, val double").show()
    # was: ValueError: Exception thrown when converting pandas.Series (object) 
to Arrow Array (double).
    #      It can be caused by overflows or other unsafe conversions warned by 
Arrow. Arrow safe type check can be disabled
    #      by using SQL config 
`spark.sql.execution.pandas.convertToArrowArraySafely`.
    # now: ValueError: Exception thrown when converting pandas.Series (object) 
with name 'val' to Arrow Array (double).
    #      It can be caused by overflows or other unsafe conversions warned by 
Arrow. Arrow safe type check can be disabled
    #      by using SQL config 
`spark.sql.execution.pandas.convertToArrowArraySafely`.
    ```
    
    ### Why are the changes needed?
    Existing errors are generic (`KeyError`) or meaningless (`not all nodes and 
buffers were consumed`). The errors should help users in spotting the 
mismatching columns by naming them.
    
    The schema of the returned Pandas DataFrames can only be checked during 
processing the DataFrame, so such errors are very expensive. Therefore, they 
should be expressive.
    
    ### Does this PR introduce _any_ user-facing change?
    This only changes error messages, not behaviour.
    
    ### How was this patch tested?
    Tests all cases of schema mismatch for `GroupedData.applyInPandas` and 
`PandasCogroupedOps.applyInPandas`.
    
    Closes #38223 from EnricoMi/branch-pyspark-apply-in-pandas-schema-mismatch.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py           |  29 +-
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  | 317 +++++++++++++--------
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 183 ++++++++----
 .../pandas/test_pandas_grouped_map_with_state.py   |   2 +-
 python/pyspark/sql/tests/test_arrow.py             |  17 +-
 python/pyspark/worker.py                           | 108 ++++---
 6 files changed, 424 insertions(+), 232 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index ca249c75ea5..30c2d102456 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -231,18 +231,25 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
                 s = s.astype(s.dtypes.categories.dtype)
             try:
                 array = pa.Array.from_pandas(s, mask=mask, type=t, 
safe=self._safecheck)
+            except TypeError as e:
+                error_msg = (
+                    "Exception thrown when converting pandas.Series (%s) "
+                    "with name '%s' to Arrow Array (%s)."
+                )
+                raise TypeError(error_msg % (s.dtype, s.name, t)) from e
             except ValueError as e:
+                error_msg = (
+                    "Exception thrown when converting pandas.Series (%s) "
+                    "with name '%s' to Arrow Array (%s)."
+                )
                 if self._safecheck:
-                    error_msg = (
-                        "Exception thrown when converting pandas.Series (%s) 
to "
-                        + "Arrow Array (%s). It can be caused by overflows or 
other "
-                        + "unsafe conversions warned by Arrow. Arrow safe type 
check "
-                        + "can be disabled by using SQL config "
-                        + 
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                    error_msg = error_msg + (
+                        " It can be caused by overflows or other "
+                        "unsafe conversions warned by Arrow. Arrow safe type 
check "
+                        "can be disabled by using SQL config "
+                        
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
                     )
-                    raise ValueError(error_msg % (s.dtype, t)) from e
-                else:
-                    raise e
+                raise ValueError(error_msg % (s.dtype, s.name, t)) from e
             return array
 
         arrs = []
@@ -265,7 +272,9 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
                 # Assign result columns by  position
                 else:
                     arrs_names = [
-                        (create_array(s[s.columns[i]], field.type), field.name)
+                        # the selected series has name '1', so we rename it to 
field.name
+                        # as the name is used by create_array to provide a 
meaningful error message
+                        (create_array(s[s.columns[i]].rename(field.name), 
field.type), field.name)
                         for i, field in enumerate(t)
                     ]
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 5cbc9e1caa4..47ed12d2f46 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -43,7 +43,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class CogroupedMapInPandasTests(ReusedSQLTestCase):
+class CogroupedApplyInPandasTests(ReusedSQLTestCase):
     @property
     def data1(self):
         return (
@@ -79,7 +79,9 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
 
     def test_different_schemas(self):
         right = self.data2.withColumn("v3", lit("a"))
-        self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3 
string")
+        self._test_merge(
+            self.data1, right, output_schema="id long, k int, v int, v2 int, 
v3 string"
+        )
 
     def test_different_keys(self):
         left = self.data1
@@ -128,26 +130,7 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
         assert_frame_equal(expected, result)
 
     def test_empty_group_by(self):
-        left = self.data1
-        right = self.data2
-
-        def merge_pandas(lft, rgt):
-            return pd.merge(lft, rgt, on=["id", "k"])
-
-        result = (
-            left.groupby()
-            .cogroup(right.groupby())
-            .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
-            .sort(["id", "k"])
-            .toPandas()
-        )
-
-        left = left.toPandas()
-        right = right.toPandas()
-
-        expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", 
"k"])
-
-        assert_frame_equal(expected, result)
+        self._test_merge(self.data1, self.data2, by=[])
 
     def test_different_group_key_cardinality(self):
         left = self.data1
@@ -166,29 +149,35 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
                 )
 
     def test_apply_in_pandas_not_returning_pandas_dataframe(self):
-        left = self.data1
-        right = self.data2
+        self._test_merge_error(
+            fn=lambda lft, rgt: lft.size + rgt.size,
+            error_class=PythonException,
+            error_message_regex="Return type of the user-defined function "
+            "should be pandas.DataFrame, but is <class 'numpy.int64'>",
+        )
+
+    def test_apply_in_pandas_returning_column_names(self):
+        self._test_merge(fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", 
"k"]))
 
+    def test_apply_in_pandas_returning_no_column_names(self):
         def merge_pandas(lft, rgt):
-            return lft.size + rgt.size
+            res = pd.merge(lft, rgt, on=["id", "k"])
+            res.columns = range(res.columns.size)
+            return res
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                PythonException,
-                "Return type of the user-defined function should be 
pandas.DataFrame, "
-                "but is <class 'numpy.int64'>",
-            ):
-                (
-                    left.groupby("id")
-                    .cogroup(right.groupby("id"))
-                    .applyInPandas(merge_pandas, "id long, k int, v int, v2 
int")
-                    .collect()
-                )
+        self._test_merge(fn=merge_pandas)
 
-    def test_apply_in_pandas_returning_wrong_number_of_columns(self):
-        left = self.data1
-        right = self.data2
+    def test_apply_in_pandas_returning_column_names_sometimes(self):
+        def merge_pandas(lft, rgt):
+            res = pd.merge(lft, rgt, on=["id", "k"])
+            if 0 in lft["id"] and lft["id"][0] % 2 == 0:
+                return res
+            res.columns = range(res.columns.size)
+            return res
+
+        self._test_merge(fn=merge_pandas)
 
+    def test_apply_in_pandas_returning_wrong_column_names(self):
         def merge_pandas(lft, rgt):
             if 0 in lft["id"] and lft["id"][0] % 2 == 0:
                 lft["add"] = 0
@@ -196,70 +185,77 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
                 rgt["more"] = 1
             return pd.merge(lft, rgt, on=["id", "k"])
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                PythonException,
-                "Number of columns of the returned pandas.DataFrame "
-                "doesn't match specified schema. Expected: 4 Actual: 6",
-            ):
-                (
-                    # merge_pandas returns two columns for even keys while we 
set schema to four
-                    left.groupby("id")
-                    .cogroup(right.groupby("id"))
-                    .applyInPandas(merge_pandas, "id long, k int, v int, v2 
int")
-                    .collect()
-                )
-
-    def test_apply_in_pandas_returning_empty_dataframe(self):
-        left = self.data1
-        right = self.data2
+        self._test_merge_error(
+            fn=merge_pandas,
+            error_class=PythonException,
+            error_message_regex="Column names of the returned pandas.DataFrame 
"
+            "do not match specified schema. Unexpected: add, more.\n",
+        )
 
+    def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
         def merge_pandas(lft, rgt):
             if 0 in lft["id"] and lft["id"][0] % 2 == 0:
-                return pd.DataFrame([])
+                lft[3] = 0
             if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
-                return pd.DataFrame([])
-            return pd.merge(lft, rgt, on=["id", "k"])
-
-        result = (
-            left.groupby("id")
-            .cogroup(right.groupby("id"))
-            .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
-            .sort(["id", "k"])
-            .toPandas()
+                rgt[3] = 1
+            res = pd.merge(lft, rgt, on=["id", "k"])
+            res.columns = range(res.columns.size)
+            return res
+
+        self._test_merge_error(
+            fn=merge_pandas,
+            error_class=PythonException,
+            error_message_regex="Number of columns of the returned 
pandas.DataFrame "
+            "doesn't match specified schema. Expected: 4 Actual: 6\n",
         )
 
-        left = left.toPandas()
-        right = right.toPandas()
-
-        expected = pd.merge(
-            left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", 
"k"]
-        ).sort_values(by=["id", "k"])
-
-        assert_frame_equal(expected, result)
-
-    def 
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
-        left = self.data1
-        right = self.data2
-
+    def test_apply_in_pandas_returning_empty_dataframe(self):
         def merge_pandas(lft, rgt):
             if 0 in lft["id"] and lft["id"][0] % 2 == 0:
-                return pd.DataFrame([], columns=["id", "k"])
+                return pd.DataFrame()
+            if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
+                return pd.DataFrame()
             return pd.merge(lft, rgt, on=["id", "k"])
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                PythonException,
-                "Number of columns of the returned pandas.DataFrame doesn't "
-                "match specified schema. Expected: 4 Actual: 2",
-            ):
-                (
-                    # merge_pandas returns two columns for even keys while we 
set schema to four
-                    left.groupby("id")
-                    .cogroup(right.groupby("id"))
-                    .applyInPandas(merge_pandas, "id long, k int, v int, v2 
int")
-                    .collect()
-                )
+        self._test_merge_empty(fn=merge_pandas)
+
+    def test_apply_in_pandas_returning_incompatible_type(self):
+        for safely in [True, False]:
+            with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
+                {"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}
+            ), QuietTest(self.sc):
+                # sometimes we see ValueErrors
+                with self.subTest(convert="string to double"):
+                    expected = (
+                        r"ValueError: Exception thrown when converting 
pandas.Series \(object\) "
+                        r"with name 'k' to Arrow Array \(double\)."
+                    )
+                    if safely:
+                        expected = expected + (
+                            " It can be caused by overflows or other "
+                            "unsafe conversions warned by Arrow. Arrow safe 
type check "
+                            "can be disabled by using SQL config "
+                            
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                        )
+                    self._test_merge_error(
+                        fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
["2.0"]}),
+                        output_schema="id long, k double",
+                        error_class=PythonException,
+                        error_message_regex=expected,
+                    )
+
+                # sometimes we see TypeErrors
+                with self.subTest(convert="double to string"):
+                    expected = (
+                        r"TypeError: Exception thrown when converting 
pandas.Series \(float64\) "
+                        r"with name 'k' to Arrow Array \(string\).\n"
+                    )
+                    self._test_merge_error(
+                        fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
[2.0]}),
+                        output_schema="id long, k string",
+                        error_class=PythonException,
+                        error_message_regex=expected,
+                    )
 
     def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
         df = self.spark.range(0, 10).toDF("v1")
@@ -312,23 +308,20 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
 
     def test_wrong_return_type(self):
         # Test that we get a sensible exception invalid values passed to apply
-        left = self.data1
-        right = self.data2
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                NotImplementedError, "Invalid return 
type.*ArrayType.*TimestampType"
-            ):
-                left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
-                    lambda l, r: l, "id long, v array<timestamp>"
-                )
+        self._test_merge_error(
+            fn=lambda l, r: l,
+            output_schema="id long, v array<timestamp>",
+            error_class=NotImplementedError,
+            error_message_regex="Invalid return 
type.*ArrayType.*TimestampType",
+        )
 
     def test_wrong_args(self):
-        left = self.data1
-        right = self.data2
-        with self.assertRaisesRegex(ValueError, "Invalid function"):
-            left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
-                lambda: 1, StructType([StructField("d", DoubleType())])
-            )
+        self.__test_merge_error(
+            fn=lambda: 1,
+            output_schema=StructType([StructField("d", DoubleType())]),
+            error_class=ValueError,
+            error_message_regex="Invalid function",
+        )
 
     def test_case_insensitive_grouping_column(self):
         # SPARK-31915: case-insensitive grouping column should work.
@@ -434,15 +427,51 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
 
         assert_frame_equal(expected, result)
 
-    @staticmethod
-    def _test_merge(left, right, output_schema="id long, k int, v int, v2 
int"):
-        def merge_pandas(lft, rgt):
-            return pd.merge(lft, rgt, on=["id", "k"])
+    def _test_merge_empty(self, fn):
+        left = self.data1.toPandas()
+        right = self.data2.toPandas()
+
+        expected = pd.merge(
+            left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id", 
"k"]
+        ).sort_values(by=["id", "k"])
+
+        self._test_merge(self.data1, self.data2, fn=fn, expected=expected)
+
+    def _test_merge(
+        self,
+        left=None,
+        right=None,
+        by=["id"],
+        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
+        output_schema="id long, k int, v int, v2 int",
+        expected=None,
+    ):
+        def fn_with_key(_, lft, rgt):
+            return fn(lft, rgt)
+
+        # Test fn with and without key argument
+        with self.subTest("without key"):
+            self.__test_merge(left, right, by, fn, output_schema, expected)
+        with self.subTest("with key"):
+            self.__test_merge(left, right, by, fn_with_key, output_schema, 
expected)
+
+    def __test_merge(
+        self,
+        left=None,
+        right=None,
+        by=["id"],
+        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
+        output_schema="id long, k int, v int, v2 int",
+        expected=None,
+    ):
+        # Test fn as is, cf. _test_merge
+        left = self.data1 if left is None else left
+        right = self.data2 if right is None else right
 
         result = (
-            left.groupby("id")
-            .cogroup(right.groupby("id"))
-            .applyInPandas(merge_pandas, output_schema)
+            left.groupby(*by)
+            .cogroup(right.groupby(*by))
+            .applyInPandas(fn, output_schema)
             .sort(["id", "k"])
             .toPandas()
         )
@@ -450,10 +479,64 @@ class CogroupedMapInPandasTests(ReusedSQLTestCase):
         left = left.toPandas()
         right = right.toPandas()
 
-        expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", 
"k"])
+        expected = (
+            pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
+            if expected is None
+            else expected
+        )
 
         assert_frame_equal(expected, result)
 
+    def _test_merge_error(
+        self,
+        error_class,
+        error_message_regex,
+        left=None,
+        right=None,
+        by=["id"],
+        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
+        output_schema="id long, k int, v int, v2 int",
+    ):
+        def fn_with_key(_, lft, rgt):
+            return fn(lft, rgt)
+
+        # Test fn with and without key argument
+        with self.subTest("without key"):
+            self.__test_merge_error(
+                left=left,
+                right=right,
+                by=by,
+                fn=fn,
+                output_schema=output_schema,
+                error_class=error_class,
+                error_message_regex=error_message_regex,
+            )
+        with self.subTest("with key"):
+            self.__test_merge_error(
+                left=left,
+                right=right,
+                by=by,
+                fn=fn_with_key,
+                output_schema=output_schema,
+                error_class=error_class,
+                error_message_regex=error_message_regex,
+            )
+
+    def __test_merge_error(
+        self,
+        error_class,
+        error_message_regex,
+        left=None,
+        right=None,
+        by=["id"],
+        fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
+        output_schema="id long, k int, v int, v2 int",
+    ):
+        # Test fn as is, cf. _test_merge_error
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(error_class, error_message_regex):
+                self.__test_merge(left, right, by, fn, output_schema)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import *  # noqa: 
F401
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 5f103c97926..88e68b04303 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -73,7 +73,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class GroupedMapInPandasTests(ReusedSQLTestCase):
+class GroupedApplyInPandasTests(ReusedSQLTestCase):
     @property
     def data(self):
         return (
@@ -270,79 +270,101 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
         assert_frame_equal(expected, result)
 
     def test_apply_in_pandas_not_returning_pandas_dataframe(self):
-        df = self.data
-
-        def stats(key, _):
-            return key
-
         with QuietTest(self.sc):
             with self.assertRaisesRegex(
                 PythonException,
                 "Return type of the user-defined function should be 
pandas.DataFrame, "
                 "but is <class 'tuple'>",
             ):
-                df.groupby("id").applyInPandas(stats, schema="id integer, m 
double").collect()
+                self._test_apply_in_pandas(lambda key, pdf: key)
 
-    def test_apply_in_pandas_returning_wrong_number_of_columns(self):
-        df = self.data
+    @staticmethod
+    def stats_with_column_names(key, pdf):
+        # order of column can be different to applyInPandas schema when column 
names are given
+        return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
 
-        def stats(key, pdf):
-            v = pdf.v
-            # returning three columns
-            res = pd.DataFrame([key + (v.mean(), v.std())])
-            return res
+    @staticmethod
+    def stats_with_no_column_names(key, pdf):
+        # columns must be in order of applyInPandas schema when no columns 
given
+        return pd.DataFrame([key + (pdf.v.mean(),)])
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                PythonException,
-                "Number of columns of the returned pandas.DataFrame doesn't 
match "
-                "specified schema. Expected: 2 Actual: 3",
-            ):
-                # stats returns three columns while here we set schema with 
two columns
-                df.groupby("id").applyInPandas(stats, schema="id integer, m 
double").collect()
+    def test_apply_in_pandas_returning_column_names(self):
+        
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_column_names)
 
-    def test_apply_in_pandas_returning_empty_dataframe(self):
-        df = self.data
+    def test_apply_in_pandas_returning_no_column_names(self):
+        
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_no_column_names)
 
-        def odd_means(key, pdf):
-            if key[0] % 2 == 0:
-                return pd.DataFrame([])
+    def test_apply_in_pandas_returning_column_names_sometimes(self):
+        def stats(key, pdf):
+            if key[0] % 2:
+                return GroupedApplyInPandasTests.stats_with_column_names(key, 
pdf)
             else:
-                return pd.DataFrame([key + (pdf.v.mean(),)])
+                return 
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
 
-        expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 != 
0}
+        self._test_apply_in_pandas(stats)
 
-        result = (
-            df.groupby("id")
-            .applyInPandas(odd_means, schema="id integer, m double")
-            .sort("id", "m")
-            .collect()
-        )
-
-        actual_ids = {row[0] for row in result}
-        self.assertSetEqual(expected_ids, actual_ids)
-
-        self.assertEqual(len(expected_ids), len(result))
-        for row in result:
-            self.assertEqual(24.5, row[1])
-
-    def 
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
-        df = self.data
-
-        def odd_means(key, pdf):
-            if key[0] % 2 == 0:
-                return pd.DataFrame([], columns=["id"])
-            else:
-                return pd.DataFrame([key + (pdf.v.mean(),)])
+    def test_apply_in_pandas_returning_wrong_column_names(self):
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(
+                PythonException,
+                "Column names of the returned pandas.DataFrame do not match 
specified schema. "
+                "Missing: mean. Unexpected: median, std.\n",
+            ):
+                self._test_apply_in_pandas(
+                    lambda key, pdf: pd.DataFrame(
+                        [key + (pdf.v.median(), pdf.v.std())], columns=["id", 
"median", "std"]
+                    )
+                )
 
+    def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
         with QuietTest(self.sc):
             with self.assertRaisesRegex(
                 PythonException,
                 "Number of columns of the returned pandas.DataFrame doesn't 
match "
-                "specified schema. Expected: 2 Actual: 1",
+                "specified schema. Expected: 2 Actual: 3\n",
             ):
-                # stats returns one column for even keys while here we set 
schema with two columns
-                df.groupby("id").applyInPandas(odd_means, schema="id integer, 
m double").collect()
+                self._test_apply_in_pandas(
+                    lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), 
pdf.v.std())])
+                )
+
+    def test_apply_in_pandas_returning_empty_dataframe(self):
+        self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame())
+
+    def test_apply_in_pandas_returning_incompatible_type(self):
+        for safely in [True, False]:
+            with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
+                {"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}
+            ), QuietTest(self.sc):
+                # sometimes we see ValueErrors
+                with self.subTest(convert="string to double"):
+                    expected = (
+                        r"ValueError: Exception thrown when converting 
pandas.Series \(object\) "
+                        r"with name 'mean' to Arrow Array \(double\)."
+                    )
+                    if safely:
+                        expected = expected + (
+                            " It can be caused by overflows or other "
+                            "unsafe conversions warned by Arrow. Arrow safe 
type check "
+                            "can be disabled by using SQL config "
+                            
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                        )
+                    with self.assertRaisesRegex(PythonException, expected + 
"\n"):
+                        self._test_apply_in_pandas(
+                            lambda key, pdf: pd.DataFrame([key + 
(str(pdf.v.mean()),)]),
+                            output_schema="id long, mean double",
+                        )
+
+                # sometimes we see TypeErrors
+                with self.subTest(convert="double to string"):
+                    with self.assertRaisesRegex(
+                        PythonException,
+                        r"TypeError: Exception thrown when converting 
pandas.Series \(float64\) "
+                        r"with name 'mean' to Arrow Array \(string\).\n",
+                    ):
+                        self._test_apply_in_pandas(
+                            lambda key, pdf: pd.DataFrame([key + 
(pdf.v.mean(),)]),
+                            output_schema="id long, mean string",
+                        )
 
     def test_datatype_string(self):
         df = self.data
@@ -566,7 +588,11 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
 
         with 
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
             with QuietTest(self.sc):
-                with self.assertRaisesRegex(Exception, "KeyError: 'id'"):
+                with self.assertRaisesRegex(
+                    PythonException,
+                    "RuntimeError: Column names of the returned 
pandas.DataFrame do not match "
+                    "specified schema. Missing: id. Unexpected: iid.\n",
+                ):
                     grouped_df.apply(column_name_typo).collect()
                 with self.assertRaisesRegex(Exception, 
"[D|d]ecimal.*got.*date"):
                     grouped_df.apply(invalid_positional_types).collect()
@@ -655,10 +681,11 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
             df.groupby("group", window("ts", "5 days"))
             .applyInPandas(f, df.schema)
             .select("id", "result")
+            .orderBy("id")
             .collect()
         )
-        for r in result:
-            self.assertListEqual(expected[r[0]], r[1])
+
+        self.assertListEqual([Row(id=key, result=val) for key, val in 
expected.items()], result)
 
     def test_grouped_over_window_with_key(self):
 
@@ -720,11 +747,11 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
             df.groupby("group", window("ts", "5 days"))
             .applyInPandas(f, df.schema)
             .select("id", "result")
+            .orderBy("id")
             .collect()
         )
 
-        for r in result:
-            self.assertListEqual(expected[r[0]], r[1])
+        self.assertListEqual([Row(id=key, result=val) for key, val in 
expected.items()], result)
 
     def test_case_insensitive_grouping_column(self):
         # SPARK-31915: case-insensitive grouping column should work.
@@ -739,6 +766,44 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
         )
         self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
 
+    def _test_apply_in_pandas(self, f, output_schema="id long, mean double"):
+        df = self.data
+
+        result = (
+            df.groupby("id").applyInPandas(f, schema=output_schema).sort("id", 
"mean").toPandas()
+        )
+        expected = df.select("id").distinct().withColumn("mean", 
lit(24.5)).toPandas()
+
+        assert_frame_equal(expected, result)
+
+    def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df):
+        """Tests some returned DataFrames are empty."""
+        df = self.data
+
+        def stats(key, pdf):
+            if key[0] % 2 == 0:
+                return 
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
+            return empty_df
+
+        result = (
+            df.groupby("id")
+            .applyInPandas(stats, schema="id long, mean double")
+            .sort("id", "mean")
+            .collect()
+        )
+
+        actual_ids = {row[0] for row in result}
+        expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 == 
0}
+        self.assertSetEqual(expected_ids, actual_ids)
+        self.assertEqual(len(expected_ids), len(result))
+        for row in result:
+            self.assertEqual(24.5, row[1])
+
+    def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df, 
error):
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(PythonException, error):
+                self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_grouped_map import *  # noqa: 
F401
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index 655f0bf151d..9600d1e3445 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -53,7 +53,7 @@ if have_pyarrow:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
+class GroupedApplyInPandasWithStateTests(ReusedSQLTestCase):
     @classmethod
     def conf(cls):
         cfg = SparkConf()
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 6083f31ac81..c61994380e6 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -465,9 +465,24 @@ class ArrowTests(ReusedSQLTestCase):
         wrong_schema = StructType(fields)
         with 
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
             with QuietTest(self.sc):
-                with self.assertRaisesRegex(Exception, 
"[D|d]ecimal.*got.*date"):
+                with self.assertRaises(Exception) as context:
                     self.spark.createDataFrame(pdf, schema=wrong_schema)
 
+                # the exception provides us with the column that is incorrect
+                exception = context.exception
+                self.assertTrue(hasattr(exception, "args"))
+                self.assertEqual(len(exception.args), 1)
+                self.assertRegex(
+                    exception.args[0],
+                    "with name '7_date_t' " "to Arrow Array 
\\(decimal128\\(38, 18\\)\\)",
+                )
+
+                # the inner exception provides us with the incorrect types
+                exception = exception.__context__
+                self.assertTrue(hasattr(exception, "args"))
+                self.assertEqual(len(exception.args), 1)
+                self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date")
+
     def test_createDataFrame_with_names(self):
         pdf = self.create_pandas_data_frame()
         new_names = list(map(str, range(len(self.schema.fieldNames()))))
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index c1c3669701f..f7d98a9a18c 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -146,7 +146,49 @@ def wrap_batch_iter_udf(f, return_type):
     )
 
 
-def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
+def verify_pandas_result(result, return_type, assign_cols_by_name):
+    import pandas as pd
+
+    if not isinstance(result, pd.DataFrame):
+        raise TypeError(
+            "Return type of the user-defined function should be "
+            "pandas.DataFrame, but is {}".format(type(result))
+        )
+
+    # check the schema of the result only if it is not empty or has columns
+    if not result.empty or len(result.columns) != 0:
+        # if any column name of the result is a string
+        # the column names of the result have to match the return type
+        #   see create_array in 
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
+        field_names = set([field.name for field in return_type.fields])
+        column_names = set(result.columns)
+        if (
+            assign_cols_by_name
+            and any(isinstance(name, str) for name in result.columns)
+            and column_names != field_names
+        ):
+            missing = sorted(list(field_names.difference(column_names)))
+            missing = f" Missing: {', '.join(missing)}." if missing else ""
+
+            extra = sorted(list(column_names.difference(field_names)))
+            extra = f" Unexpected: {', '.join(extra)}." if extra else ""
+
+            raise RuntimeError(
+                "Column names of the returned pandas.DataFrame do not match 
specified schema."
+                "{}{}".format(missing, extra)
+            )
+        # otherwise the number of columns of result have to match the return 
type
+        elif len(result.columns) != len(return_type):
+            raise RuntimeError(
+                "Number of columns of the returned pandas.DataFrame "
+                "doesn't match specified schema. "
+                "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
+            )
+
+
+def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
+    _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
     def wrapped(left_key_series, left_value_series, right_key_series, 
right_value_series):
         import pandas as pd
 
@@ -159,27 +201,16 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, 
argspec):
             key_series = left_key_series if not left_df.empty else 
right_key_series
             key = tuple(s[0] for s in key_series)
             result = f(key, left_df, right_df)
-        if not isinstance(result, pd.DataFrame):
-            raise TypeError(
-                "Return type of the user-defined function should be "
-                "pandas.DataFrame, but is {}".format(type(result))
-            )
-        # the number of columns of result have to match the return type
-        # but it is fine for result to have no columns at all if it is empty
-        if not (
-            len(result.columns) == len(return_type) or len(result.columns) == 
0 and result.empty
-        ):
-            raise RuntimeError(
-                "Number of columns of the returned pandas.DataFrame "
-                "doesn't match specified schema. "
-                "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
-            )
+        verify_pandas_result(result, return_type, _assign_cols_by_name)
+
         return result
 
     return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), 
to_arrow_type(return_type))]
 
 
-def wrap_grouped_map_pandas_udf(f, return_type, argspec):
+def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
+    _assign_cols_by_name = assign_cols_by_name(runner_conf)
+
     def wrapped(key_series, value_series):
         import pandas as pd
 
@@ -188,22 +219,8 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec):
         elif len(argspec.args) == 2:
             key = tuple(s[0] for s in key_series)
             result = f(key, pd.concat(value_series, axis=1))
+        verify_pandas_result(result, return_type, _assign_cols_by_name)
 
-        if not isinstance(result, pd.DataFrame):
-            raise TypeError(
-                "Return type of the user-defined function should be "
-                "pandas.DataFrame, but is {}".format(type(result))
-            )
-        # the number of columns of result have to match the return type
-        # but it is fine for result to have no columns at all if it is empty
-        if not (
-            len(result.columns) == len(return_type) or len(result.columns) == 
0 and result.empty
-        ):
-            raise RuntimeError(
-                "Number of columns of the returned pandas.DataFrame "
-                "doesn't match specified schema. "
-                "Expected: {} Actual: {}".format(len(return_type), 
len(result.columns))
-            )
         return result
 
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
@@ -396,12 +413,12 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         return arg_offsets, wrap_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec)
+        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
         return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, 
return_type)
     elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
-        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec)
+        return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
         return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -412,6 +429,16 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         raise ValueError("Unknown eval type: {}".format(eval_type))
 
 
+# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning 
StructType
+def assign_cols_by_name(runner_conf):
+    return (
+        runner_conf.get(
+            
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
+        ).lower()
+        == "true"
+    )
+
+
 def read_udfs(pickleSer, infile, eval_type):
     runner_conf = {}
 
@@ -444,16 +471,9 @@ def read_udfs(pickleSer, infile, eval_type):
             
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
             == "true"
         )
-        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when 
returning StructType
-        assign_cols_by_name = (
-            runner_conf.get(
-                
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
-            ).lower()
-            == "true"
-        )
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
-            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name)
+            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name(runner_conf))
         elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
             arrow_max_records_per_batch = runner_conf.get(
                 "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
@@ -463,7 +483,7 @@ def read_udfs(pickleSer, infile, eval_type):
             ser = ApplyInPandasWithStateSerializer(
                 timezone,
                 safecheck,
-                assign_cols_by_name,
+                assign_cols_by_name(runner_conf),
                 state_object_schema,
                 arrow_max_records_per_batch,
             )
@@ -478,7 +498,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
             )
             ser = ArrowStreamPandasUDFSerializer(
-                timezone, safecheck, assign_cols_by_name, df_for_struct
+                timezone, safecheck, assign_cols_by_name(runner_conf), 
df_for_struct
             )
     else:
         ser = BatchedSerializer(CPickleSerializer(), 100)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to