Repository: spark Updated Branches: refs/heads/master 8141c3e3d -> 63c5bf13c
[SPARK-23334][SQL][PYTHON] Fix pandas_udf with return type StringType() to handle str type properly in Python 2. ## What changes were proposed in this pull request? In Python 2, when `pandas_udf` tries to return string type value created in the udf with `".."`, the execution fails. E.g., ```python from pyspark.sql.functions import pandas_udf, col import pandas as pd df = spark.range(10) str_f = pandas_udf(lambda x: pd.Series(["%s" % i for i in x]), "string") df.select(str_f(col('id'))).show() ``` raises the following exception: ``` ... java.lang.AssertionError: assertion failed: Invalid schema from pandas_udf: expected StringType, got BinaryType at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:93) ... ``` Seems like pyarrow ignores `type` parameter for `pa.Array.from_pandas()` and consider it as binary type when the type is string type and the string values are `str` instead of `unicode` in Python 2. This pr adds a workaround for the case. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN <ues...@databricks.com> Closes #20507 from ueshin/issues/SPARK-23334. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/63c5bf13 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/63c5bf13 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/63c5bf13 Branch: refs/heads/master Commit: 63c5bf13ce5cd3b8d7e7fb88de881ed207fde720 Parents: 8141c3e Author: Takuya UESHIN <ues...@databricks.com> Authored: Tue Feb 6 18:30:50 2018 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Tue Feb 6 18:30:50 2018 +0900 ---------------------------------------------------------------------- python/pyspark/serializers.py | 4 ++++ python/pyspark/sql/tests.py | 9 +++++++++ 2 files changed, 13 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/63c5bf13/python/pyspark/serializers.py ---------------------------------------------------------------------- diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index e870325..91a7f09 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -230,6 +230,10 @@ def _create_batch(series, timezone): s = _check_series_convert_timestamps_internal(s.fillna(0), timezone) # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2 return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False) + elif t is not None and pa.types.is_string(t) and sys.version < '3': + # TODO: need decode before converting to Arrow in Python 2 + return pa.Array.from_pandas(s.apply( + lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series] http://git-wip-us.apache.org/repos/asf/spark/blob/63c5bf13/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 545ec5a..89b7c21 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3922,6 +3922,15 @@ class ScalarPandasUDF(ReusedSQLTestCase): res = df.select(str_f(col('str'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_string_in_udf(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType()) + actual = df.select(str_f(col('id'))) + expected = df.select(col('id').cast('string')) + self.assertEquals(expected.collect(), actual.collect()) + def test_vectorized_udf_datatype_string(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org