This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new bc6f69a988f [SPARK-43543][PYTHON] Fix nested MapType behavior in Pandas UDF bc6f69a988f is described below commit bc6f69a988f13e5e22cb055e60693a545f0cbadb Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Fri May 19 14:54:59 2023 -0700 [SPARK-43543][PYTHON] Fix nested MapType behavior in Pandas UDF ### What changes were proposed in this pull request? Fix nested MapType behavior in Pandas UDF (and Arrow-optimized Python UDF). Previously during Arrow-pandas conversion, only the outermost layer is converted to a dictionary; but now nested MapType will be converted to nested dictionaries. That applies to Spark Connect as well. ### Why are the changes needed? Correctness and consistency (with `createDataFrame` and `toPandas` when Arrow is enabled). ### Does this PR introduce _any_ user-facing change? Yes. Nested MapType type support is corrected in Pandas UDF ```py >>> schema = StructType([ ... StructField("id", StringType(), True), ... StructField("attributes", MapType(StringType(), MapType(StringType(), StringType())), True) ... ]) >>> >>> data = [ ... ("1", {"personal": {"name": "John", "city": "New York"}}), ... ] >>> df = spark.createDataFrame(data, schema) >>> pandas_udf(StringType()) ... def f(s: pd.Series) -> pd.Series: ... return s.astype(str) ... >>> df.select(f(df.attributes)).show(truncate=False) ``` The results of `df.select(f(df.attributes)).show(truncate=False)` is corrected **FROM** ```py +------------------------------------------------------+ |f(attributes) | +------------------------------------------------------+ |{'personal': [('name', 'John'), ('city', 'New York')]}| +------------------------------------------------------+ ``` **TO** ```py >>> df.select(f(df.attributes)).show(truncate=False) +--------------------------------------------------+ |f(attributes) | +--------------------------------------------------+ |{'personal': {'name': 'John', 'city': 'New York'}}| +--------------------------------------------------+ ``` **Another more obvious example:** ```py >>> pandas_udf(StringType()) ... def extract_name(s:pd.Series) -> pd.Series: ... return s.apply(lambda x: x['personal']['name']) ... >>> df.select(extract_name(df.attributes)).show(truncate=False) ``` `df.select(extract_name(df.attributes)).show(truncate=False)` is corrected **FROM** ```py org.apache.spark.api.python.PythonException: Traceback (most recent call last): ... TypeError: list indices must be integers or slices, not str ``` **TO** ```py +------------------------+ |extract_name(attributes)| +------------------------+ |John | +------------------------+ ``` ### How was this patch tested? Unit tests. Closes #41147 from xinrong-meng/nestedType. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/pandas/serializers.py | 91 ++++------------------ .../sql/tests/pandas/test_pandas_udf_scalar.py | 30 +++++++ 2 files changed, 47 insertions(+), 74 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 9b5db2d000d..e81d90fc23e 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -21,7 +21,12 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for mo from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer -from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type, _create_converter_from_pandas +from pyspark.sql.pandas.types import ( + from_arrow_type, + to_arrow_type, + _create_converter_from_pandas, + _create_converter_to_pandas, +) from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType @@ -168,23 +173,21 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): self._safecheck = safecheck def arrow_to_pandas(self, arrow_column): - from pyspark.sql.pandas.types import ( - _check_series_localize_timestamps, - _convert_map_items_to_dict, - ) - import pyarrow - # If the given column is a date type column, creates a series of datetime.date directly # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by # datetime64[ns] type handling. + # Cast dates to objects instead of datetime64[ns] dtype to avoid overflow. s = arrow_column.to_pandas(date_as_object=True) - if pyarrow.types.is_timestamp(arrow_column.type) and arrow_column.type.tz is not None: - return _check_series_localize_timestamps(s, self._timezone) - elif pyarrow.types.is_map(arrow_column.type): - return _convert_map_items_to_dict(s) - else: - return s + # TODO(SPARK-43579): cache the converter for reuse + converter = _create_converter_to_pandas( + data_type=from_arrow_type(arrow_column.type, prefer_timestamp_ntz=True), + nullable=True, + timezone=self._timezone, + struct_in_pandas="dict", + error_on_duplicated_field_names=True, + ) + return converter(s) def _create_array(self, series, arrow_type): """ @@ -209,7 +212,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): if arrow_type is not None: spark_type = from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - + # TODO(SPARK-43579): cache the converter for reuse conv = _create_converter_from_pandas( spark_type, timezone=self._timezone, error_on_duplicated_field_names=False ) @@ -317,66 +320,6 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column) return s - # To keep the current UDF behavior. - def _create_array(self, series, arrow_type): - """ - Create an Arrow Array from the given pandas.Series and optional type. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - from pyspark.sql.pandas.types import ( - _check_series_convert_timestamps_internal, - _convert_dict_to_map_items, - ) - from pandas.api.types import is_categorical_dtype - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - # Ensure timestamp series are in expected form for Spark internal representation - if ( - arrow_type is not None - and pa.types.is_timestamp(arrow_type) - and arrow_type.tz is not None - ): - series = _check_series_convert_timestamps_internal(series, self._timezone) - elif arrow_type is not None and pa.types.is_map(arrow_type): - series = _convert_dict_to_map_items(series) - elif arrow_type is None and is_categorical_dtype(series.dtype): - series = series.astype(series.dtypes.categories.dtype) - try: - return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck) - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e - except ValueError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - if self._safecheck: - error_msg = error_msg + ( - " It can be caused by overflows or other " - "unsafe conversions warned by Arrow. Arrow safe type check " - "can be disabled by using SQL config " - "`spark.sql.execution.pandas.convertToArrowArraySafely`." - ) - raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e - def _create_batch(self, series): """ Create an Arrow record batch from the given pandas.Series pandas.DataFrame diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 33c957fac58..8fa6010a62f 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -135,6 +135,36 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): result = df.select(tokenize("vals").alias("hi")) self.assertEqual([Row(hi=[["hi", "boo"]]), Row(hi=[["bye", "boo"]])], result.collect()) + def test_pandas_udf_nested_maps(self): + schema = StructType( + [ + StructField("id", StringType(), True), + StructField( + "attributes", MapType(StringType(), MapType(StringType(), StringType())), True + ), + ] + ) + data = [("1", {"personal": {"name": "John", "city": "New York"}})] + df = self.spark.createDataFrame(data, schema) + + @pandas_udf(StringType()) + def f(s: pd.Series) -> pd.Series: + return s.astype(str) + + self.assertEquals( + df.select(f(df.attributes).alias("res")).first(), + Row(res="{'personal': {'name': 'John', 'city': 'New York'}}"), + ) + + @pandas_udf(StringType()) + def extract_name(s: pd.Series) -> pd.Series: + return s.apply(lambda x: x["personal"]["name"]) + + self.assertEquals( + df.select(extract_name(df.attributes).alias("res")).first(), + Row(res="John"), + ) + @unittest.skipIf( pyarrow_version_less_than_minimum("2.0.0"), "Pyarrow version must be 2.0.0 or higher", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org