This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bc6f69a988f [SPARK-43543][PYTHON] Fix nested MapType behavior in 
Pandas UDF
bc6f69a988f is described below

commit bc6f69a988f13e5e22cb055e60693a545f0cbadb
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Fri May 19 14:54:59 2023 -0700

    [SPARK-43543][PYTHON] Fix nested MapType behavior in Pandas UDF
    
    ### What changes were proposed in this pull request?
    Fix nested MapType behavior in Pandas UDF (and Arrow-optimized Python UDF).
    
    Previously during Arrow-pandas conversion, only the outermost layer is 
converted to a dictionary; but now nested MapType will be converted to nested 
dictionaries.
    
    That applies to Spark Connect as well.
    
    ### Why are the changes needed?
    Correctness and consistency (with `createDataFrame` and `toPandas` when 
Arrow is enabled).
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    Nested MapType type support is corrected in Pandas UDF
    
    ```py
    >>> schema = StructType([
    ...      StructField("id", StringType(), True),
    ...      StructField("attributes", MapType(StringType(), 
MapType(StringType(), StringType())), True)
    ... ])
    >>>
    >>> data = [
    ...    ("1", {"personal": {"name": "John", "city": "New York"}}),
    ... ]
    >>> df = spark.createDataFrame(data, schema)
    >>> pandas_udf(StringType())
    ... def f(s: pd.Series) -> pd.Series:
    ...    return s.astype(str)
    ...
    >>> df.select(f(df.attributes)).show(truncate=False)
    ```
    
    The results of `df.select(f(df.attributes)).show(truncate=False)` is 
corrected
    
    **FROM**
    ```py
    
    +------------------------------------------------------+
    |f(attributes)                                         |
    +------------------------------------------------------+
    |{'personal': [('name', 'John'), ('city', 'New York')]}|
    +------------------------------------------------------+
    ```
    
    **TO**
    ```py
    >>> df.select(f(df.attributes)).show(truncate=False)
    +--------------------------------------------------+
    |f(attributes)                                     |
    +--------------------------------------------------+
    |{'personal': {'name': 'John', 'city': 'New York'}}|
    +--------------------------------------------------+
    
    ```
    
    **Another more obvious example:**
    ```py
    >>> pandas_udf(StringType())
    ... def extract_name(s:pd.Series) -> pd.Series:
    ...     return s.apply(lambda x: x['personal']['name'])
    ...
    >>> df.select(extract_name(df.attributes)).show(truncate=False)
    ```
    
    `df.select(extract_name(df.attributes)).show(truncate=False)` is corrected
    
    **FROM**
    ```py
    org.apache.spark.api.python.PythonException: Traceback (most recent call 
last):
    ...
    TypeError: list indices must be integers or slices, not str
    ```
    
    **TO**
    ```py
    +------------------------+
    |extract_name(attributes)|
    +------------------------+
    |John                    |
    +------------------------+
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #41147 from xinrong-meng/nestedType.
    
    Authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py           | 91 ++++------------------
 .../sql/tests/pandas/test_pandas_udf_scalar.py     | 30 +++++++
 2 files changed, 47 insertions(+), 74 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 9b5db2d000d..e81d90fc23e 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -21,7 +21,12 @@ Serializers for PyArrow and pandas conversions. See 
`pyspark.serializers` for mo
 
 from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.serializers import Serializer, read_int, write_int, 
UTF8Deserializer, CPickleSerializer
-from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type, 
_create_converter_from_pandas
+from pyspark.sql.pandas.types import (
+    from_arrow_type,
+    to_arrow_type,
+    _create_converter_from_pandas,
+    _create_converter_to_pandas,
+)
 from pyspark.sql.types import StringType, StructType, BinaryType, StructField, 
LongType
 
 
@@ -168,23 +173,21 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         self._safecheck = safecheck
 
     def arrow_to_pandas(self, arrow_column):
-        from pyspark.sql.pandas.types import (
-            _check_series_localize_timestamps,
-            _convert_map_items_to_dict,
-        )
-        import pyarrow
-
         # If the given column is a date type column, creates a series of 
datetime.date directly
         # instead of creating datetime64[ns] as intermediate data to avoid 
overflow caused by
         # datetime64[ns] type handling.
+        # Cast dates to objects instead of datetime64[ns] dtype to avoid 
overflow.
         s = arrow_column.to_pandas(date_as_object=True)
 
-        if pyarrow.types.is_timestamp(arrow_column.type) and 
arrow_column.type.tz is not None:
-            return _check_series_localize_timestamps(s, self._timezone)
-        elif pyarrow.types.is_map(arrow_column.type):
-            return _convert_map_items_to_dict(s)
-        else:
-            return s
+        # TODO(SPARK-43579): cache the converter for reuse
+        converter = _create_converter_to_pandas(
+            data_type=from_arrow_type(arrow_column.type, 
prefer_timestamp_ntz=True),
+            nullable=True,
+            timezone=self._timezone,
+            struct_in_pandas="dict",
+            error_on_duplicated_field_names=True,
+        )
+        return converter(s)
 
     def _create_array(self, series, arrow_type):
         """
@@ -209,7 +212,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
 
         if arrow_type is not None:
             spark_type = from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
-
+            # TODO(SPARK-43579): cache the converter for reuse
             conv = _create_converter_from_pandas(
                 spark_type, timezone=self._timezone, 
error_on_duplicated_field_names=False
             )
@@ -317,66 +320,6 @@ class 
ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
             s = super(ArrowStreamPandasUDFSerializer, 
self).arrow_to_pandas(arrow_column)
         return s
 
-    # To keep the current UDF behavior.
-    def _create_array(self, series, arrow_type):
-        """
-        Create an Arrow Array from the given pandas.Series and optional type.
-
-        Parameters
-        ----------
-        series : pandas.Series
-            A single series
-        arrow_type : pyarrow.DataType, optional
-            If None, pyarrow's inferred type will be used
-
-        Returns
-        -------
-        pyarrow.Array
-        """
-        import pyarrow as pa
-        from pyspark.sql.pandas.types import (
-            _check_series_convert_timestamps_internal,
-            _convert_dict_to_map_items,
-        )
-        from pandas.api.types import is_categorical_dtype
-
-        if hasattr(series.array, "__arrow_array__"):
-            mask = None
-        else:
-            mask = series.isnull()
-        # Ensure timestamp series are in expected form for Spark internal 
representation
-        if (
-            arrow_type is not None
-            and pa.types.is_timestamp(arrow_type)
-            and arrow_type.tz is not None
-        ):
-            series = _check_series_convert_timestamps_internal(series, 
self._timezone)
-        elif arrow_type is not None and pa.types.is_map(arrow_type):
-            series = _convert_dict_to_map_items(series)
-        elif arrow_type is None and is_categorical_dtype(series.dtype):
-            series = series.astype(series.dtypes.categories.dtype)
-        try:
-            return pa.Array.from_pandas(series, mask=mask, type=arrow_type, 
safe=self._safecheck)
-        except TypeError as e:
-            error_msg = (
-                "Exception thrown when converting pandas.Series (%s) "
-                "with name '%s' to Arrow Array (%s)."
-            )
-            raise PySparkTypeError(error_msg % (series.dtype, series.name, 
arrow_type)) from e
-        except ValueError as e:
-            error_msg = (
-                "Exception thrown when converting pandas.Series (%s) "
-                "with name '%s' to Arrow Array (%s)."
-            )
-            if self._safecheck:
-                error_msg = error_msg + (
-                    " It can be caused by overflows or other "
-                    "unsafe conversions warned by Arrow. Arrow safe type check 
"
-                    "can be disabled by using SQL config "
-                    "`spark.sql.execution.pandas.convertToArrowArraySafely`."
-                )
-            raise PySparkValueError(error_msg % (series.dtype, series.name, 
arrow_type)) from e
-
     def _create_batch(self, series):
         """
         Create an Arrow record batch from the given pandas.Series 
pandas.DataFrame
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index 33c957fac58..8fa6010a62f 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -135,6 +135,36 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         result = df.select(tokenize("vals").alias("hi"))
         self.assertEqual([Row(hi=[["hi", "boo"]]), Row(hi=[["bye", "boo"]])], 
result.collect())
 
+    def test_pandas_udf_nested_maps(self):
+        schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField(
+                    "attributes", MapType(StringType(), MapType(StringType(), 
StringType())), True
+                ),
+            ]
+        )
+        data = [("1", {"personal": {"name": "John", "city": "New York"}})]
+        df = self.spark.createDataFrame(data, schema)
+
+        @pandas_udf(StringType())
+        def f(s: pd.Series) -> pd.Series:
+            return s.astype(str)
+
+        self.assertEquals(
+            df.select(f(df.attributes).alias("res")).first(),
+            Row(res="{'personal': {'name': 'John', 'city': 'New York'}}"),
+        )
+
+        @pandas_udf(StringType())
+        def extract_name(s: pd.Series) -> pd.Series:
+            return s.apply(lambda x: x["personal"]["name"])
+
+        self.assertEquals(
+            df.select(extract_name(df.attributes).alias("res")).first(),
+            Row(res="John"),
+        )
+
     @unittest.skipIf(
         pyarrow_version_less_than_minimum("2.0.0"),
         "Pyarrow version must be 2.0.0 or higher",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to