This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 775e05f  [SPARK-37228][SQL][PYTHON] Implement DataFrame.mapInArrow in 
Python
775e05f is described below

commit 775e05f2c3c31fc203cfe4b36df301555ce73ca4
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Mon Nov 15 08:51:23 2021 +0900

    [SPARK-37228][SQL][PYTHON] Implement DataFrame.mapInArrow in Python
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to implement `DataFrame.mapInArrow` that allows users to 
apply a function with PyArrow record batches such as:
    
    ```python
    def do_something(iterator):
        for arrow_batch in iterator:
            # do something with `pyarrow.RecordBatch` and create new 
`pyarrow.RecordBatch`.
            # ...
            yield arrow_batch
    
    df.mapInArrow(do_something, df.schema).show()
    ```
    
    The general idea is simple. It shares the same codebase of 
`DataFrame.mapInPandas` except the pandas conversion logic.
    
    This PR also piggy-backs:
    - Removes the check in `spark.udf.register` on `SQL_MAP_PANDAS_ITER_UDF`. 
This type is only used for `DataFrame.mapInPandas` internally, and it cannot be 
registered as a SQL UDF
    - Removes the type hints for `pandas_udf` that is used for internal 
purposes such as `SQL_MAP_PANDAS_ITER_UDF` and `SQL_COGROUPED_MAP_PANDAS_UDF`. 
Both cannot be used for `pandas_udf` as a SQL expression and it should be 
hidden to end users.
    
    Note that documentation will be done in another PR.
    
    ### Why are the changes needed?
    
    For usability and technical problems. Both are elabourated in more details 
at SPARK-37227.
    Please also see the discussions at 
https://github.com/apache/spark/pull/26783.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, this PR adds a new API:
    
    ```python
    import pyarrow as pa
    
    df = spark.createDataFrame(
        [(1, "foo"), (2, None), (3, "bar"), (4, "bar")], "a int, b string")
    
    def func(iterator):
        for batch in iterator:
            # `batch` is pyarrow.RecordBatch.
            yield batch
    
    df.mapInArrow(func, df.schema).collect()
    ```
    
    ### How was this patch tested?
    
    Manually tested, and unit tests were added.
    
    Closes #34505 from HyukjinKwon/SPARK-37228.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/api/python/PythonRunner.scala |   2 +
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/rdd.py                              |   1 +
 python/pyspark/rdd.pyi                             |   2 +
 python/pyspark/sql/pandas/_typing/__init__.pyi     |   9 +-
 python/pyspark/sql/pandas/functions.py             |   8 +-
 python/pyspark/sql/pandas/functions.pyi            |  42 -------
 python/pyspark/sql/pandas/group_ops.py             |   4 +-
 python/pyspark/sql/pandas/map_ops.py               |  68 +++++++++-
 python/pyspark/sql/pandas/serializers.py           |  45 +++++++
 python/pyspark/sql/tests/test_arrow_map.py         | 138 +++++++++++++++++++++
 python/pyspark/sql/udf.py                          |  12 +-
 python/pyspark/worker.py                           |  30 +++--
 .../catalyst/analysis/DeduplicateRelations.scala   |   6 +
 .../plans/logical/pythonLogicalOperators.scala     |  15 +++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  14 +++
 .../spark/sql/execution/SparkStrategies.scala      |   2 +
 ...{MapInPandasExec.scala => MapInBatchExec.scala} |  26 ++--
 .../sql/execution/python/MapInPandasExec.scala     |  69 +----------
 .../execution/python/PythonMapInArrowExec.scala    |  38 ++++++
 20 files changed, 385 insertions(+), 147 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index bbe55cb..6a4871b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -52,6 +52,7 @@ private[spark] object PythonEvalType {
   val SQL_SCALAR_PANDAS_ITER_UDF = 204
   val SQL_MAP_PANDAS_ITER_UDF = 205
   val SQL_COGROUPED_MAP_PANDAS_UDF = 206
+  val SQL_MAP_ARROW_ITER_UDF = 207
 
   def toString(pythonEvalType: Int): String = pythonEvalType match {
     case NON_UDF => "NON_UDF"
@@ -63,6 +64,7 @@ private[spark] object PythonEvalType {
     case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
     case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
     case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
+    case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
   }
 }
 
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index b87218e..7d3ebb0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -451,6 +451,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_pandas_cogrouped_map",
         "pyspark.sql.tests.test_pandas_grouped_map",
         "pyspark.sql.tests.test_pandas_map",
+        "pyspark.sql.tests.test_arrow_map",
         "pyspark.sql.tests.test_pandas_udf",
         "pyspark.sql.tests.test_pandas_udf_grouped_agg",
         "pyspark.sql.tests.test_pandas_udf_scalar",
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7e63fca..b997932 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -90,6 +90,7 @@ class PythonEvalType(object):
     SQL_SCALAR_PANDAS_ITER_UDF = 204
     SQL_MAP_PANDAS_ITER_UDF = 205
     SQL_COGROUPED_MAP_PANDAS_UDF = 206
+    SQL_MAP_ARROW_ITER_UDF = 207
 
 
 def portable_hash(x):
diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi
index 97d36b9..37ba4c2 100644
--- a/python/pyspark/rdd.pyi
+++ b/python/pyspark/rdd.pyi
@@ -43,6 +43,7 @@ from pyspark.sql.pandas._typing import (
     PandasCogroupedMapUDFType,
     PandasGroupedAggUDFType,
     PandasMapIterUDFType,
+    ArrowMapIterUDFType,
 )
 import pyspark.context
 from pyspark.resultiterable import ResultIterable
@@ -83,6 +84,7 @@ class PythonEvalType:
     SQL_SCALAR_PANDAS_ITER_UDF: PandasScalarIterUDFType
     SQL_MAP_PANDAS_ITER_UDF: PandasMapIterUDFType
     SQL_COGROUPED_MAP_PANDAS_UDF: PandasCogroupedMapUDFType
+    SQL_MAP_ARROW_ITER_UDF: ArrowMapIterUDFType
 
 class BoundedFloat(float):
     def __new__(cls, mean: float, confidence: float, low: float, high: float) 
-> BoundedFloat: ...
diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi 
b/python/pyspark/sql/pandas/_typing/__init__.pyi
index 6e9b97b1..305c322 100644
--- a/python/pyspark/sql/pandas/_typing/__init__.pyi
+++ b/python/pyspark/sql/pandas/_typing/__init__.pyi
@@ -35,6 +35,8 @@ from pyspark.sql.pandas._typing.protocols.series import 
SeriesLike as SeriesLike
 import pandas.core.frame  # type: ignore[import]
 import pandas.core.series  # type: ignore[import]
 
+import pyarrow  # type: ignore[import]
+
 # POC compatibility annotations
 PandasDataFrame: Type[DataFrameLike] = pandas.core.frame.DataFrame
 PandasSeries: Type[SeriesLike] = pandas.core.series.Series
@@ -48,6 +50,7 @@ PandasGroupedMapUDFType = Literal[201]
 PandasCogroupedMapUDFType = Literal[206]
 PandasGroupedAggUDFType = Literal[202]
 PandasMapIterUDFType = Literal[205]
+ArrowMapIterUDFType = Literal[207]
 
 class PandasVariadicScalarToScalarFunction(Protocol):
     def __call__(self, *_: DataFrameOrSeriesLike) -> SeriesLike: ...
@@ -325,10 +328,8 @@ PandasGroupedAggFunction = Union[
 
 PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], 
Iterable[DataFrameLike]]
 
+ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]], 
Iterable[pyarrow.RecordBatch]]
+
 PandasCogroupedMapFunction = Callable[[DataFrameLike, DataFrameLike], 
DataFrameLike]
 
-MapIterPandasUserDefinedFunction = NewType("MapIterPandasUserDefinedFunction", 
FunctionType)
 GroupedMapPandasUserDefinedFunction = 
NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
-CogroupedMapPandasUserDefinedFunction = NewType(
-    "CogroupedMapPandasUserDefinedFunction", FunctionType
-)
diff --git a/python/pyspark/sql/pandas/functions.py 
b/python/pyspark/sql/pandas/functions.py
index 97a563499..ca3263b 100644
--- a/python/pyspark/sql/pandas/functions.py
+++ b/python/pyspark/sql/pandas/functions.py
@@ -366,6 +366,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+        PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         None,
     ]:  # None means it should infer the type from type hints.
@@ -400,12 +401,13 @@ def _create_pandas_udf(f, returnType, evalType):
     elif evalType in [
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+        PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
     ]:
-        # In case of 'SQL_GROUPED_MAP_PANDAS_UDF',  deprecation warning is 
being triggered
+        # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is 
being triggered
         # at `apply` instead.
-        # In case of 'SQL_MAP_PANDAS_ITER_UDF' and 
'SQL_COGROUPED_MAP_PANDAS_UDF', the
-        # evaluation type will always be set.
+        # In case of 'SQL_MAP_PANDAS_ITER_UDF', 'SQL_MAP_ARROW_ITER_UDF' and
+        # 'SQL_COGROUPED_MAP_PANDAS_UDF', the evaluation type will always be 
set.
         pass
     elif len(argspec.annotations) > 0:
         evalType = infer_eval_type(signature(f))
diff --git a/python/pyspark/sql/pandas/functions.pyi 
b/python/pyspark/sql/pandas/functions.pyi
index 10bea30..7ff06be 100644
--- a/python/pyspark/sql/pandas/functions.pyi
+++ b/python/pyspark/sql/pandas/functions.pyi
@@ -25,16 +25,10 @@ from pyspark.sql._typing import (
 )
 from pyspark.sql.pandas._typing import (
     GroupedMapPandasUserDefinedFunction,
-    MapIterPandasUserDefinedFunction,
-    CogroupedMapPandasUserDefinedFunction,
-    PandasCogroupedMapFunction,
-    PandasCogroupedMapUDFType,
     PandasGroupedAggFunction,
     PandasGroupedAggUDFType,
     PandasGroupedMapFunction,
     PandasGroupedMapUDFType,
-    PandasMapIterFunction,
-    PandasMapIterUDFType,
     PandasScalarIterFunction,
     PandasScalarIterUDFType,
     PandasScalarToScalarFunction,
@@ -130,39 +124,3 @@ def pandas_udf(
 def pandas_udf(
     f: Union[AtomicDataTypeOrString, ArrayType], *, functionType: 
PandasGroupedAggUDFType
 ) -> Callable[[PandasGroupedAggFunction], UserDefinedFunctionLike]: ...
-@overload
-def pandas_udf(
-    f: PandasMapIterFunction,
-    returnType: Union[StructType, str],
-    functionType: PandasMapIterUDFType,
-) -> MapIterPandasUserDefinedFunction: ...
-@overload
-def pandas_udf(
-    f: Union[StructType, str], returnType: PandasMapIterUDFType
-) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ...
-@overload
-def pandas_udf(
-    *, returnType: Union[StructType, str], functionType: PandasMapIterUDFType
-) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ...
-@overload
-def pandas_udf(
-    f: Union[StructType, str], *, functionType: PandasMapIterUDFType
-) -> Callable[[PandasMapIterFunction], MapIterPandasUserDefinedFunction]: ...
-@overload
-def pandas_udf(
-    f: PandasCogroupedMapFunction,
-    returnType: Union[StructType, str],
-    functionType: PandasCogroupedMapUDFType,
-) -> CogroupedMapPandasUserDefinedFunction: ...
-@overload
-def pandas_udf(
-    f: Union[StructType, str], returnType: PandasCogroupedMapUDFType
-) -> Callable[[PandasCogroupedMapFunction], 
CogroupedMapPandasUserDefinedFunction]: ...
-@overload
-def pandas_udf(
-    *, returnType: Union[StructType, str], functionType: 
PandasCogroupedMapUDFType
-) -> Callable[[PandasCogroupedMapFunction], 
CogroupedMapPandasUserDefinedFunction]: ...
-@overload
-def pandas_udf(
-    f: Union[StructType, str], *, functionType: PandasCogroupedMapUDFType
-) -> Callable[[PandasCogroupedMapFunction], 
CogroupedMapPandasUserDefinedFunction]: ...
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 0af313b..593b72d 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -347,9 +347,11 @@ class PandasCogroupedOps(object):
         """
         from pyspark.sql.pandas.functions import pandas_udf
 
+        # The usage of the pandas_udf is internal so type checking is disabled.
         udf = pandas_udf(
             func, returnType=schema, 
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF
-        )
+        )  # type: ignore[call-overload]
+
         all_cols = self._extract_cols(self._gd1) + 
self._extract_cols(self._gd2)
         udf_column = udf(*all_cols)
         jdf = self._gd1._jgd.flatMapCoGroupsInPandas(  # type: 
ignore[attr-defined]
diff --git a/python/pyspark/sql/pandas/map_ops.py 
b/python/pyspark/sql/pandas/map_ops.py
index f3bcec9..ce1480f 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -22,7 +22,7 @@ from pyspark.sql.types import StructType
 
 if TYPE_CHECKING:
     from pyspark.sql.dataframe import DataFrame
-    from pyspark.sql.pandas._typing import PandasMapIterFunction
+    from pyspark.sql.pandas._typing import PandasMapIterFunction, 
ArrowMapIterFunction
 
 
 class PandasMapOpsMixin(object):
@@ -84,13 +84,77 @@ class PandasMapOpsMixin(object):
 
         assert isinstance(self, DataFrame)
 
+        # The usage of the pandas_udf is internal so type checking is disabled.
         udf = pandas_udf(
             func, returnType=schema, 
functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
-        )
+        )  # type: ignore[call-overload]
         udf_column = udf(*[self[col] for col in self.columns])
         jdf = self._jdf.mapInPandas(udf_column._jc.expr())  # type: 
ignore[operator]
         return DataFrame(jdf, self.sql_ctx)
 
+    def mapInArrow(
+        self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
+        """
+        Maps an iterator of batches in the current :class:`DataFrame` using a 
Python native
+        function that takes and outputs a PyArrow's `RecordBatch`, and returns 
the result as a
+        :class:`DataFrame`.
+
+        The function should take an iterator of `pyarrow.RecordBatch`\\s and 
return
+        another iterator of `pyarrow.RecordBatch`\\s. All columns are passed
+        together as an iterator of `pyarrow.RecordBatch`\\s to the function 
and the
+        returned iterator of `pyarrow.RecordBatch`\\s are combined as a 
:class:`DataFrame`.
+        Each `pyarrow.RecordBatch` size can be controlled by
+        `spark.sql.execution.arrow.maxRecordsPerBatch`.
+
+        .. versionadded:: 3.3.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function that takes an iterator of 
`pyarrow.RecordBatch`\\s, and
+            outputs an iterator of `pyarrow.RecordBatch`\\s.
+        schema : :class:`pyspark.sql.types.DataType` or str
+            the return type of the `func` in PySpark. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+
+        Examples
+        --------
+        >>> import pyarrow  # doctest: +SKIP
+        >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
+        >>> def filter_func(iterator):
+        ...     for batch in iterator:
+        ...         pdf = batch.to_pandas()
+        ...         yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
+        >>> df.mapInArrow(filter_func, df.schema).show()  # doctest: +SKIP
+        +---+---+
+        | id|age|
+        +---+---+
+        |  1| 21|
+        +---+---+
+
+        Notes
+        -----
+        This API is unstable, and for developers.
+
+        See Also
+        --------
+        pyspark.sql.functions.pandas_udf
+        pyspark.sql.DataFrame.mapInPandas
+        """
+        from pyspark.sql import DataFrame
+        from pyspark.sql.pandas.functions import pandas_udf
+
+        assert isinstance(self, DataFrame)
+
+        # The usage of the pandas_udf is internal so type checking is disabled.
+        udf = pandas_udf(
+            func, returnType=schema, 
functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
+        )  # type: ignore[call-overload]
+        udf_column = udf(*[self[col] for col in self.columns])
+        jdf = self._jdf.pythonMapInArrow(udf_column._jc.expr())
+        return DataFrame(jdf, self.sql_ctx)
+
 
 def _test() -> None:
     import doctest
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index 16f5b3d..44276a4 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -100,6 +100,51 @@ class ArrowStreamSerializer(Serializer):
         return "ArrowStreamSerializer"
 
 
+class ArrowStreamUDFSerializer(ArrowStreamSerializer):
+    """
+    Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow 
record batch
+    for applying each function with the raw record arrow batch. See also 
`DataFrame.mapInArrow`.
+    """
+
+    def load_stream(self, stream):
+        """
+        Flatten the struct into Arrow's record batches.
+        """
+        import pyarrow as pa
+
+        batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
+        for batch in batches:
+            struct = batch.column(0)
+            yield [pa.RecordBatch.from_arrays(struct.flatten(), 
schema=pa.schema(struct.type))]
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
+        This should be sent after creating the first record batch so in case 
of an error, it can
+        be sent back to the JVM before the Arrow stream starts.
+        """
+        import pyarrow as pa
+
+        def wrap_and_init_stream():
+            should_write_start_length = True
+            for batch, _ in iterator:
+                assert isinstance(batch, pa.RecordBatch)
+
+                # Wrap the root struct
+                struct = pa.StructArray.from_arrays(
+                    batch.columns, fields=pa.struct(list(batch.schema))
+                )
+                batch = pa.RecordBatch.from_arrays([struct], ["_0"])
+
+                # Write the first record batch with initialization.
+                if should_write_start_length:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    should_write_start_length = False
+                yield batch
+
+        return super(ArrowStreamUDFSerializer, 
self).dump_stream(wrap_and_init_stream(), stream)
+
+
 class ArrowStreamPandasSerializer(ArrowStreamSerializer):
     """
     Serializes Pandas.Series as Arrow data with Arrow streaming format.
diff --git a/python/pyspark/sql/tests/test_arrow_map.py 
b/python/pyspark/sql/tests/test_arrow_map.py
new file mode 100644
index 0000000..a4c948f
--- /dev/null
+++ b/python/pyspark/sql/tests/test_arrow_map.py
@@ -0,0 +1,138 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import time
+import unittest
+
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+
+if have_pyarrow:
+    import pyarrow as pa
+
+if have_pandas:
+    import pandas as pd
+
+
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    pandas_requirement_message or pyarrow_requirement_message,  # type: 
ignore[arg-type]
+)
+class MapInArrowTests(ReusedSQLTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        ReusedSQLTestCase.tearDownClass()
+
+    def test_map_in_arrow(self):
+        def func(iterator):
+            for batch in iterator:
+                assert isinstance(batch, pa.RecordBatch)
+                assert batch.schema.names == ["id"]
+                yield batch
+
+        df = self.spark.range(10)
+        actual = df.mapInArrow(func, "id long").collect()
+        expected = df.collect()
+        self.assertEqual(actual, expected)
+
+    def test_multiple_columns(self):
+        data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
+        df = self.spark.createDataFrame(data, "a int, b string")
+
+        def func(iterator):
+            for batch in iterator:
+                assert isinstance(batch, pa.RecordBatch)
+                assert batch.schema.types == [pa.int32(), pa.string()]
+                yield batch
+
+        actual = df.mapInArrow(func, df.schema).collect()
+        expected = df.collect()
+        self.assertEqual(actual, expected)
+
+    def test_different_output_length(self):
+        def func(iterator):
+            for _ in iterator:
+                yield pa.RecordBatch.from_pandas(pd.DataFrame({"a": 
list(range(100))}))
+
+        df = self.spark.range(10)
+        actual = df.repartition(1).mapInArrow(func, "a long").collect()
+        self.assertEqual(set((r.a for r in actual)), set(range(100)))
+
+    def test_empty_iterator(self):
+        def empty_iter(_):
+            return iter([])
+
+        self.assertEqual(self.spark.range(10).mapInArrow(empty_iter, "a int, b 
string").count(), 0)
+
+    def test_empty_rows(self):
+        def empty_rows(_):
+            return iter([pa.RecordBatch.from_pandas(pd.DataFrame({"a": []}))])
+
+        self.assertEqual(self.spark.range(10).mapInArrow(empty_rows, "a 
int").count(), 0)
+
+    def test_chain_map_in_arrow(self):
+        def func(iterator):
+            for batch in iterator:
+                assert isinstance(batch, pa.RecordBatch)
+                assert batch.schema.names == ["id"]
+                yield batch
+
+        df = self.spark.range(10)
+        actual = df.mapInArrow(func, "id long").mapInArrow(func, "id 
long").collect()
+        expected = df.collect()
+        self.assertEqual(actual, expected)
+
+    def test_self_join(self):
+        df1 = self.spark.range(10)
+        df2 = df1.mapInArrow(lambda iter: iter, "id long")
+        actual = df2.join(df2).collect()
+        expected = df1.join(df1).collect()
+        self.assertEqual(sorted(actual), sorted(expected))
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.test_arrow_map import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index aa74795..886451a 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -157,7 +157,10 @@ class UserDefinedFunction(object):
                     "UDFs or at groupby.applyInPandas: return type must be a "
                     "StructType."
                 )
-        elif self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
+        elif (
+            self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+            or self.evalType == PythonEvalType.SQL_MAP_ARROW_ITER_UDF
+        ):
             if isinstance(self._returnType_placeholder, StructType):
                 try:
                     to_arrow_type(self._returnType_placeholder)
@@ -168,7 +171,8 @@ class UserDefinedFunction(object):
                     )
             else:
                 raise TypeError(
-                    "Invalid return type in mapInPandas: " "return type must 
be a StructType."
+                    "Invalid return type in mapInPandas/mapInArrow: "
+                    "return type must be a StructType."
                 )
         elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             if isinstance(self._returnType_placeholder, StructType):
@@ -405,12 +409,10 @@ class UDFRegistration(object):
                 PythonEvalType.SQL_SCALAR_PANDAS_UDF,
                 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                 PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
-                PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
             ]:
                 raise ValueError(
                     "Invalid f: f must be SQL_BATCHED_UDF, 
SQL_SCALAR_PANDAS_UDF, "
-                    "SQL_SCALAR_PANDAS_ITER_UDF, SQL_GROUPED_AGG_PANDAS_UDF or 
"
-                    "SQL_MAP_PANDAS_ITER_UDF."
+                    "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
                 )
             register_udf = _create_udf(
                 f.func,
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 5a07582..c2200b2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -53,7 +53,11 @@ from pyspark.serializers import (
     PickleSerializer,
     BatchedSerializer,
 )
-from pyspark.sql.pandas.serializers import ArrowStreamPandasUDFSerializer, 
CogroupUDFSerializer
+from pyspark.sql.pandas.serializers import (
+    ArrowStreamPandasUDFSerializer,
+    CogroupUDFSerializer,
+    ArrowStreamUDFSerializer,
+)
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import StructType
 from pyspark.util import fail_on_stopiteration, try_simplify_traceback  # 
type: ignore
@@ -123,7 +127,7 @@ def wrap_scalar_pandas_udf(f, return_type):
     )
 
 
-def wrap_pandas_iter_udf(f, return_type):
+def wrap_batch_iter_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
 
     def verify_result_type(result):
@@ -291,9 +295,11 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
         return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_pandas_iter_udf(func, return_type)
+        return arg_offsets, wrap_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
-        return arg_offsets, wrap_pandas_iter_udf(func, return_type)
+        return arg_offsets, wrap_batch_iter_udf(func, return_type)
+    elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
+        return arg_offsets, wrap_batch_iter_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = getfullargspec(chained_func)  # signature was lost when 
wrapping it
         return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec)
@@ -318,6 +324,7 @@ def read_udfs(pickleSer, infile, eval_type):
         PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
         PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+        PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
         PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
         PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
         PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
@@ -346,6 +353,8 @@ def read_udfs(pickleSer, infile, eval_type):
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
+            ser = ArrowStreamUDFSerializer()
         else:
             # Scalar Pandas UDF handles struct type arguments as pandas 
DataFrames instead of
             # pandas Series. See SPARK-27240.
@@ -363,13 +372,16 @@ def read_udfs(pickleSer, infile, eval_type):
     num_udfs = read_int(infile)
 
     is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
-    is_map_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+    is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+    is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF
 
-    if is_scalar_iter or is_map_iter:
+    if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter:
         if is_scalar_iter:
             assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
-        if is_map_iter:
-            assert num_udfs == 1, "One MAP_ITER UDF expected here."
+        if is_map_pandas_iter:
+            assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."
+        if is_map_arrow_iter:
+            assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
 
         arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
 
@@ -398,7 +410,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 # it's very unlikely the output length is higher than
                 # input length.
                 assert (
-                    is_map_iter or num_output_rows <= num_input_rows
+                    is_map_pandas_iter or is_map_arrow_iter or num_output_rows 
<= num_input_rows
                 ), "Pandas SCALAR_ITER UDF outputted more rows than input 
rows."
                 yield (result_batch, result_type)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 4ff1837..55b1c22 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -241,6 +241,12 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
         newVersion.copyTagsFrom(oldVersion)
         Seq((oldVersion, newVersion))
 
+      case oldVersion @ PythonMapInArrow(_, output, _)
+        if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
+        val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
+
       case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
         if 
oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
         val newVersion = oldVersion.copy(sequenceAttr = 
sequenceAttr.newInstance())
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index af18540..13a40db 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -58,6 +58,21 @@ case class MapInPandas(
 }
 
 /**
+ * Map partitions using a udf: iter(pyarrow.RecordBatch) -> 
iter(pyarrow.RecordBatch).
+ * This is used by DataFrame.mapInArrow() in PySpark
+ */
+case class PythonMapInArrow(
+    functionExpr: Expression,
+    output: Seq[Attribute],
+    child: LogicalPlan) extends UnaryNode {
+
+  override val producedAttributes = AttributeSet(output)
+
+  override protected def withNewChildInternal(newChild: LogicalPlan): 
PythonMapInArrow =
+    copy(child = newChild)
+}
+
+/**
  * Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> 
pandas.Dataframe
  * This is used by DataFrame.groupby().cogroup().apply().
  */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c8cdc20..da3eea2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2964,6 +2964,20 @@ class Dataset[T] private[sql](
   }
 
   /**
+   * Applies a function to each partition in Arrow format. The user-defined 
function
+   * defines a transformation: `iter(pyarrow.RecordBatch)` -> 
`iter(pyarrow.RecordBatch)`.
+   * Each partition is each iterator consisting of `pyarrow.RecordBatch`s as 
batches.
+   */
+  private[sql] def pythonMapInArrow(func: PythonUDF): DataFrame = {
+    Dataset.ofRows(
+      sparkSession,
+      PythonMapInArrow(
+        func,
+        func.dataType.asInstanceOf[StructType].toAttributes,
+        logicalPlan))
+  }
+
+  /**
    * (Scala-specific)
    * Returns a new Dataset by first applying a function to all elements of 
this Dataset,
    * and then flattening the results.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cdecda4..90c2507 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -755,6 +755,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           func, output, planLater(left), planLater(right)) :: Nil
       case logical.MapInPandas(func, output, child) =>
         execution.python.MapInPandasExec(func, output, planLater(child)) :: Nil
+      case logical.PythonMapInArrow(func, output, child) =>
+        execution.python.PythonMapInArrowExec(func, output, planLater(child)) 
:: Nil
       case logical.AttachDistributedSequence(attr, child) =>
         execution.python.AttachDistributedSequenceExec(attr, planLater(child)) 
:: Nil
       case logical.MapElements(f, _, _, objAttr, child) =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
similarity index 84%
copy from 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
copy to 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 0434710..d25c138 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -20,31 +20,28 @@ package org.apache.spark.sql.execution.python
 import scala.collection.JavaConverters._
 
 import org.apache.spark.{ContextAwareIterator, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.api.python.ChainedPythonFunctions
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.UnaryExecNode
 import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
 /**
- * A relation produced by applying a function that takes an iterator of pandas 
DataFrames
- * and outputs an iterator of pandas DataFrames.
+ * A relation produced by applying a function that takes an iterator of batches
+ * such as pandas DataFrame or PyArrow's record batches, and outputs an 
iterator of them.
  *
  * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
  * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
- *
  */
-case class MapInPandasExec(
-    func: Expression,
-    output: Seq[Attribute],
-    child: SparkPlan)
-  extends UnaryExecNode {
+trait MapInBatchExec extends UnaryExecNode {
+  protected val func: Expression
+  protected val pythonEvalType: Int
 
-  private val pandasFunction = func.asInstanceOf[PythonUDF].func
+  private val pythonFunction = func.asInstanceOf[PythonUDF].func
 
   override def producedAttributes: AttributeSet = AttributeSet(output)
 
@@ -56,7 +53,7 @@ case class MapInPandasExec(
     child.execute().mapPartitionsInternal { inputIter =>
       // Single function with one struct.
       val argOffsets = Array(Array(0))
-      val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+      val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
       val sessionLocalTimeZone = conf.sessionLocalTimeZone
       val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
       val outputTypes = child.schema
@@ -74,7 +71,7 @@ case class MapInPandasExec(
 
       val columnarBatchIter = new ArrowPythonRunner(
         chainedFunc,
-        PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+        pythonEvalType,
         argOffsets,
         StructType(StructField("struct", outputTypes) :: Nil),
         sessionLocalTimeZone,
@@ -93,7 +90,4 @@ case class MapInPandasExec(
       }.map(unsafeProj)
     }
   }
-
-  override protected def withNewChildInternal(newChild: SparkPlan): 
MapInPandasExec =
-    copy(child = newChild)
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
index 0434710..7a711b5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInPandasExec.scala
@@ -17,82 +17,21 @@
 
 package org.apache.spark.sql.execution.python
 
-import scala.collection.JavaConverters._
-
-import org.apache.spark.{ContextAwareIterator, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.types.{StructField, StructType}
-import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+import org.apache.spark.sql.execution.SparkPlan
 
 /**
  * A relation produced by applying a function that takes an iterator of pandas 
DataFrames
  * and outputs an iterator of pandas DataFrames.
- *
- * This is somewhat similar with [[FlatMapGroupsInPandasExec]] and
- * `org.apache.spark.sql.catalyst.plans.logical.MapPartitionsInRWithArrow`
- *
  */
 case class MapInPandasExec(
     func: Expression,
     output: Seq[Attribute],
     child: SparkPlan)
-  extends UnaryExecNode {
-
-  private val pandasFunction = func.asInstanceOf[PythonUDF].func
-
-  override def producedAttributes: AttributeSet = AttributeSet(output)
-
-  private val batchSize = conf.arrowMaxRecordsPerBatch
-
-  override def outputPartitioning: Partitioning = child.outputPartitioning
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { inputIter =>
-      // Single function with one struct.
-      val argOffsets = Array(Array(0))
-      val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
-      val sessionLocalTimeZone = conf.sessionLocalTimeZone
-      val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
-      val outputTypes = child.schema
-
-      val context = TaskContext.get()
-      val contextAwareIterator = new ContextAwareIterator(context, inputIter)
-
-      // Here we wrap it via another row so that Python sides understand it
-      // as a DataFrame.
-      val wrappedIter = contextAwareIterator.map(InternalRow(_))
-
-      // DO NOT use iter.grouped(). See BatchIterator.
-      val batchIter =
-        if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else 
Iterator(wrappedIter)
-
-      val columnarBatchIter = new ArrowPythonRunner(
-        chainedFunc,
-        PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
-        argOffsets,
-        StructType(StructField("struct", outputTypes) :: Nil),
-        sessionLocalTimeZone,
-        pythonRunnerConf).compute(batchIter, context.partitionId(), context)
-
-      val unsafeProj = UnsafeProjection.create(output, output)
+  extends MapInBatchExec {
 
-      columnarBatchIter.flatMap { batch =>
-        // Scalar Iterator UDF returns a StructType column in ColumnarBatch, 
select
-        // the children here
-        val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
-        val outputVectors = output.indices.map(structVector.getChild)
-        val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
-        flattenedBatch.setNumRows(batch.numRows())
-        flattenedBatch.rowIterator.asScala
-      }.map(unsafeProj)
-    }
-  }
+  override protected val pythonEvalType: Int = 
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
 
   override protected def withNewChildInternal(newChild: SparkPlan): 
MapInPandasExec =
     copy(child = newChild)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
new file mode 100644
index 0000000..e3c1853
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonMapInArrowExec.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * A relation produced by applying a function that takes an iterator of 
PyArrow's record batches
+ * and outputs an iterator of PyArrow's record batches.
+ */
+case class PythonMapInArrowExec(
+    func: Expression,
+    output: Seq[Attribute],
+    child: SparkPlan)
+  extends MapInBatchExec {
+
+  override protected val pythonEvalType: Int = 
PythonEvalType.SQL_MAP_ARROW_ITER_UDF
+
+  override protected def withNewChildInternal(newChild: SparkPlan): 
PythonMapInArrowExec =
+    copy(child = newChild)
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to