This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 775e05f [SPARK-37228][SQL][PYTHON] Implement DataFrame.mapInArrow in Python 775e05f is described below commit 775e05f2c3c31fc203cfe4b36df301555ce73ca4 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Mon Nov 15 08:51:23 2021 +0900 [SPARK-37228][SQL][PYTHON] Implement DataFrame.mapInArrow in Python ### What changes were proposed in this pull request? This PR proposes to implement `DataFrame.mapInArrow` that allows users to apply a function with PyArrow record batches such as: ```python def do_something(iterator): for arrow_batch in iterator: # do something with `pyarrow.RecordBatch` and create new `pyarrow.RecordBatch`. # ... yield arrow_batch df.mapInArrow(do_something, df.schema).show() ``` The general idea is simple. It shares the same codebase of `DataFrame.mapInPandas` except the pandas conversion logic. This PR also piggy-backs: - Removes the check in `spark.udf.register` on `SQL_MAP_PANDAS_ITER_UDF`. This type is only used for `DataFrame.mapInPandas` internally, and it cannot be registered as a SQL UDF - Removes the type hints for `pandas_udf` that is used for internal purposes such as `SQL_MAP_PANDAS_ITER_UDF` and `SQL_COGROUPED_MAP_PANDAS_UDF`. Both cannot be used for `pandas_udf` as a SQL expression and it should be hidden to end users. Note that documentation will be done in another PR. ### Why are the changes needed? For usability and technical problems. Both are elabourated in more details at SPARK-37227. Please also see the discussions at https://github.com/apache/spark/pull/26783. ### Does this PR introduce _any_ user-facing change? Yes, this PR adds a new API: ```python import pyarrow as pa df = spark.createDataFrame( [(1, "foo"), (2, None), (3, "bar"), (4, "bar")], "a int, b string") def func(iterator): for batch in iterator: # `batch` is pyarrow.RecordBatch. yield batch df.mapInArrow(func, df.schema).collect() ``` ### How was this patch tested? Manually tested, and unit tests were added. Closes #34505 from HyukjinKwon/SPARK-37228. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/api/python/PythonRunner.scala | 2 + dev/sparktestsupport/modules.py | 1 + python/pyspark/rdd.py | 1 + python/pyspark/rdd.pyi | 2 + python/pyspark/sql/pandas/_typing/__init__.pyi | 9 +- python/pyspark/sql/pandas/functions.py | 8 +- python/pyspark/sql/pandas/functions.pyi | 42 ------- python/pyspark/sql/pandas/group_ops.py | 4 +- python/pyspark/sql/pandas/map_ops.py | 68 +++++++++- python/pyspark/sql/pandas/serializers.py | 45 +++++++ python/pyspark/sql/tests/test_arrow_map.py | 138 +++++++++++++++++++++ python/pyspark/sql/udf.py | 12 +- python/pyspark/worker.py | 30 +++-- .../catalyst/analysis/DeduplicateRelations.scala | 6 + .../plans/logical/pythonLogicalOperators.scala | 15 +++ .../main/scala/org/apache/spark/sql/Dataset.scala | 14 +++ .../spark/sql/execution/SparkStrategies.scala | 2 + ...{MapInPandasExec.scala => MapInBatchExec.scala} | 26 ++-- .../sql/execution/python/MapInPandasExec.scala | 69 +---------- .../execution/python/PythonMapInArrowExec.scala | 38 ++++++ 20 files changed, 385 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index bbe55cb..6a4871b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -52,6 +52,7 @@ private[spark] object PythonEvalType { val SQL_SCALAR_PANDAS_ITER_UDF = 204 val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 + val SQL_MAP_ARROW_ITER_UDF = 207 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -63,6 +64,7 @@ private[spark] object PythonEvalType { case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF" case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" + case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" } } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b87218e..7d3ebb0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -451,6 +451,7 @@ pyspark_sql = Module( "pyspark.sql.tests.test_pandas_cogrouped_map", "pyspark.sql.tests.test_pandas_grouped_map", "pyspark.sql.tests.test_pandas_map", + "pyspark.sql.tests.test_arrow_map", "pyspark.sql.tests.test_pandas_udf", "pyspark.sql.tests.test_pandas_udf_grouped_agg", "pyspark.sql.tests.test_pandas_udf_scalar", diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7e63fca..b997932 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -90,6 +90,7 @@ class PythonEvalType(object): SQL_SCALAR_PANDAS_ITER_UDF = 204 SQL_MAP_PANDAS_ITER_UDF = 205 SQL_COGROUPED_MAP_PANDAS_UDF = 206 + SQL_MAP_ARROW_ITER_UDF = 207 def portable_hash(x): diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index 97d36b9..37ba4c2 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -43,6 +43,7 @@ from pyspark.sql.pandas._typing import ( PandasCogroupedMapUDFType, PandasGroupedAggUDFType, PandasMapIterUDFType, + ArrowMapIterUDFType, ) import pyspark.context from pyspark.resultiterable import ResultIterable @@ -83,6 +84,7 @@ class PythonEvalType: SQL_SCALAR_PANDAS_ITER_UDF: PandasScalarIterUDFType SQL_MAP_PANDAS_ITER_UDF: PandasMapIterUDFType SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType + SQL_MAP_ARROW_ITER_UDF: ArrowMapIterUDFType class BoundedFloat(float): def __new__(cls, mean: float, confidence: float, low: float, high: float) -> BoundedFloat: ... diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 6e9b97b1..305c322 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -35,6 +35,8 @@ from pyspark.sql.pandas._typing.protocols.series import SeriesLike as SeriesLike import pandas.core.frame # type: ignore[import] import pandas.core.series # type: ignore[import] +import pyarrow # type: ignore[import] + # POC compatibility annotations PandasDataFrame: Type[DataFrameLike] = pandas.core.frame.DataFrame PandasSeries: Type[SeriesLike] = pandas.core.series.Series @@ -48,6 +50,7 @@ PandasGroupedMapUDFType = Literal[201] PandasCogroupedMapUDFType = Literal[206] PandasGroupedAggUDFType = Literal[202] PandasMapIterUDFType = Literal[205] +ArrowMapIterUDFType = Literal[207] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike) -> SeriesLike: ... @@ -325,10 +328,8 @@ PandasGroupedAggFunction = Union[ PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]] +ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], Iterable[pyarrow.RecordBatch]] + PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], DataFrameLike] -MapIterPandasUserDefinedFunction = NewType("MapIterPandasUserDefinedFunction", FunctionType) GroupedMapPandasUserDefinedFunction = NewType("GroupedMapPandasUserDefinedFunction", FunctionType) -CogroupedMapPandasUserDefinedFunction = NewType( - "CogroupedMapPandasUserDefinedFunction", FunctionType -) diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 97a563499..ca3263b 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -366,6 +366,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, None, ]: # None means it should infer the type from type hints. @@ -400,12 +401,13 @@ def _create_pandas_udf(f, returnType, evalType): elif evalType in [ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, ]: - # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered + # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered # at `apply` instead. - # In case of 'SQL_MAP_PANDAS_ITER_UDF' and 'SQL_COGROUPED_MAP_PANDAS_UDF', the - # evaluation type will always be set. + # In case of 'SQL_MAP_PANDAS_ITER_UDF', 'SQL_MAP_ARROW_ITER_UDF' and + # 'SQL_COGROUPED_MAP_PANDAS_UDF', the evaluation type will always be set. pass elif len(argspec.annotations) > 0: evalType = infer_eval_type(signature(f)) diff --git a/python/pyspark/sql/pandas/functions.pyi b/python/pyspark/sql/pandas/functions.pyi index 10bea30..7ff06be 100644 --- a/python/pyspark/sql/pandas/functions.pyi +++ b/python/pyspark/sql/pandas/functions.pyi @@ -25,16 +25,10 @@ from pyspark.sql._typing import ( ) from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, - MapIterPandasUserDefinedFunction, - CogroupedMapPandasUserDefinedFunction, - PandasCogroupedMapFunction, - PandasCogroupedMapUDFType, PandasGroupedAggFunction, PandasGroupedAggUDFType, PandasGroupedMapFunction, PandasGroupedMapUDFType, - PandasMapIterFunction, - PandasMapIterUDFType, PandasScalarIterFunction, PandasScalarIterUDFType, PandasScalarToScalarFunction, @@ -130,39 +124,3 @@ def pandas_udf( def pandas_udf( f: Union[AtomicDataTypeOrString, ArrayType], *, functionType: PandasGroupedAggUDFType ) -> Callable[[PandasGroupedAggFunction], UserDefinedFunctionLike]: ... -@overload -def pandas_udf( - f: PandasMapIterFunction, - returnType: Union[StructType, str], - functionType: PandasMapIterUDFType, -) -> MapIterPandasUserDefinedFunction: ... -@overload -def pandas_udf( - f: Union[StructType, str], returnType: PandasMapIterUDFType -) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ... -@overload -def pandas_udf( - *, returnType: Union[StructType, str], functionType: PandasMapIterUDFType -) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ... -@overload -def pandas_udf( - f: Union[StructType, str], *, functionType: PandasMapIterUDFType -) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ... -@overload -def pandas_udf( - f: PandasCogroupedMapFunction, - returnType: Union[StructType, str], - functionType: PandasCogroupedMapUDFType, -) -> CogroupedMapPandasUserDefinedFunction: ... -@overload -def pandas_udf( - f: Union[StructType, str], returnType: PandasCogroupedMapUDFType -) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ... -@overload -def pandas_udf( - *, returnType: Union[StructType, str], functionType: PandasCogroupedMapUDFType -) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ... -@overload -def pandas_udf( - f: Union[StructType, str], *, functionType: PandasCogroupedMapUDFType -) -> Callable[[PandasCogroupedMapFunction], CogroupedMapPandasUserDefinedFunction]: ... diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 0af313b..593b72d 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -347,9 +347,11 @@ class PandasCogroupedOps(object): """ from pyspark.sql.pandas.functions import pandas_udf + # The usage of the pandas_udf is internal so type checking is disabled. udf = pandas_udf( func, returnType=schema, functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF - ) + ) # type: ignore[call-overload] + all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) udf_column = udf(*all_cols) jdf = self._gd1._jgd.flatMapCoGroupsInPandas( # type: ignore[attr-defined] diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index f3bcec9..ce1480f 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -22,7 +22,7 @@ from pyspark.sql.types import StructType if TYPE_CHECKING: from pyspark.sql.dataframe import DataFrame - from pyspark.sql.pandas._typing import PandasMapIterFunction + from pyspark.sql.pandas._typing import PandasMapIterFunction, ArrowMapIterFunction class PandasMapOpsMixin(object): @@ -84,13 +84,77 @@ class PandasMapOpsMixin(object): assert isinstance(self, DataFrame) + # The usage of the pandas_udf is internal so type checking is disabled. udf = pandas_udf( func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - ) + ) # type: ignore[call-overload] udf_column = udf(*[self[col] for col in self.columns]) jdf = self._jdf.mapInPandas(udf_column._jc.expr()) # type: ignore[operator] return DataFrame(jdf, self.sql_ctx) + def mapInArrow( + self, func: "ArrowMapIterFunction", schema: Union[StructType, str] + ) -> "DataFrame": + """ + Maps an iterator of batches in the current :class:`DataFrame` using a Python native + function that takes and outputs a PyArrow's `RecordBatch`, and returns the result as a + :class:`DataFrame`. + + The function should take an iterator of `pyarrow.RecordBatch`\\s and return + another iterator of `pyarrow.RecordBatch`\\s. All columns are passed + together as an iterator of `pyarrow.RecordBatch`\\s to the function and the + returned iterator of `pyarrow.RecordBatch`\\s are combined as a :class:`DataFrame`. + Each `pyarrow.RecordBatch` size can be controlled by + `spark.sql.execution.arrow.maxRecordsPerBatch`. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + func : function + a Python native function that takes an iterator of `pyarrow.RecordBatch`\\s, and + outputs an iterator of `pyarrow.RecordBatch`\\s. + schema : :class:`pyspark.sql.types.DataType` or str + the return type of the `func` in PySpark. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + + Examples + -------- + >>> import pyarrow # doctest: +SKIP + >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) + >>> def filter_func(iterator): + ... for batch in iterator: + ... pdf = batch.to_pandas() + ... yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1]) + >>> df.mapInArrow(filter_func, df.schema).show() # doctest: +SKIP + +---+---+ + | id|age| + +---+---+ + | 1| 21| + +---+---+ + + Notes + ----- + This API is unstable, and for developers. + + See Also + -------- + pyspark.sql.functions.pandas_udf + pyspark.sql.DataFrame.mapInPandas + """ + from pyspark.sql import DataFrame + from pyspark.sql.pandas.functions import pandas_udf + + assert isinstance(self, DataFrame) + + # The usage of the pandas_udf is internal so type checking is disabled. + udf = pandas_udf( + func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF + ) # type: ignore[call-overload] + udf_column = udf(*[self[col] for col in self.columns]) + jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) + def _test() -> None: import doctest diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 16f5b3d..44276a4 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -100,6 +100,51 @@ class ArrowStreamSerializer(Serializer): return "ArrowStreamSerializer" +class ArrowStreamUDFSerializer(ArrowStreamSerializer): + """ + Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch + for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. + """ + + def load_stream(self, stream): + """ + Flatten the struct into Arrow's record batches. + """ + import pyarrow as pa + + batches = super(ArrowStreamUDFSerializer, self).load_stream(stream) + for batch in batches: + struct = batch.column(0) + yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))] + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + import pyarrow as pa + + def wrap_and_init_stream(): + should_write_start_length = True + for batch, _ in iterator: + assert isinstance(batch, pa.RecordBatch) + + # Wrap the root struct + struct = pa.StructArray.from_arrays( + batch.columns, fields=pa.struct(list(batch.schema)) + ) + batch = pa.RecordBatch.from_arrays([struct], ["_0"]) + + # Write the first record batch with initialization. + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream) + + class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ Serializes Pandas.Series as Arrow data with Arrow streaming format. diff --git a/python/pyspark/sql/tests/test_arrow_map.py b/python/pyspark/sql/tests/test_arrow_map.py new file mode 100644 index 0000000..a4c948f --- /dev/null +++ b/python/pyspark/sql/tests/test_arrow_map.py @@ -0,0 +1,138 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import time +import unittest + +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) + +if have_pyarrow: + import pyarrow as pa + +if have_pandas: + import pandas as pd + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message, # type: ignore[arg-type] +) +class MapInArrowTests(ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + def test_map_in_arrow(self): + def func(iterator): + for batch in iterator: + assert isinstance(batch, pa.RecordBatch) + assert batch.schema.names == ["id"] + yield batch + + df = self.spark.range(10) + actual = df.mapInArrow(func, "id long").collect() + expected = df.collect() + self.assertEqual(actual, expected) + + def test_multiple_columns(self): + data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")] + df = self.spark.createDataFrame(data, "a int, b string") + + def func(iterator): + for batch in iterator: + assert isinstance(batch, pa.RecordBatch) + assert batch.schema.types == [pa.int32(), pa.string()] + yield batch + + actual = df.mapInArrow(func, df.schema).collect() + expected = df.collect() + self.assertEqual(actual, expected) + + def test_different_output_length(self): + def func(iterator): + for _ in iterator: + yield pa.RecordBatch.from_pandas(pd.DataFrame({"a": list(range(100))})) + + df = self.spark.range(10) + actual = df.repartition(1).mapInArrow(func, "a long").collect() + self.assertEqual(set((r.a for r in actual)), set(range(100))) + + def test_empty_iterator(self): + def empty_iter(_): + return iter([]) + + self.assertEqual(self.spark.range(10).mapInArrow(empty_iter, "a int, b string").count(), 0) + + def test_empty_rows(self): + def empty_rows(_): + return iter([pa.RecordBatch.from_pandas(pd.DataFrame({"a": []}))]) + + self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a int").count(), 0) + + def test_chain_map_in_arrow(self): + def func(iterator): + for batch in iterator: + assert isinstance(batch, pa.RecordBatch) + assert batch.schema.names == ["id"] + yield batch + + df = self.spark.range(10) + actual = df.mapInArrow(func, "id long").mapInArrow(func, "id long").collect() + expected = df.collect() + self.assertEqual(actual, expected) + + def test_self_join(self): + df1 = self.spark.range(10) + df2 = df1.mapInArrow(lambda iter: iter, "id long") + actual = df2.join(df2).collect() + expected = df1.join(df1).collect() + self.assertEqual(sorted(actual), sorted(expected)) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_arrow_map import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index aa74795..886451a 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -157,7 +157,10 @@ class UserDefinedFunction(object): "UDFs or at groupby.applyInPandas: return type must be a " "StructType." ) - elif self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: + elif ( + self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + or self.evalType == PythonEvalType.SQL_MAP_ARROW_ITER_UDF + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) @@ -168,7 +171,8 @@ class UserDefinedFunction(object): ) else: raise TypeError( - "Invalid return type in mapInPandas: " "return type must be a StructType." + "Invalid return type in mapInPandas/mapInArrow: " + "return type must be a StructType." ) elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: if isinstance(self._returnType_placeholder, StructType): @@ -405,12 +409,10 @@ class UDFRegistration(object): PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, ]: raise ValueError( "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, " - "SQL_SCALAR_PANDAS_ITER_UDF, SQL_GROUPED_AGG_PANDAS_UDF or " - "SQL_MAP_PANDAS_ITER_UDF." + "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF." ) register_udf = _create_udf( f.func, diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5a07582..c2200b2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -53,7 +53,11 @@ from pyspark.serializers import ( PickleSerializer, BatchedSerializer, ) -from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, CogroupUDFSerializer +from pyspark.sql.pandas.serializers import ( + ArrowStreamPandasUDFSerializer, + CogroupUDFSerializer, + ArrowStreamUDFSerializer, +) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType from pyspark.util import fail_on_stopiteration, try_simplify_traceback # type: ignore @@ -123,7 +127,7 @@ def wrap_scalar_pandas_udf(f, return_type): ) -def wrap_pandas_iter_udf(f, return_type): +def wrap_batch_iter_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) def verify_result_type(result): @@ -291,9 +295,11 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: - return arg_offsets, wrap_pandas_iter_udf(func, return_type) + return arg_offsets, wrap_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF: - return arg_offsets, wrap_pandas_iter_udf(func, return_type) + return arg_offsets, wrap_batch_iter_udf(func, return_type) + elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: + return arg_offsets, wrap_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) @@ -318,6 +324,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, @@ -346,6 +353,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: + ser = ArrowStreamUDFSerializer() else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. @@ -363,13 +372,16 @@ def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF - if is_scalar_iter or is_map_iter: + if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter: if is_scalar_iter: assert num_udfs == 1, "One SCALAR_ITER UDF expected here." - if is_map_iter: - assert num_udfs == 1, "One MAP_ITER UDF expected here." + if is_map_pandas_iter: + assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here." + if is_map_arrow_iter: + assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here." arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) @@ -398,7 +410,7 @@ def read_udfs(pickleSer, infile, eval_type): # it's very unlikely the output length is higher than # input length. assert ( - is_map_iter or num_output_rows <= num_input_rows + is_map_pandas_iter or is_map_arrow_iter or num_output_rows <= num_input_rows ), "Pandas SCALAR_ITER UDF outputted more rows than input rows." yield (result_batch, result_type) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 4ff1837..55b1c22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -241,6 +241,12 @@ object DeduplicateRelations extends Rule[LogicalPlan] { newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) + case oldVersion @ PythonMapInArrow(_, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.copy(output = output.map(_.newInstance())) + newVersion.copyTagsFrom(oldVersion) + Seq((oldVersion, newVersion)) + case oldVersion @ AttachDistributedSequence(sequenceAttr, _) if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index af18540..13a40db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -58,6 +58,21 @@ case class MapInPandas( } /** + * Map partitions using a udf: iter(pyarrow.RecordBatch) -> iter(pyarrow.RecordBatch). + * This is used by DataFrame.mapInArrow() in PySpark + */ +case class PythonMapInArrow( + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): PythonMapInArrow = + copy(child = newChild) +} + +/** * Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe * This is used by DataFrame.groupby().cogroup().apply(). */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c8cdc20..da3eea2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2964,6 +2964,20 @@ class Dataset[T] private[sql]( } /** + * Applies a function to each partition in Arrow format. The user-defined function + * defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`. + * Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches. + */ + private[sql] def pythonMapInArrow(func: PythonUDF): DataFrame = { + Dataset.ofRows( + sparkSession, + PythonMapInArrow( + func, + func.dataType.asInstanceOf[StructType].toAttributes, + logicalPlan)) + } + + /** * (Scala-specific) * Returns a new Dataset by first applying a function to all elements of this Dataset, * and then flattening the results. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cdecda4..90c2507 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -755,6 +755,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { func, output, planLater(left), planLater(right)) :: Nil case logical.MapInPandas(func, output, child) => execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil + case logical.PythonMapInArrow(func, output, child) => + execution.python.PythonMapInArrowExec(func, output, planLater(child)) :: Nil case logical.AttachDistributedSequence(attr, child) => execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala similarity index 84% copy from sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala copy to sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 0434710..d25c138 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -20,31 +20,28 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ import org.apache.spark.{ContextAwareIterator, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.UnaryExecNode import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** - * A relation produced by applying a function that takes an iterator of pandas DataFrames - * and outputs an iterator of pandas DataFrames. + * A relation produced by applying a function that takes an iterator of batches + * such as pandas DataFrame or PyArrow's record batches, and outputs an iterator of them. * * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow` - * */ -case class MapInPandasExec( - func: Expression, - output: Seq[Attribute], - child: SparkPlan) - extends UnaryExecNode { +trait MapInBatchExec extends UnaryExecNode { + protected val func: Expression + protected val pythonEvalType: Int - private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val pythonFunction = func.asInstanceOf[PythonUDF].func override def producedAttributes: AttributeSet = AttributeSet(output) @@ -56,7 +53,7 @@ case class MapInPandasExec( child.execute().mapPartitionsInternal { inputIter => // Single function with one struct. val argOffsets = Array(Array(0)) - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) val outputTypes = child.schema @@ -74,7 +71,7 @@ case class MapInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + pythonEvalType, argOffsets, StructType(StructField("struct", outputTypes) :: Nil), sessionLocalTimeZone, @@ -93,7 +90,4 @@ case class MapInPandasExec( }.map(unsafeProj) } } - - override protected def withNewChildInternal(newChild: SparkPlan): MapInPandasExec = - copy(child = newChild) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala index 0434710..7a711b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala @@ -17,82 +17,21 @@ package org.apache.spark.sql.execution.python -import scala.collection.JavaConverters._ - -import org.apache.spark.{ContextAwareIterator, TaskContext} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.{StructField, StructType} -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} +import org.apache.spark.sql.execution.SparkPlan /** * A relation produced by applying a function that takes an iterator of pandas DataFrames * and outputs an iterator of pandas DataFrames. - * - * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and - * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow` - * */ case class MapInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode { - - private val pandasFunction = func.asInstanceOf[PythonUDF].func - - override def producedAttributes: AttributeSet = AttributeSet(output) - - private val batchSize = conf.arrowMaxRecordsPerBatch - - override def outputPartitioning: Partitioning = child.outputPartitioning - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { inputIter => - // Single function with one struct. - val argOffsets = Array(Array(0)) - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - val outputTypes = child.schema - - val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, inputIter) - - // Here we wrap it via another row so that Python sides understand it - // as a DataFrame. - val wrappedIter = contextAwareIterator.map(InternalRow(_)) - - // DO NOT use iter.grouped(). See BatchIterator. - val batchIter = - if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter) - - val columnarBatchIter = new ArrowPythonRunner( - chainedFunc, - PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, - argOffsets, - StructType(StructField("struct", outputTypes) :: Nil), - sessionLocalTimeZone, - pythonRunnerConf).compute(batchIter, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) + extends MapInBatchExec { - columnarBatchIter.flatMap { batch => - // Scalar Iterator UDF returns a StructType column in ColumnarBatch, select - // the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) - } - } + override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_PANDAS_ITER_UDF override protected def withNewChildInternal(newChild: SparkPlan): MapInPandasExec = copy(child = newChild) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala new file mode 100644 index 0000000..e3c1853 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan + +/** + * A relation produced by applying a function that takes an iterator of PyArrow's record batches + * and outputs an iterator of PyArrow's record batches. + */ +case class PythonMapInArrowExec( + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends MapInBatchExec { + + override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_ARROW_ITER_UDF + + override protected def withNewChildInternal(newChild: SparkPlan): PythonMapInArrowExec = + copy(child = newChild) +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org