This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6e56cfeaca8 [SPARK-44150][PYTHON][CONNECT] Explicit Arrow casting for mismatched return type in Arrow Python UDF 6e56cfeaca8 is described below commit 6e56cfeaca884b1ccfaa8524c70f12f118bc840c Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Thu Jun 29 11:46:06 2023 -0700 [SPARK-44150][PYTHON][CONNECT] Explicit Arrow casting for mismatched return type in Arrow Python UDF ### What changes were proposed in this pull request? Explicit Arrow casting for the mismatched return type of Arrow Python UDF. ### Why are the changes needed? A more standardized and coherent type coercion. Please refer to https://github.com/apache/spark/pull/41706 for a comprehensive comparison between type coercion rules of Arrow and Pickle(used by the default Python UDF) separately. See more at [[Design] Type-coercion in Arrow Python UDFs](https://docs.google.com/document/d/e/2PACX-1vTEGElOZfhl9NfgbBw4CTrlm-8F_xQCAKNOXouz-7mg5vYobS7lCGUsGkDZxPY0wV5YkgoZmkYlxccU/pub). ### Does this PR introduce _any_ user-facing change? Yes. FROM ```py >>> df = spark.createDataFrame(['1', '2'], schema='string') df.select(pandas_udf(lambda x: x, 'int')('value')).show() >>> df.select(pandas_udf(lambda x: x, 'int')('value')).show() ... org.apache.spark.api.python.PythonException: Traceback (most recent call last): ... pyarrow.lib.ArrowInvalid: Could not convert '1' with type str: tried to convert to int32 ``` TO ```py >>> df = spark.createDataFrame(['1', '2'], schema='string') >>> df.select(pandas_udf(lambda x: x, 'int')('value')).show() +---------------+ |<lambda>(value)| +---------------+ | 1| | 2| +---------------+ ``` ### How was this patch tested? Unit tests. Closes #41503 from xinrong-meng/type_coersion. Authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> --- python/pyspark/sql/pandas/serializers.py | 30 ++++++++++++++--- python/pyspark/sql/tests/test_arrow_python_udf.py | 39 +++++++++++++++++++++++ python/pyspark/worker.py | 3 ++ 3 files changed, 67 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 307fcc33752..a99eda9cbea 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -190,7 +190,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): ) return converter(s) - def _create_array(self, series, arrow_type, spark_type=None): + def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): """ Create an Arrow Array from the given pandas.Series and optional type. @@ -202,6 +202,9 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): If None, pyarrow's inferred type will be used spark_type : DataType, optional If None, spark type converted from arrow_type will be used + arrow_cast: bool, optional + Whether to apply Arrow casting when the user-specified return type mismatches the + actual return values. Returns ------- @@ -226,7 +229,12 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): else: mask = series.isnull() try: - return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck) + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask, type=arrow_type).cast( + target_type=arrow_type, safe=self._safecheck + ) + else: + return pa.Array.from_pandas(series, mask=mask, safe=self._safecheck) except TypeError as e: error_msg = ( "Exception thrown when converting pandas.Series (%s) " @@ -319,12 +327,14 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): df_for_struct=False, struct_in_pandas="dict", ndarray_as_list=False, + arrow_cast=False, ): super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck) self._assign_cols_by_name = assign_cols_by_name self._df_for_struct = df_for_struct self._struct_in_pandas = struct_in_pandas self._ndarray_as_list = ndarray_as_list + self._arrow_cast = arrow_cast def arrow_to_pandas(self, arrow_column): import pyarrow.types as types @@ -386,7 +396,13 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): # Assign result columns by schema name if user labeled with strings elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns): arrs_names = [ - (self._create_array(s[field.name], field.type), field.name) for field in t + ( + self._create_array( + s[field.name], field.type, arrow_cast=self._arrow_cast + ), + field.name, + ) + for field in t ] # Assign result columns by position else: @@ -394,7 +410,11 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): # the selected series has name '1', so we rename it to field.name # as the name is used by _create_array to provide a meaningful error message ( - self._create_array(s[s.columns[i]].rename(field.name), field.type), + self._create_array( + s[s.columns[i]].rename(field.name), + field.type, + arrow_cast=self._arrow_cast, + ), field.name, ) for i, field in enumerate(t) @@ -403,7 +423,7 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): struct_arrs, struct_names = zip(*arrs_names) arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) else: - arrs.append(self._create_array(s, t)) + arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast)) return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 0accb0f3cc1..264ea0b901f 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -17,6 +17,8 @@ import unittest +from pyspark.errors import PythonException +from pyspark.sql import Row from pyspark.sql.functions import udf from pyspark.sql.tests.test_udf import BaseUDFTestsMixin from pyspark.testing.sqlutils import ( @@ -141,6 +143,43 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin): "[[1, 2], [3, 4]]", ) + def test_type_coercion_string_to_numeric(self): + df_int_value = self.spark.createDataFrame(["1", "2"], schema="string") + df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string") + + int_ddl_types = ["tinyint", "smallint", "int", "bigint"] + floating_ddl_types = ["double", "float"] + + for ddl_type in int_ddl_types: + # df_int_value + res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEquals(res.collect(), [Row(res=1), Row(res=2)]) + self.assertEquals(res.dtypes[0][1], ddl_type) + + floating_results = [ + [Row(res=1.1), Row(res=2.2)], + [Row(res=1.100000023841858), Row(res=2.200000047683716)], + ] + for ddl_type, floating_res in zip(floating_ddl_types, floating_results): + # df_int_value + res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEquals(res.collect(), [Row(res=1.0), Row(res=2.0)]) + self.assertEquals(res.dtypes[0][1], ddl_type) + # df_floating_value + res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEquals(res.collect(), floating_res) + self.assertEquals(res.dtypes[0][1], ddl_type) + + # invalid + with self.assertRaises(PythonException): + df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect() + + with self.assertRaises(PythonException): + df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() + + with self.assertRaises(PythonException): + df_floating_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 71a7ccd15aa..577286a7357 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -598,6 +598,8 @@ def read_udfs(pickleSer, infile, eval_type): "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" ) ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion + arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF ser = ArrowStreamPandasUDFSerializer( timezone, safecheck, @@ -605,6 +607,7 @@ def read_udfs(pickleSer, infile, eval_type): df_for_struct, struct_in_pandas, ndarray_as_list, + arrow_cast, ) else: ser = BatchedSerializer(CPickleSerializer(), 100) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org