This is an automated email from the ASF dual-hosted git repository. xinrong pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 000895da3f6 [SPARK-42510][CONNECT][PYTHON] Implement `DataFrame.mapInPandas` 000895da3f6 is described below commit 000895da3f6c0d17ccfdfe79c0ca34dfb9fb6e7b Author: Xinrong Meng <xinr...@apache.org> AuthorDate: Sat Feb 25 07:39:54 2023 +0800 [SPARK-42510][CONNECT][PYTHON] Implement `DataFrame.mapInPandas` ### What changes were proposed in this pull request? Implement `DataFrame.mapInPandas` and enable parity tests to vanilla PySpark. A proto message `FrameMap` is intorudced for `mapInPandas` and `mapInArrow`(to implement next). ### Why are the changes needed? To reach parity with vanilla PySpark. ### Does this PR introduce _any_ user-facing change? Yes. `DataFrame.mapInPandas` is supported. An example is as shown below. ```py >>> df = spark.range(2) >>> def filter_func(iterator): ... for pdf in iterator: ... yield pdf[pdf.id == 1] ... >>> df.mapInPandas(filter_func, df.schema) DataFrame[id: bigint] >>> df.mapInPandas(filter_func, df.schema).show() +---+ | id| +---+ | 1| +---+ ``` ### How was this patch tested? Unit tests. Closes #40104 from xinrong-meng/mapInPandas. Lead-authored-by: Xinrong Meng <xinr...@apache.org>] Co-authored-by: Xinrong Meng <xinr...@apache.org> Signed-off-by: Xinrong Meng <xinr...@apache.org> (cherry picked from commit 9abccad1d93a243d7e47e53dcbc85568a460c529) Signed-off-by: Xinrong Meng <xinr...@apache.org> --- .../main/protobuf/spark/connect/relations.proto | 10 + .../sql/connect/planner/SparkConnectPlanner.scala | 18 +- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/_typing.py | 8 +- python/pyspark/sql/connect/client.py | 2 +- python/pyspark/sql/connect/dataframe.py | 22 +- python/pyspark/sql/connect/expressions.py | 6 +- python/pyspark/sql/connect/plan.py | 25 ++- python/pyspark/sql/connect/proto/relations_pb2.py | 222 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 36 ++++ python/pyspark/sql/connect/types.py | 4 +- python/pyspark/sql/connect/udf.py | 20 +- python/pyspark/sql/pandas/map_ops.py | 3 + .../sql/tests/connect/test_parity_pandas_map.py | 50 +++++ python/pyspark/sql/tests/pandas/test_pandas_map.py | 46 +++-- 15 files changed, 331 insertions(+), 142 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 29fffd65c75..4d96b6b0c7e 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -60,6 +60,7 @@ message Relation { Unpivot unpivot = 25; ToSchema to_schema = 26; RepartitionByExpression repartition_by_expression = 27; + FrameMap frame_map = 28; // NA functions NAFill fill_na = 90; @@ -768,3 +769,12 @@ message RepartitionByExpression { // (Optional) number of partitions, must be positive. optional int32 num_partitions = 3; } + +message FrameMap { + // (Required) Input relation for a Frame Map API: mapInPandas, mapInArrow. + Relation input = 1; + + // (Required) Input user-defined function of a Frame Map API. + CommonInlineUserDefinedFunction func = 2; +} + diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 268bf02fad9..cc43c1cace3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -24,7 +24,7 @@ import com.google.common.collect.{Lists, Maps} import com.google.protobuf.{Any => ProtoAny} import org.apache.spark.TaskContext -import org.apache.spark.api.python.SimplePythonFunction +import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} @@ -106,6 +106,8 @@ class SparkConnectPlanner(val session: SparkSession) { case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot) case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION => transformRepartitionByExpression(rel.getRepartitionByExpression) + case proto.Relation.RelTypeCase.FRAME_MAP => + transformFrameMap(rel.getFrameMap) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") @@ -458,6 +460,20 @@ class SparkConnectPlanner(val session: SparkSession) { .logicalPlan } + private def transformFrameMap(rel: proto.FrameMap): LogicalPlan = { + val commonUdf = rel.getFunc + val pythonUdf = transformPythonUDF(commonUdf) + pythonUdf.evalType match { + case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => + logical.MapInPandas( + pythonUdf, + pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + transformRelation(rel.getInput)) + case _ => + throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } + } + private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = { Dataset .ofRows(session, transformRelation(rel.getInput)) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 75a6b4401b8..b849892e20a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -532,6 +532,7 @@ pyspark_connect = Module( "pyspark.sql.tests.connect.test_parity_readwriter", "pyspark.sql.tests.connect.test_parity_udf", "pyspark.sql.tests.connect.test_parity_pandas_udf", + "pyspark.sql.tests.connect.test_parity_pandas_map", ], excluded_python_implementations=[ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 66b08d898fe..c91d4e629d8 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -22,10 +22,12 @@ if sys.version_info >= (3, 8): else: from typing_extensions import Protocol -from typing import Any, Callable, Union, Optional +from typing import Any, Callable, Iterable, Union, Optional import datetime import decimal +from pandas.core.frame import DataFrame as PandasDataFrame + from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import DataType @@ -44,6 +46,10 @@ DateTimeLiteral = Union[datetime.datetime, datetime.date] DataTypeOrString = Union[DataType, str] +DataFrameLike = PandasDataFrame + +PandasMapIterFunction = Callable[[Iterable[DataFrameLike]], Iterable[DataFrameLike]] + class UserDefinedFunctionLike(Protocol): func: Callable[..., Any] diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 154dd161e92..7ae4645863b 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -494,7 +494,7 @@ class SparkConnectClient(object): deterministic=deterministic, arguments=[], function=py_udf, - ).to_command(self) + ).to_plan_udf(self) # construct the request req = self._execute_plan_request_with_metadata() diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 393f7f42ec8..b2253c21b66 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -51,6 +51,7 @@ from pyspark.sql.dataframe import ( from pyspark.errors import PySparkTypeError from pyspark.errors.exceptions.connect import SparkConnectException +from pyspark.rdd import PythonEvalType import pyspark.sql.connect.plan as plan from pyspark.sql.connect.group import GroupedData from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2 @@ -73,6 +74,7 @@ if TYPE_CHECKING: LiteralType, PrimitiveType, OptionalPrimitiveType, + PandasMapIterFunction, ) from pyspark.sql.connect.session import SparkSession @@ -1540,8 +1542,24 @@ class DataFrame: def storageLevel(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("storageLevel() is not implemented.") - def mapInPandas(self, *args: Any, **kwargs: Any) -> None: - raise NotImplementedError("mapInPandas() is not implemented.") + def mapInPandas( + self, func: "PandasMapIterFunction", schema: Union[StructType, str] + ) -> "DataFrame": + from pyspark.sql.connect.udf import UserDefinedFunction + + if self._plan is None: + raise Exception("Cannot mapInPandas when self._plan is empty.") + + udf_obj = UserDefinedFunction( + func, returnType=schema, evalType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + ) + + return DataFrame.withPlan( + plan.FrameMap(child=self._plan, function=udf_obj, cols=self.columns), + session=self._session, + ) + + mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__ def mapInArrow(self, *args: Any, **kwargs: Any) -> None: raise NotImplementedError("mapInArrow() is not implemented.") diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 76e4252dce7..f3c9e2c70c4 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -549,10 +549,14 @@ class CommonInlineUserDefinedFunction(Expression): ) return expr - def to_command(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction": + def to_plan_udf(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction": + """Compared to `to_plan`, it returns a CommonInlineUserDefinedFunction instead of an + Expression.""" expr = proto.CommonInlineUserDefinedFunction() expr.function_name = self._function_name expr.deterministic = self._deterministic + if len(self._arguments) > 0: + expr.arguments.extend([arg.to_plan(session) for arg in self._arguments]) expr.python_udf.CopyFrom(self._function.to_plan(session)) return expr diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 0f27b214502..badbb9871ed 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -30,12 +30,17 @@ from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto from pyspark.sql.connect.column import Column -from pyspark.sql.connect.expressions import SortOrder, ColumnReference, LiteralExpression +from pyspark.sql.connect.expressions import ( + SortOrder, + ColumnReference, + LiteralExpression, +) from pyspark.sql.connect.types import pyspark_types_to_proto_types if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName from pyspark.sql.connect.client import SparkConnectClient + from pyspark.sql.connect.udf import UserDefinedFunction class InputValidationError(Exception): @@ -1863,3 +1868,21 @@ class ListCatalogs(LogicalPlan): def plan(self, session: "SparkConnectClient") -> proto.Relation: return proto.Relation(catalog=proto.Catalog(list_catalogs=proto.ListCatalogs())) + + +class FrameMap(LogicalPlan): + """Logical plan object for a Frame Map API: mapInPandas, mapInArrow.""" + + def __init__( + self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", cols: List[str] + ) -> None: + super().__init__(child) + + self._func = function._build_common_inline_user_defined_function(*cols) + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.frame_map.input.CopyFrom(self._child.plan(session)) + plan.frame_map.func.CopyFrom(self._func.to_plan_udf(session)) + return plan diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 057b96a8da9..3afdf61e681 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) @@ -91,6 +91,7 @@ _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"] _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"] _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"] _REPARTITIONBYEXPRESSION = DESCRIPTOR.message_types_by_name["RepartitionByExpression"] +_FRAMEMAP = DESCRIPTOR.message_types_by_name["FrameMap"] _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"] _SETOPERATION_SETOPTYPE = _SETOPERATION.enum_types_by_name["SetOpType"] _AGGREGATE_GROUPTYPE = _AGGREGATE.enum_types_by_name["GroupType"] @@ -624,6 +625,17 @@ RepartitionByExpression = _reflection.GeneratedProtocolMessageType( ) _sym_db.RegisterMessage(RepartitionByExpression) +FrameMap = _reflection.GeneratedProtocolMessageType( + "FrameMap", + (_message.Message,), + { + "DESCRIPTOR": _FRAMEMAP, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.FrameMap) + }, +) +_sym_db.RegisterMessage(FrameMap) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -635,107 +647,109 @@ if _descriptor._USE_C_DESCRIPTORS == False: _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._options = None _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 165 - _RELATION._serialized_end = 2462 - _UNKNOWN._serialized_start = 2464 - _UNKNOWN._serialized_end = 2473 - _RELATIONCOMMON._serialized_start = 2475 - _RELATIONCOMMON._serialized_end = 2566 - _SQL._serialized_start = 2569 - _SQL._serialized_end = 2703 - _SQL_ARGSENTRY._serialized_start = 2648 - _SQL_ARGSENTRY._serialized_end = 2703 - _READ._serialized_start = 2706 - _READ._serialized_end = 3154 - _READ_NAMEDTABLE._serialized_start = 2848 - _READ_NAMEDTABLE._serialized_end = 2909 - _READ_DATASOURCE._serialized_start = 2912 - _READ_DATASOURCE._serialized_end = 3141 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3072 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3130 - _PROJECT._serialized_start = 3156 - _PROJECT._serialized_end = 3273 - _FILTER._serialized_start = 3275 - _FILTER._serialized_end = 3387 - _JOIN._serialized_start = 3390 - _JOIN._serialized_end = 3861 - _JOIN_JOINTYPE._serialized_start = 3653 - _JOIN_JOINTYPE._serialized_end = 3861 - _SETOPERATION._serialized_start = 3864 - _SETOPERATION._serialized_end = 4343 - _SETOPERATION_SETOPTYPE._serialized_start = 4180 - _SETOPERATION_SETOPTYPE._serialized_end = 4294 - _LIMIT._serialized_start = 4345 - _LIMIT._serialized_end = 4421 - _OFFSET._serialized_start = 4423 - _OFFSET._serialized_end = 4502 - _TAIL._serialized_start = 4504 - _TAIL._serialized_end = 4579 - _AGGREGATE._serialized_start = 4582 - _AGGREGATE._serialized_end = 5164 - _AGGREGATE_PIVOT._serialized_start = 4921 - _AGGREGATE_PIVOT._serialized_end = 5032 - _AGGREGATE_GROUPTYPE._serialized_start = 5035 - _AGGREGATE_GROUPTYPE._serialized_end = 5164 - _SORT._serialized_start = 5167 - _SORT._serialized_end = 5327 - _DROP._serialized_start = 5329 - _DROP._serialized_end = 5429 - _DEDUPLICATE._serialized_start = 5432 - _DEDUPLICATE._serialized_end = 5603 - _LOCALRELATION._serialized_start = 5605 - _LOCALRELATION._serialized_end = 5694 - _SAMPLE._serialized_start = 5697 - _SAMPLE._serialized_end = 5970 - _RANGE._serialized_start = 5973 - _RANGE._serialized_end = 6118 - _SUBQUERYALIAS._serialized_start = 6120 - _SUBQUERYALIAS._serialized_end = 6234 - _REPARTITION._serialized_start = 6237 - _REPARTITION._serialized_end = 6379 - _SHOWSTRING._serialized_start = 6382 - _SHOWSTRING._serialized_end = 6524 - _STATSUMMARY._serialized_start = 6526 - _STATSUMMARY._serialized_end = 6618 - _STATDESCRIBE._serialized_start = 6620 - _STATDESCRIBE._serialized_end = 6701 - _STATCROSSTAB._serialized_start = 6703 - _STATCROSSTAB._serialized_end = 6804 - _STATCOV._serialized_start = 6806 - _STATCOV._serialized_end = 6902 - _STATCORR._serialized_start = 6905 - _STATCORR._serialized_end = 7042 - _STATAPPROXQUANTILE._serialized_start = 7045 - _STATAPPROXQUANTILE._serialized_end = 7209 - _STATFREQITEMS._serialized_start = 7211 - _STATFREQITEMS._serialized_end = 7336 - _STATSAMPLEBY._serialized_start = 7339 - _STATSAMPLEBY._serialized_end = 7648 - _STATSAMPLEBY_FRACTION._serialized_start = 7540 - _STATSAMPLEBY_FRACTION._serialized_end = 7639 - _NAFILL._serialized_start = 7651 - _NAFILL._serialized_end = 7785 - _NADROP._serialized_start = 7788 - _NADROP._serialized_end = 7922 - _NAREPLACE._serialized_start = 7925 - _NAREPLACE._serialized_end = 8221 - _NAREPLACE_REPLACEMENT._serialized_start = 8080 - _NAREPLACE_REPLACEMENT._serialized_end = 8221 - _TODF._serialized_start = 8223 - _TODF._serialized_end = 8311 - _WITHCOLUMNSRENAMED._serialized_start = 8314 - _WITHCOLUMNSRENAMED._serialized_end = 8553 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8486 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8553 - _WITHCOLUMNS._serialized_start = 8555 - _WITHCOLUMNS._serialized_end = 8674 - _HINT._serialized_start = 8677 - _HINT._serialized_end = 8809 - _UNPIVOT._serialized_start = 8812 - _UNPIVOT._serialized_end = 9139 - _UNPIVOT_VALUES._serialized_start = 9069 - _UNPIVOT_VALUES._serialized_end = 9128 - _TOSCHEMA._serialized_start = 9141 - _TOSCHEMA._serialized_end = 9247 - _REPARTITIONBYEXPRESSION._serialized_start = 9250 - _REPARTITIONBYEXPRESSION._serialized_end = 9453 + _RELATION._serialized_end = 2518 + _UNKNOWN._serialized_start = 2520 + _UNKNOWN._serialized_end = 2529 + _RELATIONCOMMON._serialized_start = 2531 + _RELATIONCOMMON._serialized_end = 2622 + _SQL._serialized_start = 2625 + _SQL._serialized_end = 2759 + _SQL_ARGSENTRY._serialized_start = 2704 + _SQL_ARGSENTRY._serialized_end = 2759 + _READ._serialized_start = 2762 + _READ._serialized_end = 3210 + _READ_NAMEDTABLE._serialized_start = 2904 + _READ_NAMEDTABLE._serialized_end = 2965 + _READ_DATASOURCE._serialized_start = 2968 + _READ_DATASOURCE._serialized_end = 3197 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3128 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3186 + _PROJECT._serialized_start = 3212 + _PROJECT._serialized_end = 3329 + _FILTER._serialized_start = 3331 + _FILTER._serialized_end = 3443 + _JOIN._serialized_start = 3446 + _JOIN._serialized_end = 3917 + _JOIN_JOINTYPE._serialized_start = 3709 + _JOIN_JOINTYPE._serialized_end = 3917 + _SETOPERATION._serialized_start = 3920 + _SETOPERATION._serialized_end = 4399 + _SETOPERATION_SETOPTYPE._serialized_start = 4236 + _SETOPERATION_SETOPTYPE._serialized_end = 4350 + _LIMIT._serialized_start = 4401 + _LIMIT._serialized_end = 4477 + _OFFSET._serialized_start = 4479 + _OFFSET._serialized_end = 4558 + _TAIL._serialized_start = 4560 + _TAIL._serialized_end = 4635 + _AGGREGATE._serialized_start = 4638 + _AGGREGATE._serialized_end = 5220 + _AGGREGATE_PIVOT._serialized_start = 4977 + _AGGREGATE_PIVOT._serialized_end = 5088 + _AGGREGATE_GROUPTYPE._serialized_start = 5091 + _AGGREGATE_GROUPTYPE._serialized_end = 5220 + _SORT._serialized_start = 5223 + _SORT._serialized_end = 5383 + _DROP._serialized_start = 5385 + _DROP._serialized_end = 5485 + _DEDUPLICATE._serialized_start = 5488 + _DEDUPLICATE._serialized_end = 5659 + _LOCALRELATION._serialized_start = 5661 + _LOCALRELATION._serialized_end = 5750 + _SAMPLE._serialized_start = 5753 + _SAMPLE._serialized_end = 6026 + _RANGE._serialized_start = 6029 + _RANGE._serialized_end = 6174 + _SUBQUERYALIAS._serialized_start = 6176 + _SUBQUERYALIAS._serialized_end = 6290 + _REPARTITION._serialized_start = 6293 + _REPARTITION._serialized_end = 6435 + _SHOWSTRING._serialized_start = 6438 + _SHOWSTRING._serialized_end = 6580 + _STATSUMMARY._serialized_start = 6582 + _STATSUMMARY._serialized_end = 6674 + _STATDESCRIBE._serialized_start = 6676 + _STATDESCRIBE._serialized_end = 6757 + _STATCROSSTAB._serialized_start = 6759 + _STATCROSSTAB._serialized_end = 6860 + _STATCOV._serialized_start = 6862 + _STATCOV._serialized_end = 6958 + _STATCORR._serialized_start = 6961 + _STATCORR._serialized_end = 7098 + _STATAPPROXQUANTILE._serialized_start = 7101 + _STATAPPROXQUANTILE._serialized_end = 7265 + _STATFREQITEMS._serialized_start = 7267 + _STATFREQITEMS._serialized_end = 7392 + _STATSAMPLEBY._serialized_start = 7395 + _STATSAMPLEBY._serialized_end = 7704 + _STATSAMPLEBY_FRACTION._serialized_start = 7596 + _STATSAMPLEBY_FRACTION._serialized_end = 7695 + _NAFILL._serialized_start = 7707 + _NAFILL._serialized_end = 7841 + _NADROP._serialized_start = 7844 + _NADROP._serialized_end = 7978 + _NAREPLACE._serialized_start = 7981 + _NAREPLACE._serialized_end = 8277 + _NAREPLACE_REPLACEMENT._serialized_start = 8136 + _NAREPLACE_REPLACEMENT._serialized_end = 8277 + _TODF._serialized_start = 8279 + _TODF._serialized_end = 8367 + _WITHCOLUMNSRENAMED._serialized_start = 8370 + _WITHCOLUMNSRENAMED._serialized_end = 8609 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8542 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8609 + _WITHCOLUMNS._serialized_start = 8611 + _WITHCOLUMNS._serialized_end = 8730 + _HINT._serialized_start = 8733 + _HINT._serialized_end = 8865 + _UNPIVOT._serialized_start = 8868 + _UNPIVOT._serialized_end = 9195 + _UNPIVOT_VALUES._serialized_start = 9125 + _UNPIVOT_VALUES._serialized_end = 9184 + _TOSCHEMA._serialized_start = 9197 + _TOSCHEMA._serialized_end = 9303 + _REPARTITIONBYEXPRESSION._serialized_start = 9306 + _REPARTITIONBYEXPRESSION._serialized_end = 9509 + _FRAMEMAP._serialized_start = 9511 + _FRAMEMAP._serialized_end = 9636 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index b7cef7b299d..3f3b9f4c5b0 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -89,6 +89,7 @@ class Relation(google.protobuf.message.Message): UNPIVOT_FIELD_NUMBER: builtins.int TO_SCHEMA_FIELD_NUMBER: builtins.int REPARTITION_BY_EXPRESSION_FIELD_NUMBER: builtins.int + FRAME_MAP_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -158,6 +159,8 @@ class Relation(google.protobuf.message.Message): @property def repartition_by_expression(self) -> global___RepartitionByExpression: ... @property + def frame_map(self) -> global___FrameMap: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -221,6 +224,7 @@ class Relation(google.protobuf.message.Message): unpivot: global___Unpivot | None = ..., to_schema: global___ToSchema | None = ..., repartition_by_expression: global___RepartitionByExpression | None = ..., + frame_map: global___FrameMap | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -267,6 +271,8 @@ class Relation(google.protobuf.message.Message): b"fill_na", "filter", b"filter", + "frame_map", + b"frame_map", "freq_items", b"freq_items", "hint", @@ -356,6 +362,8 @@ class Relation(google.protobuf.message.Message): b"fill_na", "filter", b"filter", + "frame_map", + b"frame_map", "freq_items", b"freq_items", "hint", @@ -443,6 +451,7 @@ class Relation(google.protobuf.message.Message): "unpivot", "to_schema", "repartition_by_expression", + "frame_map", "fill_na", "drop_na", "replace", @@ -2639,3 +2648,30 @@ class RepartitionByExpression(google.protobuf.message.Message): ) -> typing_extensions.Literal["num_partitions"] | None: ... global___RepartitionByExpression = RepartitionByExpression + +class FrameMap(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + FUNC_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) Input relation for a Frame Map API: mapInPandas, mapInArrow.""" + @property + def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: + """(Required) Input user-defined function of a Frame Map API.""" + def __init__( + self, + *, + input: global___Relation | None = ..., + func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction + | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"] + ) -> None: ... + +global___FrameMap = FrameMap diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 28eb51d72cc..c73aa43d0ca 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -351,7 +351,9 @@ def parse_data_type(data_type: str) -> DataType: return_type_schema = ( PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema ) - if len(return_type_schema.fields) == 1: + with_col_name = " " in data_type.strip() + if len(return_type_schema.fields) == 1 and not with_col_name: + # To match pyspark.sql.types._parse_datatype_string return_type = return_type_schema.fields[0].dataType else: return_type = return_type_schema diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index bfe7006d161..c6bff4a3caa 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -108,11 +108,14 @@ class UserDefinedFunction: self.evalType = evalType self.deterministic = deterministic - def __call__(self, *cols: "ColumnOrName") -> Column: + def _build_common_inline_user_defined_function( + self, *cols: "ColumnOrName" + ) -> CommonInlineUserDefinedFunction: arg_cols = [ col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols ] arg_exprs = [col._expr for col in arg_cols] + data_type_str = ( self._returnType.json() if isinstance(self._returnType, DataType) else self._returnType ) @@ -122,15 +125,16 @@ class UserDefinedFunction: command=CloudPickleSerializer().dumps((self.func, self._returnType)), python_ver="%d.%d" % sys.version_info[:2], ) - return Column( - CommonInlineUserDefinedFunction( - function_name=self._name, - deterministic=self.deterministic, - arguments=arg_exprs, - function=py_udf, - ) + return CommonInlineUserDefinedFunction( + function_name=self._name, + deterministic=self.deterministic, + arguments=arg_exprs, + function=py_udf, ) + def __call__(self, *cols: "ColumnOrName") -> Column: + return Column(self._build_common_inline_user_defined_function(*cols)) + # This function is for improving the online help system in the interactive interpreter. # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and # argument annotation. (See: SPARK-19161) diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index 5f89577a1b6..2184fdce52d 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -48,6 +48,9 @@ class PandasMapOpsMixin: .. versionadded:: 3.0.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- func : function diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_map.py b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py new file mode 100644 index 00000000000..b8402c564f1 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_pandas_map.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.tests.pandas.test_pandas_map import MapInPandasTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class MapInPandasParityTests(MapInPandasTestsMixin, ReusedConnectTestCase): + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_empty_dataframes_with_less_columns(self): + super().test_empty_dataframes_with_less_columns() + + @unittest.skip( + "Spark Connect does not support sc._jvm.org.apache.log4j but the test depends on it." + ) + def test_other_than_dataframe(self): + super().test_other_than_dataframe() + + @unittest.skip("Spark Connect does not support spark.conf but the test depends on it.") + def test_map_in_pandas_with_column_vector(self): + super().test_map_in_pandas_with_column_vector() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_pandas_map import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py index 4a9bd4c6533..e39b97613cf 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_map.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py @@ -41,28 +41,7 @@ if have_pandas: not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) -class MapInPandasTests(ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - +class MapInPandasTestsMixin: def test_map_in_pandas(self): def func(iterator): for pdf in iterator: @@ -203,6 +182,29 @@ class MapInPandasTests(ReusedSQLTestCase): shutil.rmtree(path) +class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + if __name__ == "__main__": from pyspark.sql.tests.pandas.test_pandas_map import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org