This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 000895da3f6 [SPARK-42510][CONNECT][PYTHON] Implement 
`DataFrame.mapInPandas`
000895da3f6 is described below

commit 000895da3f6c0d17ccfdfe79c0ca34dfb9fb6e7b
Author: Xinrong Meng <xinr...@apache.org>
AuthorDate: Sat Feb 25 07:39:54 2023 +0800

    [SPARK-42510][CONNECT][PYTHON] Implement `DataFrame.mapInPandas`
    
    ### What changes were proposed in this pull request?
    Implement `DataFrame.mapInPandas` and enable parity tests to vanilla 
PySpark.
    
    A proto message `FrameMap` is intorudced for `mapInPandas` and 
`mapInArrow`(to implement next).
    
    ### Why are the changes needed?
    To reach parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `DataFrame.mapInPandas` is supported. An example is as shown below.
    
    ```py
    >>> df = spark.range(2)
    >>> def filter_func(iterator):
    ...   for pdf in iterator:
    ...     yield pdf[pdf.id == 1]
    ...
    >>> df.mapInPandas(filter_func, df.schema)
    DataFrame[id: bigint]
    >>> df.mapInPandas(filter_func, df.schema).show()
    +---+
    | id|
    +---+
    |  1|
    +---+
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #40104 from xinrong-meng/mapInPandas.
    
    Lead-authored-by: Xinrong Meng <xinr...@apache.org>]
    Co-authored-by: Xinrong Meng <xinr...@apache.org>
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
    (cherry picked from commit 9abccad1d93a243d7e47e53dcbc85568a460c529)
    Signed-off-by: Xinrong Meng <xinr...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  10 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  18 +-
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/sql/connect/_typing.py              |   8 +-
 python/pyspark/sql/connect/client.py               |   2 +-
 python/pyspark/sql/connect/dataframe.py            |  22 +-
 python/pyspark/sql/connect/expressions.py          |   6 +-
 python/pyspark/sql/connect/plan.py                 |  25 ++-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 222 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  36 ++++
 python/pyspark/sql/connect/types.py                |   4 +-
 python/pyspark/sql/connect/udf.py                  |  20 +-
 python/pyspark/sql/pandas/map_ops.py               |   3 +
 .../sql/tests/connect/test_parity_pandas_map.py    |  50 +++++
 python/pyspark/sql/tests/pandas/test_pandas_map.py |  46 +++--
 15 files changed, 331 insertions(+), 142 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 29fffd65c75..4d96b6b0c7e 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -60,6 +60,7 @@ message Relation {
     Unpivot unpivot = 25;
     ToSchema to_schema = 26;
     RepartitionByExpression repartition_by_expression = 27;
+    FrameMap frame_map = 28;
 
     // NA functions
     NAFill fill_na = 90;
@@ -768,3 +769,12 @@ message RepartitionByExpression {
   // (Optional) number of partitions, must be positive.
   optional int32 num_partitions = 3;
 }
+
+message FrameMap {
+  // (Required) Input relation for a Frame Map API: mapInPandas, mapInArrow.
+  Relation input = 1;
+
+  // (Required) Input user-defined function of a Frame Map API.
+  CommonInlineUserDefinedFunction func = 2;
+}
+
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 268bf02fad9..cc43c1cace3 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -24,7 +24,7 @@ import com.google.common.collect.{Lists, Maps}
 import com.google.protobuf.{Any => ProtoAny}
 
 import org.apache.spark.TaskContext
-import org.apache.spark.api.python.SimplePythonFunction
+import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
@@ -106,6 +106,8 @@ class SparkConnectPlanner(val session: SparkSession) {
       case proto.Relation.RelTypeCase.UNPIVOT => 
transformUnpivot(rel.getUnpivot)
       case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
         transformRepartitionByExpression(rel.getRepartitionByExpression)
+      case proto.Relation.RelTypeCase.FRAME_MAP =>
+        transformFrameMap(rel.getFrameMap)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but 
is empty.")
 
@@ -458,6 +460,20 @@ class SparkConnectPlanner(val session: SparkSession) {
       .logicalPlan
   }
 
+  private def transformFrameMap(rel: proto.FrameMap): LogicalPlan = {
+    val commonUdf = rel.getFunc
+    val pythonUdf = transformPythonUDF(commonUdf)
+    pythonUdf.evalType match {
+      case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
+        logical.MapInPandas(
+          pythonUdf,
+          pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
+          transformRelation(rel.getInput))
+      case _ =>
+        throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} 
is not supported")
+    }
+  }
+
   private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): 
LogicalPlan = {
     Dataset
       .ofRows(session, transformRelation(rel.getInput))
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 75a6b4401b8..b849892e20a 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -532,6 +532,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_parity_readwriter",
         "pyspark.sql.tests.connect.test_parity_udf",
         "pyspark.sql.tests.connect.test_parity_pandas_udf",
+        "pyspark.sql.tests.connect.test_parity_pandas_map",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/pyspark/sql/connect/_typing.py 
b/python/pyspark/sql/connect/_typing.py
index 66b08d898fe..c91d4e629d8 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -22,10 +22,12 @@ if sys.version_info >= (3, 8):
 else:
     from typing_extensions import Protocol
 
-from typing import Any, Callable, Union, Optional
+from typing import Any, Callable, Iterable, Union, Optional
 import datetime
 import decimal
 
+from pandas.core.frame import DataFrame as PandasDataFrame
+
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.types import DataType
 
@@ -44,6 +46,10 @@ DateTimeLiteral = Union[datetime.datetime, datetime.date]
 
 DataTypeOrString = Union[DataType, str]
 
+DataFrameLike = PandasDataFrame
+
+PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], 
Iterable[DataFrameLike]]
+
 
 class UserDefinedFunctionLike(Protocol):
     func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index 154dd161e92..7ae4645863b 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -494,7 +494,7 @@ class SparkConnectClient(object):
             deterministic=deterministic,
             arguments=[],
             function=py_udf,
-        ).to_command(self)
+        ).to_plan_udf(self)
 
         # construct the request
         req = self._execute_plan_request_with_metadata()
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 393f7f42ec8..b2253c21b66 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -51,6 +51,7 @@ from pyspark.sql.dataframe import (
 
 from pyspark.errors import PySparkTypeError
 from pyspark.errors.exceptions.connect import SparkConnectException
+from pyspark.rdd import PythonEvalType
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
@@ -73,6 +74,7 @@ if TYPE_CHECKING:
         LiteralType,
         PrimitiveType,
         OptionalPrimitiveType,
+        PandasMapIterFunction,
     )
     from pyspark.sql.connect.session import SparkSession
 
@@ -1540,8 +1542,24 @@ class DataFrame:
     def storageLevel(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("storageLevel() is not implemented.")
 
-    def mapInPandas(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("mapInPandas() is not implemented.")
+    def mapInPandas(
+        self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
+        from pyspark.sql.connect.udf import UserDefinedFunction
+
+        if self._plan is None:
+            raise Exception("Cannot mapInPandas when self._plan is empty.")
+
+        udf_obj = UserDefinedFunction(
+            func, returnType=schema, 
evalType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+        )
+
+        return DataFrame.withPlan(
+            plan.FrameMap(child=self._plan, function=udf_obj, 
cols=self.columns),
+            session=self._session,
+        )
+
+    mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__
 
     def mapInArrow(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("mapInArrow() is not implemented.")
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 76e4252dce7..f3c9e2c70c4 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -549,10 +549,14 @@ class CommonInlineUserDefinedFunction(Expression):
         )
         return expr
 
-    def to_command(self, session: "SparkConnectClient") -> 
"proto.CommonInlineUserDefinedFunction":
+    def to_plan_udf(self, session: "SparkConnectClient") -> 
"proto.CommonInlineUserDefinedFunction":
+        """Compared to `to_plan`, it returns a CommonInlineUserDefinedFunction 
instead of an
+        Expression."""
         expr = proto.CommonInlineUserDefinedFunction()
         expr.function_name = self._function_name
         expr.deterministic = self._deterministic
+        if len(self._arguments) > 0:
+            expr.arguments.extend([arg.to_plan(session) for arg in 
self._arguments])
         expr.python_udf.CopyFrom(self._function.to_plan(session))
         return expr
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 0f27b214502..badbb9871ed 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -30,12 +30,17 @@ from pyspark.sql.types import DataType
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import SortOrder, ColumnReference, 
LiteralExpression
+from pyspark.sql.connect.expressions import (
+    SortOrder,
+    ColumnReference,
+    LiteralExpression,
+)
 from pyspark.sql.connect.types import pyspark_types_to_proto_types
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName
     from pyspark.sql.connect.client import SparkConnectClient
+    from pyspark.sql.connect.udf import UserDefinedFunction
 
 
 class InputValidationError(Exception):
@@ -1863,3 +1868,21 @@ class ListCatalogs(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         return 
proto.Relation(catalog=proto.Catalog(list_catalogs=proto.ListCatalogs()))
+
+
+class FrameMap(LogicalPlan):
+    """Logical plan object for a Frame Map API: mapInPandas, mapInArrow."""
+
+    def __init__(
+        self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", 
cols: List[str]
+    ) -> None:
+        super().__init__(child)
+
+        self._func = function._build_common_inline_user_defined_function(*cols)
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        assert self._child is not None
+        plan = self._create_proto_relation()
+        plan.frame_map.input.CopyFrom(self._child.plan(session))
+        plan.frame_map.func.CopyFrom(self._func.to_plan_udf(session))
+        return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 057b96a8da9..3afdf61e681 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -91,6 +91,7 @@ _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"]
 _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
 _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
 _REPARTITIONBYEXPRESSION = 
DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
+_FRAMEMAP = DESCRIPTOR.message_types_by_name["FrameMap"]
 _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
 _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"]
 _AGGREGATE_GROUPTYPE = _AGGREGATE.enum_types_by_name["GroupType"]
@@ -624,6 +625,17 @@ RepartitionByExpression = 
_reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(RepartitionByExpression)
 
+FrameMap = _reflection.GeneratedProtocolMessageType(
+    "FrameMap",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _FRAMEMAP,
+        "__module__": "spark.connect.relations_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.FrameMap)
+    },
+)
+_sym_db.RegisterMessage(FrameMap)
+
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
@@ -635,107 +647,109 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._options = None
     _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 2462
-    _UNKNOWN._serialized_start = 2464
-    _UNKNOWN._serialized_end = 2473
-    _RELATIONCOMMON._serialized_start = 2475
-    _RELATIONCOMMON._serialized_end = 2566
-    _SQL._serialized_start = 2569
-    _SQL._serialized_end = 2703
-    _SQL_ARGSENTRY._serialized_start = 2648
-    _SQL_ARGSENTRY._serialized_end = 2703
-    _READ._serialized_start = 2706
-    _READ._serialized_end = 3154
-    _READ_NAMEDTABLE._serialized_start = 2848
-    _READ_NAMEDTABLE._serialized_end = 2909
-    _READ_DATASOURCE._serialized_start = 2912
-    _READ_DATASOURCE._serialized_end = 3141
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3072
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3130
-    _PROJECT._serialized_start = 3156
-    _PROJECT._serialized_end = 3273
-    _FILTER._serialized_start = 3275
-    _FILTER._serialized_end = 3387
-    _JOIN._serialized_start = 3390
-    _JOIN._serialized_end = 3861
-    _JOIN_JOINTYPE._serialized_start = 3653
-    _JOIN_JOINTYPE._serialized_end = 3861
-    _SETOPERATION._serialized_start = 3864
-    _SETOPERATION._serialized_end = 4343
-    _SETOPERATION_SETOPTYPE._serialized_start = 4180
-    _SETOPERATION_SETOPTYPE._serialized_end = 4294
-    _LIMIT._serialized_start = 4345
-    _LIMIT._serialized_end = 4421
-    _OFFSET._serialized_start = 4423
-    _OFFSET._serialized_end = 4502
-    _TAIL._serialized_start = 4504
-    _TAIL._serialized_end = 4579
-    _AGGREGATE._serialized_start = 4582
-    _AGGREGATE._serialized_end = 5164
-    _AGGREGATE_PIVOT._serialized_start = 4921
-    _AGGREGATE_PIVOT._serialized_end = 5032
-    _AGGREGATE_GROUPTYPE._serialized_start = 5035
-    _AGGREGATE_GROUPTYPE._serialized_end = 5164
-    _SORT._serialized_start = 5167
-    _SORT._serialized_end = 5327
-    _DROP._serialized_start = 5329
-    _DROP._serialized_end = 5429
-    _DEDUPLICATE._serialized_start = 5432
-    _DEDUPLICATE._serialized_end = 5603
-    _LOCALRELATION._serialized_start = 5605
-    _LOCALRELATION._serialized_end = 5694
-    _SAMPLE._serialized_start = 5697
-    _SAMPLE._serialized_end = 5970
-    _RANGE._serialized_start = 5973
-    _RANGE._serialized_end = 6118
-    _SUBQUERYALIAS._serialized_start = 6120
-    _SUBQUERYALIAS._serialized_end = 6234
-    _REPARTITION._serialized_start = 6237
-    _REPARTITION._serialized_end = 6379
-    _SHOWSTRING._serialized_start = 6382
-    _SHOWSTRING._serialized_end = 6524
-    _STATSUMMARY._serialized_start = 6526
-    _STATSUMMARY._serialized_end = 6618
-    _STATDESCRIBE._serialized_start = 6620
-    _STATDESCRIBE._serialized_end = 6701
-    _STATCROSSTAB._serialized_start = 6703
-    _STATCROSSTAB._serialized_end = 6804
-    _STATCOV._serialized_start = 6806
-    _STATCOV._serialized_end = 6902
-    _STATCORR._serialized_start = 6905
-    _STATCORR._serialized_end = 7042
-    _STATAPPROXQUANTILE._serialized_start = 7045
-    _STATAPPROXQUANTILE._serialized_end = 7209
-    _STATFREQITEMS._serialized_start = 7211
-    _STATFREQITEMS._serialized_end = 7336
-    _STATSAMPLEBY._serialized_start = 7339
-    _STATSAMPLEBY._serialized_end = 7648
-    _STATSAMPLEBY_FRACTION._serialized_start = 7540
-    _STATSAMPLEBY_FRACTION._serialized_end = 7639
-    _NAFILL._serialized_start = 7651
-    _NAFILL._serialized_end = 7785
-    _NADROP._serialized_start = 7788
-    _NADROP._serialized_end = 7922
-    _NAREPLACE._serialized_start = 7925
-    _NAREPLACE._serialized_end = 8221
-    _NAREPLACE_REPLACEMENT._serialized_start = 8080
-    _NAREPLACE_REPLACEMENT._serialized_end = 8221
-    _TODF._serialized_start = 8223
-    _TODF._serialized_end = 8311
-    _WITHCOLUMNSRENAMED._serialized_start = 8314
-    _WITHCOLUMNSRENAMED._serialized_end = 8553
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8486
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8553
-    _WITHCOLUMNS._serialized_start = 8555
-    _WITHCOLUMNS._serialized_end = 8674
-    _HINT._serialized_start = 8677
-    _HINT._serialized_end = 8809
-    _UNPIVOT._serialized_start = 8812
-    _UNPIVOT._serialized_end = 9139
-    _UNPIVOT_VALUES._serialized_start = 9069
-    _UNPIVOT_VALUES._serialized_end = 9128
-    _TOSCHEMA._serialized_start = 9141
-    _TOSCHEMA._serialized_end = 9247
-    _REPARTITIONBYEXPRESSION._serialized_start = 9250
-    _REPARTITIONBYEXPRESSION._serialized_end = 9453
+    _RELATION._serialized_end = 2518
+    _UNKNOWN._serialized_start = 2520
+    _UNKNOWN._serialized_end = 2529
+    _RELATIONCOMMON._serialized_start = 2531
+    _RELATIONCOMMON._serialized_end = 2622
+    _SQL._serialized_start = 2625
+    _SQL._serialized_end = 2759
+    _SQL_ARGSENTRY._serialized_start = 2704
+    _SQL_ARGSENTRY._serialized_end = 2759
+    _READ._serialized_start = 2762
+    _READ._serialized_end = 3210
+    _READ_NAMEDTABLE._serialized_start = 2904
+    _READ_NAMEDTABLE._serialized_end = 2965
+    _READ_DATASOURCE._serialized_start = 2968
+    _READ_DATASOURCE._serialized_end = 3197
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3128
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3186
+    _PROJECT._serialized_start = 3212
+    _PROJECT._serialized_end = 3329
+    _FILTER._serialized_start = 3331
+    _FILTER._serialized_end = 3443
+    _JOIN._serialized_start = 3446
+    _JOIN._serialized_end = 3917
+    _JOIN_JOINTYPE._serialized_start = 3709
+    _JOIN_JOINTYPE._serialized_end = 3917
+    _SETOPERATION._serialized_start = 3920
+    _SETOPERATION._serialized_end = 4399
+    _SETOPERATION_SETOPTYPE._serialized_start = 4236
+    _SETOPERATION_SETOPTYPE._serialized_end = 4350
+    _LIMIT._serialized_start = 4401
+    _LIMIT._serialized_end = 4477
+    _OFFSET._serialized_start = 4479
+    _OFFSET._serialized_end = 4558
+    _TAIL._serialized_start = 4560
+    _TAIL._serialized_end = 4635
+    _AGGREGATE._serialized_start = 4638
+    _AGGREGATE._serialized_end = 5220
+    _AGGREGATE_PIVOT._serialized_start = 4977
+    _AGGREGATE_PIVOT._serialized_end = 5088
+    _AGGREGATE_GROUPTYPE._serialized_start = 5091
+    _AGGREGATE_GROUPTYPE._serialized_end = 5220
+    _SORT._serialized_start = 5223
+    _SORT._serialized_end = 5383
+    _DROP._serialized_start = 5385
+    _DROP._serialized_end = 5485
+    _DEDUPLICATE._serialized_start = 5488
+    _DEDUPLICATE._serialized_end = 5659
+    _LOCALRELATION._serialized_start = 5661
+    _LOCALRELATION._serialized_end = 5750
+    _SAMPLE._serialized_start = 5753
+    _SAMPLE._serialized_end = 6026
+    _RANGE._serialized_start = 6029
+    _RANGE._serialized_end = 6174
+    _SUBQUERYALIAS._serialized_start = 6176
+    _SUBQUERYALIAS._serialized_end = 6290
+    _REPARTITION._serialized_start = 6293
+    _REPARTITION._serialized_end = 6435
+    _SHOWSTRING._serialized_start = 6438
+    _SHOWSTRING._serialized_end = 6580
+    _STATSUMMARY._serialized_start = 6582
+    _STATSUMMARY._serialized_end = 6674
+    _STATDESCRIBE._serialized_start = 6676
+    _STATDESCRIBE._serialized_end = 6757
+    _STATCROSSTAB._serialized_start = 6759
+    _STATCROSSTAB._serialized_end = 6860
+    _STATCOV._serialized_start = 6862
+    _STATCOV._serialized_end = 6958
+    _STATCORR._serialized_start = 6961
+    _STATCORR._serialized_end = 7098
+    _STATAPPROXQUANTILE._serialized_start = 7101
+    _STATAPPROXQUANTILE._serialized_end = 7265
+    _STATFREQITEMS._serialized_start = 7267
+    _STATFREQITEMS._serialized_end = 7392
+    _STATSAMPLEBY._serialized_start = 7395
+    _STATSAMPLEBY._serialized_end = 7704
+    _STATSAMPLEBY_FRACTION._serialized_start = 7596
+    _STATSAMPLEBY_FRACTION._serialized_end = 7695
+    _NAFILL._serialized_start = 7707
+    _NAFILL._serialized_end = 7841
+    _NADROP._serialized_start = 7844
+    _NADROP._serialized_end = 7978
+    _NAREPLACE._serialized_start = 7981
+    _NAREPLACE._serialized_end = 8277
+    _NAREPLACE_REPLACEMENT._serialized_start = 8136
+    _NAREPLACE_REPLACEMENT._serialized_end = 8277
+    _TODF._serialized_start = 8279
+    _TODF._serialized_end = 8367
+    _WITHCOLUMNSRENAMED._serialized_start = 8370
+    _WITHCOLUMNSRENAMED._serialized_end = 8609
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8542
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8609
+    _WITHCOLUMNS._serialized_start = 8611
+    _WITHCOLUMNS._serialized_end = 8730
+    _HINT._serialized_start = 8733
+    _HINT._serialized_end = 8865
+    _UNPIVOT._serialized_start = 8868
+    _UNPIVOT._serialized_end = 9195
+    _UNPIVOT_VALUES._serialized_start = 9125
+    _UNPIVOT_VALUES._serialized_end = 9184
+    _TOSCHEMA._serialized_start = 9197
+    _TOSCHEMA._serialized_end = 9303
+    _REPARTITIONBYEXPRESSION._serialized_start = 9306
+    _REPARTITIONBYEXPRESSION._serialized_end = 9509
+    _FRAMEMAP._serialized_start = 9511
+    _FRAMEMAP._serialized_end = 9636
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index b7cef7b299d..3f3b9f4c5b0 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -89,6 +89,7 @@ class Relation(google.protobuf.message.Message):
     UNPIVOT_FIELD_NUMBER: builtins.int
     TO_SCHEMA_FIELD_NUMBER: builtins.int
     REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int
+    FRAME_MAP_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
     REPLACE_FIELD_NUMBER: builtins.int
@@ -158,6 +159,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def repartition_by_expression(self) -> global___RepartitionByExpression: 
...
     @property
+    def frame_map(self) -> global___FrameMap: ...
+    @property
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
@@ -221,6 +224,7 @@ class Relation(google.protobuf.message.Message):
         unpivot: global___Unpivot | None = ...,
         to_schema: global___ToSchema | None = ...,
         repartition_by_expression: global___RepartitionByExpression | None = 
...,
+        frame_map: global___FrameMap | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
         replace: global___NAReplace | None = ...,
@@ -267,6 +271,8 @@ class Relation(google.protobuf.message.Message):
             b"fill_na",
             "filter",
             b"filter",
+            "frame_map",
+            b"frame_map",
             "freq_items",
             b"freq_items",
             "hint",
@@ -356,6 +362,8 @@ class Relation(google.protobuf.message.Message):
             b"fill_na",
             "filter",
             b"filter",
+            "frame_map",
+            b"frame_map",
             "freq_items",
             b"freq_items",
             "hint",
@@ -443,6 +451,7 @@ class Relation(google.protobuf.message.Message):
         "unpivot",
         "to_schema",
         "repartition_by_expression",
+        "frame_map",
         "fill_na",
         "drop_na",
         "replace",
@@ -2639,3 +2648,30 @@ class 
RepartitionByExpression(google.protobuf.message.Message):
     ) -> typing_extensions.Literal["num_partitions"] | None: ...
 
 global___RepartitionByExpression = RepartitionByExpression
+
+class FrameMap(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    FUNC_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """(Required) Input relation for a Frame Map API: mapInPandas, 
mapInArrow."""
+    @property
+    def func(self) -> 
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
+        """(Required) Input user-defined function of a Frame Map API."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        func: 
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+        | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["func", b"func", "input", 
b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["func", b"func", "input", 
b"input"]
+    ) -> None: ...
+
+global___FrameMap = FrameMap
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 28eb51d72cc..c73aa43d0ca 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -351,7 +351,9 @@ def parse_data_type(data_type: str) -> DataType:
     return_type_schema = (
         PySparkSession.builder.getOrCreate().createDataFrame(data=[], 
schema=data_type).schema
     )
-    if len(return_type_schema.fields) == 1:
+    with_col_name = " " in data_type.strip()
+    if len(return_type_schema.fields) == 1 and not with_col_name:
+        # To match pyspark.sql.types._parse_datatype_string
         return_type = return_type_schema.fields[0].dataType
     else:
         return_type = return_type_schema
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index bfe7006d161..c6bff4a3caa 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -108,11 +108,14 @@ class UserDefinedFunction:
         self.evalType = evalType
         self.deterministic = deterministic
 
-    def __call__(self, *cols: "ColumnOrName") -> Column:
+    def _build_common_inline_user_defined_function(
+        self, *cols: "ColumnOrName"
+    ) -> CommonInlineUserDefinedFunction:
         arg_cols = [
             col if isinstance(col, Column) else Column(ColumnReference(col)) 
for col in cols
         ]
         arg_exprs = [col._expr for col in arg_cols]
+
         data_type_str = (
             self._returnType.json() if isinstance(self._returnType, DataType) 
else self._returnType
         )
@@ -122,15 +125,16 @@ class UserDefinedFunction:
             command=CloudPickleSerializer().dumps((self.func, 
self._returnType)),
             python_ver="%d.%d" % sys.version_info[:2],
         )
-        return Column(
-            CommonInlineUserDefinedFunction(
-                function_name=self._name,
-                deterministic=self.deterministic,
-                arguments=arg_exprs,
-                function=py_udf,
-            )
+        return CommonInlineUserDefinedFunction(
+            function_name=self._name,
+            deterministic=self.deterministic,
+            arguments=arg_exprs,
+            function=py_udf,
         )
 
+    def __call__(self, *cols: "ColumnOrName") -> Column:
+        return Column(self._build_common_inline_user_defined_function(*cols))
+
     # This function is for improving the online help system in the interactive 
interpreter.
     # For example, the built-in help / pydoc.help. It wraps the UDF with the 
docstring and
     # argument annotation. (See: SPARK-19161)
diff --git a/python/pyspark/sql/pandas/map_ops.py 
b/python/pyspark/sql/pandas/map_ops.py
index 5f89577a1b6..2184fdce52d 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -48,6 +48,9 @@ class PandasMapOpsMixin:
 
         .. versionadded:: 3.0.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         func : function
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_map.py 
b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py
new file mode 100644
index 00000000000..b8402c564f1
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_map import MapInPandasTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class MapInPandasParityTests(MapInPandasTestsMixin, ReusedConnectTestCase):
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_empty_dataframes_with_less_columns(self):
+        super().test_empty_dataframes_with_less_columns()
+
+    @unittest.skip(
+        "Spark Connect does not support sc._jvm.org.apache.log4j but the test 
depends on it."
+    )
+    def test_other_than_dataframe(self):
+        super().test_other_than_dataframe()
+
+    @unittest.skip("Spark Connect does not support spark.conf but the test 
depends on it.")
+    def test_map_in_pandas_with_column_vector(self):
+        super().test_map_in_pandas_with_column_vector()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_pandas_map import *  # noqa: 
F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 4a9bd4c6533..e39b97613cf 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -41,28 +41,7 @@ if have_pandas:
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),
 )
-class MapInPandasTests(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        ReusedSQLTestCase.setUpClass()
-
-        # Synchronize default timezone between Python and Java
-        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
-        tz = "America/Los_Angeles"
-        os.environ["TZ"] = tz
-        time.tzset()
-
-        cls.sc.environment["TZ"] = tz
-        cls.spark.conf.set("spark.sql.session.timeZone", tz)
-
-    @classmethod
-    def tearDownClass(cls):
-        del os.environ["TZ"]
-        if cls.tz_prev is not None:
-            os.environ["TZ"] = cls.tz_prev
-        time.tzset()
-        ReusedSQLTestCase.tearDownClass()
-
+class MapInPandasTestsMixin:
     def test_map_in_pandas(self):
         def func(iterator):
             for pdf in iterator:
@@ -203,6 +182,29 @@ class MapInPandasTests(ReusedSQLTestCase):
             shutil.rmtree(path)
 
 
+class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin):
+    @classmethod
+    def setUpClass(cls):
+        ReusedSQLTestCase.setUpClass()
+
+        # Synchronize default timezone between Python and Java
+        cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
+        tz = "America/Los_Angeles"
+        os.environ["TZ"] = tz
+        time.tzset()
+
+        cls.sc.environment["TZ"] = tz
+        cls.spark.conf.set("spark.sql.session.timeZone", tz)
+
+    @classmethod
+    def tearDownClass(cls):
+        del os.environ["TZ"]
+        if cls.tz_prev is not None:
+            os.environ["TZ"] = cls.tz_prev
+        time.tzset()
+        ReusedSQLTestCase.tearDownClass()
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests.pandas.test_pandas_map import *  # noqa: F401
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to