This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new d85f3bc5a54 [SPARK-40770][PYTHON][FOLLOW-UP][3.5] Improved error 
messages for mapInPandas for schema mismatch
d85f3bc5a54 is described below

commit d85f3bc5a54302469824b2d8b5e71ebbcfc7c4c4
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Fri Aug 4 10:38:35 2023 +0900

    [SPARK-40770][PYTHON][FOLLOW-UP][3.5] Improved error messages for 
mapInPandas for schema mismatch
    
    ### What changes were proposed in this pull request?
    This merges #39952 into 3.5 branch.
    
    Similar to #38223, improve the error messages when a Python method provided 
to `DataFrame.mapInPandas` returns a Pandas DataFrame that does not match the 
expected schema.
    
    With
    ```Python
    df = spark.range(2).withColumn("v", col("id"))
    ```
    
    **Mismatching column names:**
    ```Python
    df.mapInPandas(lambda it: it, "id long, val long").show()
    # was: KeyError: 'val'
    # now: RuntimeError: Column names of the returned pandas.DataFrame do not 
match specified schema.
    #      Missing: val  Unexpected: v
    ```
    
    **Python function not returning iterator:**
    ```Python
    df.mapInPandas(lambda it: 1, "id long").show()
    # was: TypeError: 'int' object is not iterable
    # now: TypeError: Return type of the user-defined function should be 
iterator of pandas.DataFrame, but is <class 'int'>
    ```
    
    **Python function not returning iterator of pandas.DataFrame:**
    ```Python
    df.mapInPandas(lambda it: [1], "id long").show()
    # was: TypeError: Return type of the user-defined function should be 
Pandas.DataFrame, but is <class 'int'>
    # now: TypeError: Return type of the user-defined function should be 
iterator of pandas.DataFrame, but is iterator of <class 'int'>
    # sometimes: ValueError: A field of type StructType expects a 
pandas.DataFrame, but got: <class 'list'>
    # now: TypeError: Return type of the user-defined function should be 
iterator of pandas.DataFrame, but is iterator of <class 'list'>
    ```
    
    **Mismatching types (ValueError and TypeError):**
    ```Python
    df.mapInPandas(lambda it: it, "id int, v string").show()
    # was: pyarrow.lib.ArrowTypeError: Expected a string or bytes dtype, got 
int64
    # now: pyarrow.lib.ArrowTypeError: Expected a string or bytes dtype, got 
int64
    #      The above exception was the direct cause of the following exception:
    #      TypeError: Exception thrown when converting pandas.Series (int64) 
with name 'v' to Arrow Array (string).
    
    df.mapInPandas(lambda it: [pdf.assign(v=pdf["v"].apply(str)) for pdf in 
it], "id int, v double").show()
    # was: pyarrow.lib.ArrowInvalid: Could not convert '0' with type str: tried 
to convert to double
    # now: pyarrow.lib.ArrowInvalid: Could not convert '0' with type str: tried 
to convert to double
    #      The above exception was the direct cause of the following exception:
    #      ValueError: Exception thrown when converting pandas.Series (object) 
with name 'v' to Arrow Array (double).
    
    with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": 
True}):
      df.mapInPandas(lambda it: [pdf.assign(v=pdf["v"].apply(str)) for pdf in 
it], "id int, v double").show()
    # was: ValueError: Exception thrown when converting pandas.Series (object) 
to Arrow Array (double).
    #      It can be caused by overflows or other unsafe conversions warned by 
Arrow. Arrow safe type check can be disabled
    #      by using SQL config 
`spark.sql.execution.pandas.convertToArrowArraySafely`.
    # now: ValueError: Exception thrown when converting pandas.Series (object) 
with name 'v' to Arrow Array (double).
    #      It can be caused by overflows or other unsafe conversions warned by 
Arrow. Arrow safe type check can be disabled
    #      by using SQL config 
`spark.sql.execution.pandas.convertToArrowArraySafely`.
    ```
    
    ### Why are the changes needed?
    Existing errors are generic (`KeyError`) or meaningless (`'int' object is 
not iterable`). The errors should help users in spotting the mismatching 
columns by naming them.
    
    The schema of the returned Pandas DataFrames can only be checked during 
processing the DataFrame, so such errors are very expensive. Therefore, they 
should be expressive.
    
    ### Does this PR introduce _any_ user-facing change?
    This only changes error messages, not behaviour.
    
    ### How was this patch tested?
    Tests all cases of schema mismatch for `DataFrame.mapInPandas`.
    
    Closes #42316 from 
EnricoMi/branch-pyspark-map-in-pandas-schema-mismatch-3.5.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/pandas/frame.py                     |   2 +-
 python/pyspark/sql/pandas/serializers.py           |   2 +-
 .../sql/tests/connect/test_parity_arrow_map.py     |   3 +-
 .../sql/tests/connect/test_parity_pandas_map.py    |  23 +-
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  |   4 +-
 .../sql/tests/pandas/test_pandas_grouped_map.py    |   4 +-
 python/pyspark/sql/tests/pandas/test_pandas_map.py | 265 ++++++++++++++++++---
 python/pyspark/sql/tests/test_arrow_map.py         |  27 +++
 python/pyspark/worker.py                           | 205 +++++++++++-----
 10 files changed, 440 insertions(+), 100 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index db80705e7d2..2a3f454452e 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -713,6 +713,11 @@ ERROR_CLASSES_JSON = """
       "Expected <expected> values for `<item>`, got <actual>."
     ]
   },
+  "UDF_RETURN_TYPE" : {
+    "message" : [
+      "Return type of the user-defined function should be <expected>, but is 
<actual>."
+    ]
+  },
   "UDTF_ARROW_TYPE_CAST_ERROR" : {
     "message" : [
       "Cannot convert the output value of the column '<col_name>' with type 
'<col_type>' to the specified return type of the column: '<arrow_type>'. Please 
check if the data types match and try again."
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 6f2c8389a4c..d8a3f812c33 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -399,7 +399,7 @@ class DataFrame(Frame, Generic[T]):
         `compute.ops_on_diff_frames` should be turned on;
         2, when `data` is a local dataset (Pandas DataFrame/numpy 
ndarray/list/etc),
         it will first collect the `index` to driver if necessary, and then 
apply
-        the `Pandas.DataFrame(...)` creation internally;
+        the `pandas.DataFrame(...)` creation internally;
 
     Examples
     --------
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 993bacbed67..f3037c8b39c 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -163,7 +163,7 @@ class ArrowStreamUDFSerializer(ArrowStreamSerializer):
 
 class ArrowStreamPandasSerializer(ArrowStreamSerializer):
     """
-    Serializes Pandas.Series as Arrow data with Arrow streaming format.
+    Serializes pandas.Series as Arrow data with Arrow streaming format.
 
     Parameters
     ----------
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py 
b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
index ed51d0d3d19..868aeaeff7f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow_map.py
@@ -22,7 +22,8 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class ArrowMapParityTests(MapInArrowTestsMixin, ReusedConnectTestCase):
-    pass
+    def test_other_than_recordbatch_iter(self):
+        self.check_other_than_recordbatch_iter()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_map.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py
index 539fd98266b..6ff9b0cb33b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_pandas_map.py
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py
@@ -14,16 +14,35 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import unittest
+
+
 from pyspark.sql.tests.pandas.test_pandas_map import MapInPandasTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class MapInPandasParityTests(MapInPandasTestsMixin, ReusedConnectTestCase):
+    def test_other_than_dataframe_iter(self):
+        self.check_other_than_dataframe_iter()
+
+    def test_dataframes_with_other_column_names(self):
+        self.check_dataframes_with_other_column_names()
+
+    def test_dataframes_with_duplicate_column_names(self):
+        self.check_dataframes_with_duplicate_column_names()
+
+    def test_dataframes_with_less_columns(self):
+        self.check_dataframes_with_less_columns()
+
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_dataframes_with_incompatible_types(self):
+        self.check_dataframes_with_incompatible_types()
+
     def test_empty_dataframes_with_less_columns(self):
         self.check_empty_dataframes_with_less_columns()
 
-    def test_other_than_dataframe(self):
-        self.check_other_than_dataframe()
+    def test_empty_dataframes_with_other_columns(self):
+        self.check_empty_dataframes_with_other_columns()
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 8def08323be..b867156e71a 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -56,7 +56,6 @@ class CogroupedApplyInPandasTestsMixin:
     def data1(self):
         return (
             self.spark.range(10)
-            .toDF("id")
             .withColumn("ks", array([lit(i) for i in range(20, 30)]))
             .withColumn("k", explode(col("ks")))
             .withColumn("v", col("k") * 10)
@@ -67,7 +66,6 @@ class CogroupedApplyInPandasTestsMixin:
     def data2(self):
         return (
             self.spark.range(10)
-            .toDF("id")
             .withColumn("ks", array([lit(i) for i in range(20, 30)]))
             .withColumn("k", explode(col("ks")))
             .withColumn("v2", col("k") * 100)
@@ -168,7 +166,7 @@ class CogroupedApplyInPandasTestsMixin:
             fn=lambda lft, rgt: lft.size + rgt.size,
             error_class=PythonException,
             error_message_regex="Return type of the user-defined function "
-            "should be pandas.DataFrame, but is <class 'numpy.int64'>",
+            "should be pandas.DataFrame, but is int64.",
         )
 
     def test_apply_in_pandas_returning_column_names(self):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 84e61d42843..742b3657f6e 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -79,7 +79,6 @@ class GroupedApplyInPandasTestsMixin:
     def data(self):
         return (
             self.spark.range(10)
-            .toDF("id")
             .withColumn("vs", array([lit(i) for i in range(20, 30)]))
             .withColumn("v", explode(col("vs")))
             .drop("vs")
@@ -287,8 +286,7 @@ class GroupedApplyInPandasTestsMixin:
     def check_apply_in_pandas_not_returning_pandas_dataframe(self):
         with self.assertRaisesRegex(
             PythonException,
-            "Return type of the user-defined function should be 
pandas.DataFrame, "
-            "but is <class 'tuple'>",
+            "Return type of the user-defined function should be 
pandas.DataFrame, but is tuple.",
         ):
             self._test_apply_in_pandas(lambda key, pdf: key)
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 3d9a90bc81c..fb2f9214c5d 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -42,15 +42,46 @@ if have_pandas:
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
 class MapInPandasTestsMixin:
-    def test_map_in_pandas(self):
+    @staticmethod
+    def identity_dataframes_iter(*columns: str):
         def func(iterator):
             for pdf in iterator:
                 assert isinstance(pdf, pd.DataFrame)
-                assert pdf.columns == ["id"]
+                assert pdf.columns.tolist() == list(columns)
                 yield pdf
 
+        return func
+
+    @staticmethod
+    def identity_dataframes_wo_column_names_iter(*columns: str):
+        def func(iterator):
+            for pdf in iterator:
+                assert isinstance(pdf, pd.DataFrame)
+                assert pdf.columns.tolist() == list(columns)
+                yield pdf.rename(columns=list(pdf.columns).index)
+
+        return func
+
+    @staticmethod
+    def dataframes_and_empty_dataframe_iter(*columns: str):
+        def func(iterator):
+            for pdf in iterator:
+                yield pdf
+            # after yielding all elements, also yield an empty dataframe with 
given columns
+            yield pd.DataFrame([], columns=list(columns))
+
+        return func
+
+    def test_map_in_pandas(self):
+        # test returning iterator of DataFrames
+        df = self.spark.range(10, numPartitions=3)
+        actual = df.mapInPandas(self.identity_dataframes_iter("id"), "id 
long").collect()
+        expected = df.collect()
+        self.assertEqual(actual, expected)
+
+        # test returning list of DataFrames
         df = self.spark.range(10, numPartitions=3)
-        actual = df.mapInPandas(func, "id long").collect()
+        actual = df.mapInPandas(lambda it: [pdf for pdf in it], "id 
long").collect()
         expected = df.collect()
         self.assertEqual(actual, expected)
 
@@ -85,6 +116,18 @@ class MapInPandasTestsMixin:
             expected = df.collect()
             self.assertEqual(actual, expected)
 
+    def test_no_column_names(self):
+        data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
+        df = self.spark.createDataFrame(data, "a int, b string")
+
+        def func(iterator):
+            for pdf in iterator:
+                yield pdf.rename(columns=list(pdf.columns).index)
+
+        actual = df.mapInPandas(func, df.schema).collect()
+        expected = df.collect()
+        self.assertEqual(actual, expected)
+
     def test_different_output_length(self):
         def func(iterator):
             for _ in iterator:
@@ -94,20 +137,161 @@ class MapInPandasTestsMixin:
         actual = df.repartition(1).mapInPandas(func, "a long").collect()
         self.assertEqual(set((r.a for r in actual)), set(range(100)))
 
-    def test_other_than_dataframe(self):
+    def test_other_than_dataframe_iter(self):
         with QuietTest(self.sc):
-            self.check_other_than_dataframe()
+            self.check_other_than_dataframe_iter()
 
-    def check_other_than_dataframe(self):
-        def bad_iter(_):
+    def check_other_than_dataframe_iter(self):
+        def no_iter(_):
+            return 1
+
+        def bad_iter_elem(_):
             return iter([1])
 
         with self.assertRaisesRegex(
             PythonException,
-            "Return type of the user-defined function should be 
Pandas.DataFrame, "
-            "but is <class 'int'>",
+            "Return type of the user-defined function should be iterator of 
pandas.DataFrame, "
+            "but is int.",
+        ):
+            (self.spark.range(10, numPartitions=3).mapInPandas(no_iter, "a 
int").count())
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "Return type of the user-defined function should be iterator of 
pandas.DataFrame, "
+            "but is iterator of int.",
+        ):
+            (self.spark.range(10, numPartitions=3).mapInPandas(bad_iter_elem, 
"a int").count())
+
+    def test_dataframes_with_other_column_names(self):
+        with QuietTest(self.sc):
+            self.check_dataframes_with_other_column_names()
+
+    def check_dataframes_with_other_column_names(self):
+        def dataframes_with_other_column_names(iterator):
+            for pdf in iterator:
+                yield pdf.rename(columns={"id": "iid"})
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
+            "Column names of the returned pandas.DataFrame do not match "
+            "specified schema. Missing: id. Unexpected: iid.\n",
+        ):
+            (
+                self.spark.range(10, numPartitions=3)
+                .withColumn("value", lit(0))
+                .mapInPandas(dataframes_with_other_column_names, "id int, 
value int")
+                .collect()
+            )
+
+    def test_dataframes_with_duplicate_column_names(self):
+        with QuietTest(self.sc):
+            self.check_dataframes_with_duplicate_column_names()
+
+    def check_dataframes_with_duplicate_column_names(self):
+        def dataframes_with_other_column_names(iterator):
+            for pdf in iterator:
+                yield pdf.rename(columns={"id2": "id"})
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
+            "Column names of the returned pandas.DataFrame do not match "
+            "specified schema. Missing: id2.\n",
         ):
-            self.spark.range(10, numPartitions=3).mapInPandas(bad_iter, "a 
int, b string").count()
+            (
+                self.spark.range(10, numPartitions=3)
+                .withColumn("id2", lit(0))
+                .withColumn("value", lit(1))
+                .mapInPandas(dataframes_with_other_column_names, "id int, id2 
long, value int")
+                .collect()
+            )
+
+    def test_dataframes_with_less_columns(self):
+        with QuietTest(self.sc):
+            self.check_dataframes_with_less_columns()
+
+    def check_dataframes_with_less_columns(self):
+        df = self.spark.range(10, numPartitions=3).withColumn("value", lit(0))
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
+            "Column names of the returned pandas.DataFrame do not match "
+            "specified schema. Missing: id2.\n",
+        ):
+            f = self.identity_dataframes_iter("id", "value")
+            (df.mapInPandas(f, "id int, id2 long, value int").collect())
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "PySparkRuntimeError: \\[RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF\\] "
+            "Number of columns of the returned pandas.DataFrame doesn't match "
+            "specified schema. Expected: 3 Actual: 2\n",
+        ):
+            f = self.identity_dataframes_wo_column_names_iter("id", "value")
+            (df.mapInPandas(f, "id int, id2 long, value int").collect())
+
+    def test_dataframes_with_more_columns(self):
+        df = self.spark.range(10, numPartitions=3).select(
+            "id", col("id").alias("value"), col("id").alias("extra")
+        )
+        expected = df.select("id", "value").collect()
+
+        f = self.identity_dataframes_iter("id", "value", "extra")
+        actual = df.repartition(1).mapInPandas(f, "id long, value 
long").collect()
+        self.assertEqual(actual, expected)
+
+        f = self.identity_dataframes_wo_column_names_iter("id", "value", 
"extra")
+        actual = df.repartition(1).mapInPandas(f, "id long, value 
long").collect()
+        self.assertEqual(actual, expected)
+
+    def test_dataframes_with_incompatible_types(self):
+        with QuietTest(self.sc):
+            self.check_dataframes_with_incompatible_types()
+
+    def check_dataframes_with_incompatible_types(self):
+        def func(iterator):
+            for pdf in iterator:
+                yield pdf.assign(id=pdf["id"].apply(str))
+
+        for safely in [True, False]:
+            with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
+                {"spark.sql.execution.pandas.convertToArrowArraySafely": 
safely}
+            ):
+                # sometimes we see ValueErrors
+                with self.subTest(convert="string to double"):
+                    expected = (
+                        r"ValueError: Exception thrown when converting 
pandas.Series "
+                        r"\(object\) with name 'id' to Arrow Array \(double\)."
+                    )
+                    if safely:
+                        expected = expected + (
+                            " It can be caused by overflows or other "
+                            "unsafe conversions warned by Arrow. Arrow safe 
type check "
+                            "can be disabled by using SQL config "
+                            
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                        )
+                    with self.assertRaisesRegex(PythonException, expected + 
"\n"):
+                        (
+                            self.spark.range(10, numPartitions=3)
+                            .mapInPandas(func, "id double")
+                            .collect()
+                        )
+
+                # sometimes we see TypeErrors
+                with self.subTest(convert="double to string"):
+                    with self.assertRaisesRegex(
+                        PythonException,
+                        r"TypeError: Exception thrown when converting 
pandas.Series "
+                        r"\(float64\) with name 'id' to Arrow Array 
\(string\).\n",
+                    ):
+                        (
+                            self.spark.range(10, numPartitions=3)
+                            .select(col("id").cast("double"))
+                            .mapInPandas(self.identity_dataframes_iter("id"), 
"id string")
+                            .collect()
+                        )
 
     def test_empty_iterator(self):
         def empty_iter(_):
@@ -124,16 +308,8 @@ class MapInPandasTestsMixin:
         self.assertEqual(mapped.count(), 0)
 
     def test_empty_dataframes_without_columns(self):
-        def empty_dataframes_wo_columns(iterator):
-            for pdf in iterator:
-                yield pdf
-            # after yielding all elements of the iterator, also yield one 
dataframe without columns
-            yield pd.DataFrame([])
-
-        mapped = (
-            self.spark.range(10, numPartitions=3)
-            .toDF("id")
-            .mapInPandas(empty_dataframes_wo_columns, "id int")
+        mapped = self.spark.range(10, numPartitions=3).mapInPandas(
+            self.dataframes_and_empty_dataframe_iter(), "id int"
         )
         self.assertEqual(mapped.count(), 10)
 
@@ -142,16 +318,47 @@ class MapInPandasTestsMixin:
             self.check_empty_dataframes_with_less_columns()
 
     def check_empty_dataframes_with_less_columns(self):
-        def empty_dataframes_with_less_columns(iterator):
-            for pdf in iterator:
-                yield pdf
-            # after yielding all elements of the iterator, also yield a 
dataframe with less columns
-            yield pd.DataFrame([(1,)], columns=["id"])
+        with self.assertRaisesRegex(
+            PythonException,
+            "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
+            "Column names of the returned pandas.DataFrame do not match "
+            "specified schema. Missing: value.\n",
+        ):
+            f = self.dataframes_and_empty_dataframe_iter("id")
+            (
+                self.spark.range(10, numPartitions=3)
+                .withColumn("value", lit(0))
+                .mapInPandas(f, "id int, value int")
+                .collect()
+            )
 
-        with self.assertRaisesRegex(PythonException, "KeyError: 'value'"):
-            self.spark.range(10, numPartitions=3).withColumn("value", 
lit(0)).toDF(
-                "id", "value"
-            ).mapInPandas(empty_dataframes_with_less_columns, "id int, value 
int").collect()
+    def test_empty_dataframes_with_more_columns(self):
+        mapped = self.spark.range(10, numPartitions=3).mapInPandas(
+            self.dataframes_and_empty_dataframe_iter("id", "extra"), "id int"
+        )
+        self.assertEqual(mapped.count(), 10)
+
+    def test_empty_dataframes_with_other_columns(self):
+        with QuietTest(self.sc):
+            self.check_empty_dataframes_with_other_columns()
+
+    def check_empty_dataframes_with_other_columns(self):
+        def empty_dataframes_with_other_columns(iterator):
+            for _ in iterator:
+                yield pd.DataFrame({"iid": [], "value": []})
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
+            "Column names of the returned pandas.DataFrame do not match "
+            "specified schema. Missing: id. Unexpected: iid.\n",
+        ):
+            (
+                self.spark.range(10, numPartitions=3)
+                .withColumn("value", lit(0))
+                .mapInPandas(empty_dataframes_with_other_columns, "id int, 
value int")
+                .collect()
+            )
 
     def test_chain_map_partitions_in_pandas(self):
         def func(iterator):
diff --git a/python/pyspark/sql/tests/test_arrow_map.py 
b/python/pyspark/sql/tests/test_arrow_map.py
index 050f2c32665..15367743585 100644
--- a/python/pyspark/sql/tests/test_arrow_map.py
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -18,6 +18,7 @@ import os
 import time
 import unittest
 
+from pyspark.sql.utils import PythonException
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -25,6 +26,7 @@ from pyspark.testing.sqlutils import (
     pandas_requirement_message,
     pyarrow_requirement_message,
 )
+from pyspark.testing.utils import QuietTest
 
 if have_pyarrow:
     import pyarrow as pa
@@ -88,6 +90,31 @@ class MapInArrowTestsMixin(object):
         actual = df.repartition(1).mapInArrow(func, "a long").collect()
         self.assertEqual(set((r.a for r in actual)), set(range(100)))
 
+    def test_other_than_recordbatch_iter(self):
+        with QuietTest(self.sc):
+            self.check_other_than_recordbatch_iter()
+
+    def check_other_than_recordbatch_iter(self):
+        def not_iter(_):
+            return 1
+
+        def bad_iter_elem(_):
+            return iter([1])
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "Return type of the user-defined function should be iterator "
+            "of pyarrow.RecordBatch, but is int.",
+        ):
+            (self.spark.range(10, numPartitions=3).mapInArrow(not_iter, "a 
int").count())
+
+        with self.assertRaisesRegex(
+            PythonException,
+            "Return type of the user-defined function should be iterator "
+            "of pyarrow.RecordBatch, but is iterator of int.",
+        ):
+            (self.spark.range(10, numPartitions=3).mapInArrow(bad_iter_elem, 
"a int").count())
+
     def test_empty_iterator(self):
         def empty_iter(_):
             return iter([])
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index cbc9faad47c..3dffdf2c642 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -24,6 +24,7 @@ import time
 from inspect import currentframe, getframeinfo, getfullargspec
 import importlib
 import json
+from typing import Iterator
 
 # 'resource' is a Unix specific module.
 has_resource_module = True
@@ -110,10 +111,13 @@ def wrap_scalar_pandas_udf(f, return_type):
 
     def verify_result_type(result):
         if not hasattr(result, "__len__"):
-            pd_type = "Pandas.DataFrame" if type(return_type) == StructType 
else "Pandas.Series"
-            raise TypeError(
-                "Return type of the user-defined function should be "
-                "{}, but is {}".format(pd_type, type(result))
+            pd_type = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": pd_type,
+                    "actual": type(result).__name__,
+                },
             )
         return result
 
@@ -134,67 +138,136 @@ def wrap_scalar_pandas_udf(f, return_type):
     )
 
 
-def wrap_batch_iter_udf(f, return_type):
+def wrap_pandas_batch_iter_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
+    iter_type_label = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
 
-    def verify_result_type(result):
-        if not hasattr(result, "__len__"):
-            pd_type = "Pandas.DataFrame" if type(return_type) == StructType 
else "Pandas.Series"
-            raise TypeError(
-                "Return type of the user-defined function should be "
-                "{}, but is {}".format(pd_type, type(result))
+    def verify_result(result):
+        if not isinstance(result, Iterator) and not hasattr(result, 
"__iter__"):
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": "iterator of {}".format(iter_type_label),
+                    "actual": type(result).__name__,
+                },
             )
         return result
 
+    def verify_element(elem):
+        import pandas as pd
+
+        if not isinstance(elem, pd.DataFrame if type(return_type) == 
StructType else pd.Series):
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": "iterator of {}".format(iter_type_label),
+                    "actual": "iterator of {}".format(type(elem).__name__),
+                },
+            )
+
+        verify_pandas_result(
+            elem, return_type, assign_cols_by_name=True, 
truncate_return_schema=True
+        )
+
+        return elem
+
     return lambda *iterator: map(
-        lambda res: (res, arrow_return_type), map(verify_result_type, 
f(*iterator))
+        lambda res: (res, arrow_return_type), map(verify_element, 
verify_result(f(*iterator)))
     )
 
 
-def verify_pandas_result(result, return_type, assign_cols_by_name):
+def verify_pandas_result(result, return_type, assign_cols_by_name, 
truncate_return_schema):
     import pandas as pd
 
-    if not isinstance(result, pd.DataFrame):
-        raise TypeError(
-            "Return type of the user-defined function should be "
-            "pandas.DataFrame, but is {}".format(type(result))
-        )
+    if type(return_type) == StructType:
+        if not isinstance(result, pd.DataFrame):
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": "pandas.DataFrame",
+                    "actual": type(result).__name__,
+                },
+            )
+
+        # check the schema of the result only if it is not empty or has columns
+        if not result.empty or len(result.columns) != 0:
+            # if any column name of the result is a string
+            # the column names of the result have to match the return type
+            #   see create_array in 
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
+            field_names = set([field.name for field in return_type.fields])
+            # only the first len(field_names) result columns are considered
+            # when truncating the return schema
+            result_columns = (
+                result.columns[: len(field_names)] if truncate_return_schema 
else result.columns
+            )
+            column_names = set(result_columns)
+            if (
+                assign_cols_by_name
+                and any(isinstance(name, str) for name in result.columns)
+                and column_names != field_names
+            ):
+                missing = sorted(list(field_names.difference(column_names)))
+                missing = f" Missing: {', '.join(missing)}." if missing else ""
 
-    # check the schema of the result only if it is not empty or has columns
-    if not result.empty or len(result.columns) != 0:
-        # if any column name of the result is a string
-        # the column names of the result have to match the return type
-        #   see create_array in 
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
-        field_names = set([field.name for field in return_type.fields])
-        column_names = set(result.columns)
-        if (
-            assign_cols_by_name
-            and any(isinstance(name, str) for name in result.columns)
-            and column_names != field_names
-        ):
-            missing = sorted(list(field_names.difference(column_names)))
-            missing = f" Missing: {', '.join(missing)}." if missing else ""
-
-            extra = sorted(list(column_names.difference(field_names)))
-            extra = f" Unexpected: {', '.join(extra)}." if extra else ""
+                extra = sorted(list(column_names.difference(field_names)))
+                extra = f" Unexpected: {', '.join(extra)}." if extra else ""
 
-            raise PySparkRuntimeError(
-                error_class="RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF",
+                raise PySparkRuntimeError(
+                    error_class="RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF",
+                    message_parameters={
+                        "missing": missing,
+                        "extra": extra,
+                    },
+                )
+            # otherwise the number of columns of result have to match the 
return type
+            elif len(result_columns) != len(return_type):
+                raise PySparkRuntimeError(
+                    error_class="RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF",
+                    message_parameters={
+                        "expected": str(len(return_type)),
+                        "actual": str(len(result.columns)),
+                    },
+                )
+    else:
+        if not isinstance(result, pd.Series):
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={"expected": "pandas.Series", "actual": 
type(result).__name__},
+            )
+
+
+def wrap_arrow_batch_iter_udf(f, return_type):
+    arrow_return_type = to_arrow_type(return_type)
+
+    def verify_result(result):
+        if not isinstance(result, Iterator) and not hasattr(result, 
"__iter__"):
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
                 message_parameters={
-                    "missing": missing,
-                    "extra": extra,
+                    "expected": "iterator of pyarrow.RecordBatch",
+                    "actual": type(result).__name__,
                 },
             )
-        # otherwise the number of columns of result have to match the return 
type
-        elif len(result.columns) != len(return_type):
-            raise PySparkRuntimeError(
-                error_class="RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF",
+        return result
+
+    def verify_element(elem):
+        import pyarrow as pa
+
+        if not isinstance(elem, pa.RecordBatch):
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
                 message_parameters={
-                    "expected": str(len(return_type)),
-                    "actual": str(len(result.columns)),
+                    "expected": "iterator of pyarrow.RecordBatch",
+                    "actual": "iterator of {}".format(type(elem).__name__),
                 },
             )
 
+        return elem
+
+    return lambda *iterator: map(
+        lambda res: (res, arrow_return_type), map(verify_element, 
verify_result(f(*iterator)))
+    )
+
 
 def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
     _assign_cols_by_name = assign_cols_by_name(runner_conf)
@@ -211,7 +284,9 @@ def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
             key_series = left_key_series if not left_df.empty else 
right_key_series
             key = tuple(s[0] for s in key_series)
             result = f(key, left_df, right_df)
-        verify_pandas_result(result, return_type, _assign_cols_by_name)
+        verify_pandas_result(
+            result, return_type, _assign_cols_by_name, 
truncate_return_schema=False
+        )
 
         return result
 
@@ -229,7 +304,9 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, 
runner_conf):
         elif len(argspec.args) == 2:
             key = tuple(s[0] for s in key_series)
             result = f(key, pd.concat(value_series, axis=1))
-        verify_pandas_result(result, return_type, _assign_cols_by_name)
+        verify_pandas_result(
+            result, return_type, _assign_cols_by_name, 
truncate_return_schema=False
+        )
 
         return result
 
@@ -278,9 +355,12 @@ def wrap_grouped_map_pandas_udf_with_state(f, return_type):
 
         def verify_element(result):
             if not isinstance(result, pd.DataFrame):
-                raise TypeError(
-                    "The type of element in return iterator of the 
user-defined function "
-                    "should be pandas.DataFrame, but is 
{}".format(type(result))
+                raise PySparkTypeError(
+                    error_class="UDF_RETURN_TYPE",
+                    message_parameters={
+                        "expected": "iterator of pandas.DataFrame",
+                        "actual": "iterator of 
{}".format(type(result).__name__),
+                    },
                 )
             # the number of columns of result have to match the return type
             # but it is fine for result to have no columns at all if it is 
empty
@@ -299,17 +379,20 @@ def wrap_grouped_map_pandas_udf_with_state(f, 
return_type):
             return result
 
         if isinstance(result_iter, pd.DataFrame):
-            raise TypeError(
-                "Return type of the user-defined function should be "
-                "iterable of pandas.DataFrame, but is 
{}".format(type(result_iter))
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": "iterable of pandas.DataFrame",
+                    "actual": type(result_iter).__name__,
+                },
             )
 
         try:
             iter(result_iter)
         except TypeError:
-            raise TypeError(
-                "Return type of the user-defined function should be "
-                "iterable, but is {}".format(type(result_iter))
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={"expected": "iterable", "actual": 
type(result_iter).__name__},
             )
 
         result_iter_with_validation = (verify_element(x) for x in result_iter)
@@ -423,11 +506,11 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, 
PythonEvalType.SQL_ARROW_BATCHED_UDF):
         return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_batch_iter_udf(func, return_type)
+        return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_batch_iter_udf(func, return_type)
+        return arg_offsets, wrap_pandas_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
-        return arg_offsets, wrap_batch_iter_udf(func, return_type)
+        return arg_offsets, wrap_arrow_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
         return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec, runner_conf)
@@ -547,7 +630,9 @@ def read_udtf(pickleSer, infile, eval_type):
                         )
 
                 # Verify the type and the schema of the result.
-                verify_pandas_result(result, return_type, 
assign_cols_by_name=False)
+                verify_pandas_result(
+                    result, return_type, assign_cols_by_name=False, 
truncate_return_schema=False
+                )
                 return result
 
             return lambda *a: map(lambda res: (res, arrow_return_type), 
map(verify_result, f(*a)))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to