This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d91904e6271 [SPARK-43965][PYTHON][CONNECT] Support Python UDTF in Spark Connect d91904e6271 is described below commit d91904e627101f260933ac42244a75736dc97881 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Sun Jul 16 16:23:32 2023 +0900 [SPARK-43965][PYTHON][CONNECT] Support Python UDTF in Spark Connect ### What changes were proposed in this pull request? This PR supports creating and registering Python UDTFs in Spark Connect. ### Why are the changes needed? To make Python UDTF work with Spark Connect. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can use Python UDTFs in Spark Connect. ### How was this patch tested? New unit tests Closes #41989 from allisonwang-db/spark-43965-udtf-spark-connect. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/commands.proto | 1 + .../main/protobuf/spark/connect/relations.proto | 31 +++ .../sql/connect/planner/SparkConnectPlanner.scala | 65 ++++- python/pyspark/errors/error_classes.py | 5 + python/pyspark/sql/connect/client/core.py | 39 +++ python/pyspark/sql/connect/functions.py | 18 ++ python/pyspark/sql/connect/plan.py | 93 ++++++- python/pyspark/sql/connect/proto/commands_pb2.py | 160 ++++++------ python/pyspark/sql/connect/proto/commands_pb2.pyi | 12 + python/pyspark/sql/connect/proto/relations_pb2.py | 268 +++++++++++---------- python/pyspark/sql/connect/proto/relations_pb2.pyi | 119 +++++++++ python/pyspark/sql/connect/session.py | 11 +- python/pyspark/sql/connect/udtf.py | 205 ++++++++++++++++ python/pyspark/sql/connect/utils.py | 4 + python/pyspark/sql/functions.py | 3 +- .../sql/tests/connect/test_connect_basic.py | 11 - .../sql/tests/connect/test_connect_function.py | 18 ++ .../pyspark/sql/tests/connect/test_parity_udtf.py | 141 +++++++++++ python/pyspark/sql/tests/test_udtf.py | 38 ++- python/pyspark/sql/udtf.py | 111 +++++---- 20 files changed, 1071 insertions(+), 282 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 6689662fcf8..ce8c1d53943 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -41,6 +41,7 @@ message Command { StreamingQueryCommand streaming_query_command = 7; GetResourcesCommand get_resources_command = 8; StreamingQueryManagerCommand streaming_query_manager_command = 9; + CommonInlineUserDefinedTableFunction register_table_function = 10; // This field is used to mark extensions to the protocol. When plugins generate arbitrary // Commands they can add them here. During the planning the correct resolution is done. diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index b6a3e5fa236..8001b3cbcfa 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -71,6 +71,7 @@ message Relation { HtmlString html_string = 35; CachedLocalRelation cached_local_relation = 36; CachedRemoteRelation cached_remote_relation = 37; + CommonInlineUserDefinedTableFunction common_inline_user_defined_table_function = 38; // NA functions NAFill fill_na = 90; @@ -941,6 +942,36 @@ message ApplyInPandasWithState { string timeout_conf = 7; } +message CommonInlineUserDefinedTableFunction { + // (Required) Name of the user-defined table function. + string function_name = 1; + + // (Optional) Whether the user-defined table function is deterministic. + bool deterministic = 2; + + // (Optional) Function input arguments. Empty arguments are allowed. + repeated Expression arguments = 3; + + // (Required) Type of the user-defined table function. + oneof function { + PythonUDTF python_udtf = 4; + } +} + +message PythonUDTF { + // (Optional) Return type of the Python UDTF. + optional DataType return_type = 1; + + // (Required) EvalType of the Python UDTF. + int32 eval_type = 2; + + // (Required) The encoded commands of the Python UDTF. + bytes command = 3; + + // (Required) Python version being used in the client. + string python_ver = 4; +} + // Collect arbitrary (named) metrics from a dataset. message CollectMetrics { // (Required) The input relation. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 492396631f3..d414169fd81 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -66,7 +66,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation} -import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPythonFunction} +import org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper @@ -153,6 +153,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { transformCoGroupMap(rel.getCoGroupMap) case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE => transformApplyInPandasWithState(rel.getApplyInPandasWithState) + case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION => + transformCommonInlineUserDefinedTableFunction(rel.getCommonInlineUserDefinedTableFunction) case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION => transformCachedRemoteRelation(rel.getCachedRemoteRelation) case proto.Relation.RelTypeCase.COLLECT_METRICS => @@ -890,6 +892,32 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { .logicalPlan } + private def transformCommonInlineUserDefinedTableFunction( + fun: proto.CommonInlineUserDefinedTableFunction): LogicalPlan = { + fun.getFunctionCase match { + case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF => + val function = createPythonUserDefinedTableFunction(fun) + function.builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq) + case _ => + throw InvalidPlanInput( + s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") + } + } + + private def transformPythonTableFunction(fun: proto.PythonUDTF): SimplePythonFunction = { + SimplePythonFunction( + command = fun.getCommand.toByteArray, + // Empty environment variables + envVars = Maps.newHashMap(), + pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, + pythonExec = pythonExec, + pythonVer = fun.getPythonVer, + // Empty broadcast variables + broadcastVars = Lists.newArrayList(), + // Null accumulator + accumulator = null) + } + private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): LogicalPlan = { sessionHolder .getDataFrameOrThrow(rel.getRelationId) @@ -2311,6 +2339,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { command.getCommandTypeCase match { case proto.Command.CommandTypeCase.REGISTER_FUNCTION => handleRegisterUserDefinedFunction(command.getRegisterFunction) + case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION => + handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction) case proto.Command.CommandTypeCase.WRITE_OPERATION => handleWriteOperation(command.getWriteOperation) case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW => @@ -2437,6 +2467,39 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } } + private def handleRegisterUserDefinedTableFunction( + fun: proto.CommonInlineUserDefinedTableFunction): Unit = { + fun.getFunctionCase match { + case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF => + val function = createPythonUserDefinedTableFunction(fun) + session.udtf.registerPython(fun.getFunctionName, function) + case _ => + throw InvalidPlanInput( + s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") + } + } + + private def createPythonUserDefinedTableFunction( + fun: proto.CommonInlineUserDefinedTableFunction): UserDefinedPythonTableFunction = { + val udtf = fun.getPythonUdtf + // Currently return type is required for Python UDTFs. + // TODO(SPARK-44380): support `analyze` in Python UDTFs + assert(udtf.hasReturnType) + val returnType = transformDataType(udtf.getReturnType) + if (!returnType.isInstanceOf[StructType]) { + throw InvalidPlanInput( + "Invalid Python user-defined table function return type. " + + s"Expect a struct type, but got ${returnType.typeName}.") + } + + UserDefinedPythonTableFunction( + name = fun.getFunctionName, + func = transformPythonTableFunction(udtf), + returnType = returnType.asInstanceOf[StructType], + pythonEvalType = udtf.getEvalType, + udfDeterministic = fun.getDeterministic) + } + private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = { val udf = fun.getPythonUdf val function = transformPythonFunction(udf) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 8c51024bf06..56b166b53c5 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -668,6 +668,11 @@ ERROR_CLASSES_JSON = """ "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema." ] }, + "UDTF_RETURN_TYPE_MISMATCH" : { + "message" : [ + "Mismatch in return type for the UDTF '<name>'. Expected a 'StructType', but got '<return_type>'. Please ensure the return type is a correctly formatted StructType." + ] + }, "UNEXPECTED_RESPONSE_FROM_SERVER" : { "message" : [ "Unexpected response from iterator server." diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 537ab0a6140..00f2a85d602 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -75,6 +75,11 @@ from pyspark.sql.connect.expressions import ( CommonInlineUserDefinedFunction, JavaUDF, ) +from pyspark.sql.connect.plan import ( + CommonInlineUserDefinedTableFunction, + PythonUDTF, +) +from pyspark.sql.connect.utils import get_python_ver from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema from pyspark.sql.types import DataType, StructType, TimestampType, _has_type from pyspark.rdd import PythonEvalType @@ -641,6 +646,40 @@ class SparkConnectClient(object): self._execute(req) return name + def register_udtf( + self, + function: Any, + return_type: "DataTypeOrString", + name: str, + eval_type: int = PythonEvalType.SQL_TABLE_UDF, + deterministic: bool = True, + ) -> str: + """ + Register a user-defined table function (UDTF) in the session catalog + as a temporary function. The return type, if specified, must be a + struct type and it's validated when building the proto message + for the PythonUDTF. + """ + udtf = PythonUDTF( + func=function, + return_type=return_type, + eval_type=eval_type, + python_ver=get_python_ver(), + ) + + func = CommonInlineUserDefinedTableFunction( + function_name=name, + function=udtf, + deterministic=deterministic, + arguments=[], + ).udtf_plan(self) + + req = self._execute_plan_request_with_metadata() + req.plan.command.register_table_function.CopyFrom(func) + + self._execute(req) + return name + def register_java( self, name: str, diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 1be759d9b6e..a1c0516ee0d 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -31,6 +31,7 @@ from typing import ( overload, Optional, Tuple, + Type, Callable, ValuesView, cast, @@ -52,6 +53,7 @@ from pyspark.sql.connect.expressions import ( UnresolvedNamedLambdaVariable, ) from pyspark.sql.connect.udf import _create_py_udf +from pyspark.sql.connect.udtf import _create_py_udtf from pyspark.sql import functions as pysparkfuncs from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType @@ -67,6 +69,7 @@ if TYPE_CHECKING: UserDefinedFunctionLike, ) from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.udtf import UserDefinedTableFunction def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column: @@ -3891,6 +3894,21 @@ def udf( udf.__doc__ = pysparkfuncs.udf.__doc__ +def udtf( + cls: Optional[Type] = None, + *, + returnType: Union[StructType, str], + useArrow: Optional[bool] = None, +) -> Union["UserDefinedTableFunction", Callable[[Type], "UserDefinedTableFunction"]]: + if cls is None: + return functools.partial(_create_py_udtf, returnType=returnType, useArrow=useArrow) + else: + return _create_py_udtf(cls=cls, returnType=returnType, useArrow=useArrow) + + +udtf.__doc__ = pysparkfuncs.udtf.__doc__ + + def call_function(udfName: str, *cols: "ColumnOrName") -> Column: return _invoke_function(udfName, *[_to_col(c) for c in cols]) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 97348d4863e..3390faa04de 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -18,7 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) -from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict +from typing import Any, List, Optional, Type, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools import json from threading import Lock @@ -26,6 +26,7 @@ from inspect import signature, isclass import pyarrow as pa +from pyspark.serializers import CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.sql.types import DataType @@ -33,11 +34,12 @@ import pyspark.sql.connect.proto as proto from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( + Expression, SortOrder, ColumnReference, LiteralExpression, ) -from pyspark.sql.connect.types import pyspark_types_to_proto_types +from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType from pyspark.errors import PySparkTypeError, PySparkNotImplementedError if TYPE_CHECKING: @@ -2173,6 +2175,93 @@ class ApplyInPandasWithState(LogicalPlan): return plan +class PythonUDTF: + """Represents a Python user-defined table function.""" + + def __init__( + self, + func: Type, + return_type: Union[DataType, str], + eval_type: int, + python_ver: str, + ) -> None: + self._func = func + self._name = func.__name__ + self._return_type: DataType = ( + UnparsedDataType(return_type) if isinstance(return_type, str) else return_type + ) + self._eval_type = eval_type + self._python_ver = python_ver + + def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF: + udtf = proto.PythonUDTF() + # Currently the return type cannot be None. + # TODO(SPARK-44380): support `analyze` in Python UDTFs + assert self._return_type is not None + udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type)) + udtf.eval_type = self._eval_type + udtf.command = CloudPickleSerializer().dumps(self._func) + udtf.python_ver = self._python_ver + return udtf + + def __repr__(self) -> str: + return ( + f"PythonUDTF({self._name}, {self._return_type}, " + f"{self._eval_type}, {self._python_ver})" + ) + + +class CommonInlineUserDefinedTableFunction(LogicalPlan): + """ + Logical plan object for a user-defined table function with + an inlined defined function body. + """ + + def __init__( + self, + function_name: str, + function: PythonUDTF, + deterministic: bool, + arguments: Sequence[Expression], + ) -> None: + super().__init__(None) + self._function_name = function_name + self._deterministic = deterministic + self._arguments = arguments + self._function = function + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + plan = self._create_proto_relation() + plan.common_inline_user_defined_table_function.function_name = self._function_name + plan.common_inline_user_defined_table_function.deterministic = self._deterministic + if len(self._arguments) > 0: + plan.common_inline_user_defined_table_function.arguments.extend( + [arg.to_plan(session) for arg in self._arguments] + ) + plan.common_inline_user_defined_table_function.python_udtf.CopyFrom( + self._function.to_plan(session) + ) + return plan + + def udtf_plan( + self, session: "SparkConnectClient" + ) -> "proto.CommonInlineUserDefinedTableFunction": + """ + Compared to `plan`, it returns a `proto.CommonInlineUserDefinedTableFunction` + instead of a `proto.Relation`. + """ + plan = proto.CommonInlineUserDefinedTableFunction() + plan.function_name = self._function_name + plan.deterministic = self._deterministic + if len(self._arguments) > 0: + plan.arguments.extend([arg.to_plan(session) for arg in self._arguments]) + plan.python_udtf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session))) + return plan + + def __repr__(self) -> str: + return f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])})" + + class CachedRelation(LogicalPlan): def __init__(self, plan: proto.Relation) -> None: super(CachedRelation, self).__init__(None) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 3947d172ed4..c852cebd140 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_rel DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -59,83 +59,83 @@ if _descriptor._USE_C_DESCRIPTORS == False: _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001" _COMMAND._serialized_start = 167 - _COMMAND._serialized_end = 1069 - _SQLCOMMAND._serialized_start = 1072 - _SQLCOMMAND._serialized_end = 1313 - _SQLCOMMAND_ARGSENTRY._serialized_start = 1223 - _SQLCOMMAND_ARGSENTRY._serialized_end = 1313 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1316 - _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1466 - _WRITEOPERATION._serialized_start = 1469 - _WRITEOPERATION._serialized_end = 2520 - _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1944 - _WRITEOPERATION_OPTIONSENTRY._serialized_end = 2002 - _WRITEOPERATION_SAVETABLE._serialized_start = 2005 - _WRITEOPERATION_SAVETABLE._serialized_end = 2263 - _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 2139 - _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 2263 - _WRITEOPERATION_BUCKETBY._serialized_start = 2265 - _WRITEOPERATION_BUCKETBY._serialized_end = 2356 - _WRITEOPERATION_SAVEMODE._serialized_start = 2359 - _WRITEOPERATION_SAVEMODE._serialized_end = 2496 - _WRITEOPERATIONV2._serialized_start = 2523 - _WRITEOPERATIONV2._serialized_end = 3336 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1944 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 2002 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 3095 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 3161 - _WRITEOPERATIONV2_MODE._serialized_start = 3164 - _WRITEOPERATIONV2_MODE._serialized_end = 3323 - _WRITESTREAMOPERATIONSTART._serialized_start = 3339 - _WRITESTREAMOPERATIONSTART._serialized_end = 4139 - _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 1944 - _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2002 - _STREAMINGFOREACHFUNCTION._serialized_start = 4142 - _STREAMINGFOREACHFUNCTION._serialized_end = 4321 - _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 4323 - _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 4444 - _STREAMINGQUERYINSTANCEID._serialized_start = 4446 - _STREAMINGQUERYINSTANCEID._serialized_end = 4511 - _STREAMINGQUERYCOMMAND._serialized_start = 4514 - _STREAMINGQUERYCOMMAND._serialized_end = 5146 - _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5013 - _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 5057 - _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 5059 - _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 5135 - _STREAMINGQUERYCOMMANDRESULT._serialized_start = 5149 - _STREAMINGQUERYCOMMANDRESULT._serialized_end = 6290 - _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5732 - _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 5902 - _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 5904 - _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 5976 - _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 5978 - _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6017 - _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6020 - _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 6217 - _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 6219 - _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6275 - _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6293 - _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6990 - _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 6824 - _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 6903 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 6905 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 6979 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6993 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8147 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7601 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7728 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 7730 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 7845 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 7847 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 7906 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 7908 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 7983 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 7986 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 8132 - _GETRESOURCESCOMMAND._serialized_start = 8149 - _GETRESOURCESCOMMAND._serialized_end = 8170 - _GETRESOURCESCOMMANDRESULT._serialized_start = 8173 - _GETRESOURCESCOMMANDRESULT._serialized_end = 8385 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8289 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8385 + _COMMAND._serialized_end = 1180 + _SQLCOMMAND._serialized_start = 1183 + _SQLCOMMAND._serialized_end = 1424 + _SQLCOMMAND_ARGSENTRY._serialized_start = 1334 + _SQLCOMMAND_ARGSENTRY._serialized_end = 1424 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1427 + _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1577 + _WRITEOPERATION._serialized_start = 1580 + _WRITEOPERATION._serialized_end = 2631 + _WRITEOPERATION_OPTIONSENTRY._serialized_start = 2055 + _WRITEOPERATION_OPTIONSENTRY._serialized_end = 2113 + _WRITEOPERATION_SAVETABLE._serialized_start = 2116 + _WRITEOPERATION_SAVETABLE._serialized_end = 2374 + _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 2250 + _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 2374 + _WRITEOPERATION_BUCKETBY._serialized_start = 2376 + _WRITEOPERATION_BUCKETBY._serialized_end = 2467 + _WRITEOPERATION_SAVEMODE._serialized_start = 2470 + _WRITEOPERATION_SAVEMODE._serialized_end = 2607 + _WRITEOPERATIONV2._serialized_start = 2634 + _WRITEOPERATIONV2._serialized_end = 3447 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 2055 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 2113 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 3206 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 3272 + _WRITEOPERATIONV2_MODE._serialized_start = 3275 + _WRITEOPERATIONV2_MODE._serialized_end = 3434 + _WRITESTREAMOPERATIONSTART._serialized_start = 3450 + _WRITESTREAMOPERATIONSTART._serialized_end = 4250 + _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 2055 + _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2113 + _STREAMINGFOREACHFUNCTION._serialized_start = 4253 + _STREAMINGFOREACHFUNCTION._serialized_end = 4432 + _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 4434 + _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 4555 + _STREAMINGQUERYINSTANCEID._serialized_start = 4557 + _STREAMINGQUERYINSTANCEID._serialized_end = 4622 + _STREAMINGQUERYCOMMAND._serialized_start = 4625 + _STREAMINGQUERYCOMMAND._serialized_end = 5257 + _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5124 + _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 5168 + _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 5170 + _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 5246 + _STREAMINGQUERYCOMMANDRESULT._serialized_start = 5260 + _STREAMINGQUERYCOMMANDRESULT._serialized_end = 6401 + _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5843 + _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 6013 + _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 6015 + _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 6087 + _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 6089 + _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6128 + _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6131 + _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 6328 + _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 6330 + _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6386 + _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6404 + _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7101 + _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 6935 + _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 7014 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7016 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 7090 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7104 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8258 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7712 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7839 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 7841 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 7956 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 7958 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 8017 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 8019 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 8094 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 8097 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 8243 + _GETRESOURCESCOMMAND._serialized_start = 8260 + _GETRESOURCESCOMMAND._serialized_end = 8281 + _GETRESOURCESCOMMANDRESULT._serialized_start = 8284 + _GETRESOURCESCOMMANDRESULT._serialized_end = 8496 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8400 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8496 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index fe472b3140d..5b44e3da7a0 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -69,6 +69,7 @@ class Command(google.protobuf.message.Message): STREAMING_QUERY_COMMAND_FIELD_NUMBER: builtins.int GET_RESOURCES_COMMAND_FIELD_NUMBER: builtins.int STREAMING_QUERY_MANAGER_COMMAND_FIELD_NUMBER: builtins.int + REGISTER_TABLE_FUNCTION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property def register_function( @@ -91,6 +92,10 @@ class Command(google.protobuf.message.Message): @property def streaming_query_manager_command(self) -> global___StreamingQueryManagerCommand: ... @property + def register_table_function( + self, + ) -> pyspark.sql.connect.proto.relations_pb2.CommonInlineUserDefinedTableFunction: ... + @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary Commands they can add them here. During the planning the correct resolution is done. @@ -108,6 +113,8 @@ class Command(google.protobuf.message.Message): streaming_query_command: global___StreamingQueryCommand | None = ..., get_resources_command: global___GetResourcesCommand | None = ..., streaming_query_manager_command: global___StreamingQueryManagerCommand | None = ..., + register_table_function: pyspark.sql.connect.proto.relations_pb2.CommonInlineUserDefinedTableFunction + | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... def HasField( @@ -123,6 +130,8 @@ class Command(google.protobuf.message.Message): b"get_resources_command", "register_function", b"register_function", + "register_table_function", + b"register_table_function", "sql_command", b"sql_command", "streaming_query_command", @@ -150,6 +159,8 @@ class Command(google.protobuf.message.Message): b"get_resources_command", "register_function", b"register_function", + "register_table_function", + b"register_table_function", "sql_command", b"sql_command", "streaming_query_command", @@ -176,6 +187,7 @@ class Command(google.protobuf.message.Message): "streaming_query_command", "get_resources_command", "streaming_query_manager_command", + "register_table_function", "extension", ] | None: ... diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index d370bc6fd0c..3a0a7ff71fd 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xd0\x17\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xe1\x18\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -57,135 +57,139 @@ if _descriptor._USE_C_DESCRIPTORS == False: _PARSE_OPTIONSENTRY._options = None _PARSE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 165 - _RELATION._serialized_end = 3189 - _UNKNOWN._serialized_start = 3191 - _UNKNOWN._serialized_end = 3200 - _RELATIONCOMMON._serialized_start = 3202 - _RELATIONCOMMON._serialized_end = 3293 - _SQL._serialized_start = 3296 - _SQL._serialized_end = 3527 - _SQL_ARGSENTRY._serialized_start = 3437 - _SQL_ARGSENTRY._serialized_end = 3527 - _READ._serialized_start = 3530 - _READ._serialized_end = 4193 - _READ_NAMEDTABLE._serialized_start = 3708 - _READ_NAMEDTABLE._serialized_end = 3900 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3842 - _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 3900 - _READ_DATASOURCE._serialized_start = 3903 - _READ_DATASOURCE._serialized_end = 4180 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3842 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3900 - _PROJECT._serialized_start = 4195 - _PROJECT._serialized_end = 4312 - _FILTER._serialized_start = 4314 - _FILTER._serialized_end = 4426 - _JOIN._serialized_start = 4429 - _JOIN._serialized_end = 5090 - _JOIN_JOINDATATYPE._serialized_start = 4768 - _JOIN_JOINDATATYPE._serialized_end = 4860 - _JOIN_JOINTYPE._serialized_start = 4863 - _JOIN_JOINTYPE._serialized_end = 5071 - _SETOPERATION._serialized_start = 5093 - _SETOPERATION._serialized_end = 5572 - _SETOPERATION_SETOPTYPE._serialized_start = 5409 - _SETOPERATION_SETOPTYPE._serialized_end = 5523 - _LIMIT._serialized_start = 5574 - _LIMIT._serialized_end = 5650 - _OFFSET._serialized_start = 5652 - _OFFSET._serialized_end = 5731 - _TAIL._serialized_start = 5733 - _TAIL._serialized_end = 5808 - _AGGREGATE._serialized_start = 5811 - _AGGREGATE._serialized_end = 6393 - _AGGREGATE_PIVOT._serialized_start = 6150 - _AGGREGATE_PIVOT._serialized_end = 6261 - _AGGREGATE_GROUPTYPE._serialized_start = 6264 - _AGGREGATE_GROUPTYPE._serialized_end = 6393 - _SORT._serialized_start = 6396 - _SORT._serialized_end = 6556 - _DROP._serialized_start = 6559 - _DROP._serialized_end = 6700 - _DEDUPLICATE._serialized_start = 6703 - _DEDUPLICATE._serialized_end = 6943 - _LOCALRELATION._serialized_start = 6945 - _LOCALRELATION._serialized_end = 7034 - _CACHEDLOCALRELATION._serialized_start = 7036 - _CACHEDLOCALRELATION._serialized_end = 7131 - _CACHEDREMOTERELATION._serialized_start = 7133 - _CACHEDREMOTERELATION._serialized_end = 7188 - _SAMPLE._serialized_start = 7191 - _SAMPLE._serialized_end = 7464 - _RANGE._serialized_start = 7467 - _RANGE._serialized_end = 7612 - _SUBQUERYALIAS._serialized_start = 7614 - _SUBQUERYALIAS._serialized_end = 7728 - _REPARTITION._serialized_start = 7731 - _REPARTITION._serialized_end = 7873 - _SHOWSTRING._serialized_start = 7876 - _SHOWSTRING._serialized_end = 8018 - _HTMLSTRING._serialized_start = 8020 - _HTMLSTRING._serialized_end = 8134 - _STATSUMMARY._serialized_start = 8136 - _STATSUMMARY._serialized_end = 8228 - _STATDESCRIBE._serialized_start = 8230 - _STATDESCRIBE._serialized_end = 8311 - _STATCROSSTAB._serialized_start = 8313 - _STATCROSSTAB._serialized_end = 8414 - _STATCOV._serialized_start = 8416 - _STATCOV._serialized_end = 8512 - _STATCORR._serialized_start = 8515 - _STATCORR._serialized_end = 8652 - _STATAPPROXQUANTILE._serialized_start = 8655 - _STATAPPROXQUANTILE._serialized_end = 8819 - _STATFREQITEMS._serialized_start = 8821 - _STATFREQITEMS._serialized_end = 8946 - _STATSAMPLEBY._serialized_start = 8949 - _STATSAMPLEBY._serialized_end = 9258 - _STATSAMPLEBY_FRACTION._serialized_start = 9150 - _STATSAMPLEBY_FRACTION._serialized_end = 9249 - _NAFILL._serialized_start = 9261 - _NAFILL._serialized_end = 9395 - _NADROP._serialized_start = 9398 - _NADROP._serialized_end = 9532 - _NAREPLACE._serialized_start = 9535 - _NAREPLACE._serialized_end = 9831 - _NAREPLACE_REPLACEMENT._serialized_start = 9690 - _NAREPLACE_REPLACEMENT._serialized_end = 9831 - _TODF._serialized_start = 9833 - _TODF._serialized_end = 9921 - _WITHCOLUMNSRENAMED._serialized_start = 9924 - _WITHCOLUMNSRENAMED._serialized_end = 10163 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10096 - _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10163 - _WITHCOLUMNS._serialized_start = 10165 - _WITHCOLUMNS._serialized_end = 10284 - _WITHWATERMARK._serialized_start = 10287 - _WITHWATERMARK._serialized_end = 10421 - _HINT._serialized_start = 10424 - _HINT._serialized_end = 10556 - _UNPIVOT._serialized_start = 10559 - _UNPIVOT._serialized_end = 10886 - _UNPIVOT_VALUES._serialized_start = 10816 - _UNPIVOT_VALUES._serialized_end = 10875 - _TOSCHEMA._serialized_start = 10888 - _TOSCHEMA._serialized_end = 10994 - _REPARTITIONBYEXPRESSION._serialized_start = 10997 - _REPARTITIONBYEXPRESSION._serialized_end = 11200 - _MAPPARTITIONS._serialized_start = 11203 - _MAPPARTITIONS._serialized_end = 11384 - _GROUPMAP._serialized_start = 11387 - _GROUPMAP._serialized_end = 12022 - _COGROUPMAP._serialized_start = 12025 - _COGROUPMAP._serialized_end = 12551 - _APPLYINPANDASWITHSTATE._serialized_start = 12554 - _APPLYINPANDASWITHSTATE._serialized_end = 12911 - _COLLECTMETRICS._serialized_start = 12914 - _COLLECTMETRICS._serialized_end = 13050 - _PARSE._serialized_start = 13053 - _PARSE._serialized_end = 13441 - _PARSE_OPTIONSENTRY._serialized_start = 3842 - _PARSE_OPTIONSENTRY._serialized_end = 3900 - _PARSE_PARSEFORMAT._serialized_start = 13342 - _PARSE_PARSEFORMAT._serialized_end = 13430 + _RELATION._serialized_end = 3334 + _UNKNOWN._serialized_start = 3336 + _UNKNOWN._serialized_end = 3345 + _RELATIONCOMMON._serialized_start = 3347 + _RELATIONCOMMON._serialized_end = 3438 + _SQL._serialized_start = 3441 + _SQL._serialized_end = 3672 + _SQL_ARGSENTRY._serialized_start = 3582 + _SQL_ARGSENTRY._serialized_end = 3672 + _READ._serialized_start = 3675 + _READ._serialized_end = 4338 + _READ_NAMEDTABLE._serialized_start = 3853 + _READ_NAMEDTABLE._serialized_end = 4045 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3987 + _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4045 + _READ_DATASOURCE._serialized_start = 4048 + _READ_DATASOURCE._serialized_end = 4325 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3987 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4045 + _PROJECT._serialized_start = 4340 + _PROJECT._serialized_end = 4457 + _FILTER._serialized_start = 4459 + _FILTER._serialized_end = 4571 + _JOIN._serialized_start = 4574 + _JOIN._serialized_end = 5235 + _JOIN_JOINDATATYPE._serialized_start = 4913 + _JOIN_JOINDATATYPE._serialized_end = 5005 + _JOIN_JOINTYPE._serialized_start = 5008 + _JOIN_JOINTYPE._serialized_end = 5216 + _SETOPERATION._serialized_start = 5238 + _SETOPERATION._serialized_end = 5717 + _SETOPERATION_SETOPTYPE._serialized_start = 5554 + _SETOPERATION_SETOPTYPE._serialized_end = 5668 + _LIMIT._serialized_start = 5719 + _LIMIT._serialized_end = 5795 + _OFFSET._serialized_start = 5797 + _OFFSET._serialized_end = 5876 + _TAIL._serialized_start = 5878 + _TAIL._serialized_end = 5953 + _AGGREGATE._serialized_start = 5956 + _AGGREGATE._serialized_end = 6538 + _AGGREGATE_PIVOT._serialized_start = 6295 + _AGGREGATE_PIVOT._serialized_end = 6406 + _AGGREGATE_GROUPTYPE._serialized_start = 6409 + _AGGREGATE_GROUPTYPE._serialized_end = 6538 + _SORT._serialized_start = 6541 + _SORT._serialized_end = 6701 + _DROP._serialized_start = 6704 + _DROP._serialized_end = 6845 + _DEDUPLICATE._serialized_start = 6848 + _DEDUPLICATE._serialized_end = 7088 + _LOCALRELATION._serialized_start = 7090 + _LOCALRELATION._serialized_end = 7179 + _CACHEDLOCALRELATION._serialized_start = 7181 + _CACHEDLOCALRELATION._serialized_end = 7276 + _CACHEDREMOTERELATION._serialized_start = 7278 + _CACHEDREMOTERELATION._serialized_end = 7333 + _SAMPLE._serialized_start = 7336 + _SAMPLE._serialized_end = 7609 + _RANGE._serialized_start = 7612 + _RANGE._serialized_end = 7757 + _SUBQUERYALIAS._serialized_start = 7759 + _SUBQUERYALIAS._serialized_end = 7873 + _REPARTITION._serialized_start = 7876 + _REPARTITION._serialized_end = 8018 + _SHOWSTRING._serialized_start = 8021 + _SHOWSTRING._serialized_end = 8163 + _HTMLSTRING._serialized_start = 8165 + _HTMLSTRING._serialized_end = 8279 + _STATSUMMARY._serialized_start = 8281 + _STATSUMMARY._serialized_end = 8373 + _STATDESCRIBE._serialized_start = 8375 + _STATDESCRIBE._serialized_end = 8456 + _STATCROSSTAB._serialized_start = 8458 + _STATCROSSTAB._serialized_end = 8559 + _STATCOV._serialized_start = 8561 + _STATCOV._serialized_end = 8657 + _STATCORR._serialized_start = 8660 + _STATCORR._serialized_end = 8797 + _STATAPPROXQUANTILE._serialized_start = 8800 + _STATAPPROXQUANTILE._serialized_end = 8964 + _STATFREQITEMS._serialized_start = 8966 + _STATFREQITEMS._serialized_end = 9091 + _STATSAMPLEBY._serialized_start = 9094 + _STATSAMPLEBY._serialized_end = 9403 + _STATSAMPLEBY_FRACTION._serialized_start = 9295 + _STATSAMPLEBY_FRACTION._serialized_end = 9394 + _NAFILL._serialized_start = 9406 + _NAFILL._serialized_end = 9540 + _NADROP._serialized_start = 9543 + _NADROP._serialized_end = 9677 + _NAREPLACE._serialized_start = 9680 + _NAREPLACE._serialized_end = 9976 + _NAREPLACE_REPLACEMENT._serialized_start = 9835 + _NAREPLACE_REPLACEMENT._serialized_end = 9976 + _TODF._serialized_start = 9978 + _TODF._serialized_end = 10066 + _WITHCOLUMNSRENAMED._serialized_start = 10069 + _WITHCOLUMNSRENAMED._serialized_end = 10308 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10241 + _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10308 + _WITHCOLUMNS._serialized_start = 10310 + _WITHCOLUMNS._serialized_end = 10429 + _WITHWATERMARK._serialized_start = 10432 + _WITHWATERMARK._serialized_end = 10566 + _HINT._serialized_start = 10569 + _HINT._serialized_end = 10701 + _UNPIVOT._serialized_start = 10704 + _UNPIVOT._serialized_end = 11031 + _UNPIVOT_VALUES._serialized_start = 10961 + _UNPIVOT_VALUES._serialized_end = 11020 + _TOSCHEMA._serialized_start = 11033 + _TOSCHEMA._serialized_end = 11139 + _REPARTITIONBYEXPRESSION._serialized_start = 11142 + _REPARTITIONBYEXPRESSION._serialized_end = 11345 + _MAPPARTITIONS._serialized_start = 11348 + _MAPPARTITIONS._serialized_end = 11529 + _GROUPMAP._serialized_start = 11532 + _GROUPMAP._serialized_end = 12167 + _COGROUPMAP._serialized_start = 12170 + _COGROUPMAP._serialized_end = 12696 + _APPLYINPANDASWITHSTATE._serialized_start = 12699 + _APPLYINPANDASWITHSTATE._serialized_end = 13056 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13059 + _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13303 + _PYTHONUDTF._serialized_start = 13306 + _PYTHONUDTF._serialized_end = 13483 + _COLLECTMETRICS._serialized_start = 13486 + _COLLECTMETRICS._serialized_end = 13622 + _PARSE._serialized_start = 13625 + _PARSE._serialized_end = 14013 + _PARSE_OPTIONSENTRY._serialized_start = 3987 + _PARSE_OPTIONSENTRY._serialized_end = 4045 + _PARSE_PARSEFORMAT._serialized_start = 13914 + _PARSE_PARSEFORMAT._serialized_end = 14002 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 1f15ed2c6c7..9cadd4acc52 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -99,6 +99,7 @@ class Relation(google.protobuf.message.Message): HTML_STRING_FIELD_NUMBER: builtins.int CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int CACHED_REMOTE_RELATION_FIELD_NUMBER: builtins.int + COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int DROP_NA_FIELD_NUMBER: builtins.int REPLACE_FIELD_NUMBER: builtins.int @@ -188,6 +189,10 @@ class Relation(google.protobuf.message.Message): @property def cached_remote_relation(self) -> global___CachedRemoteRelation: ... @property + def common_inline_user_defined_table_function( + self, + ) -> global___CommonInlineUserDefinedTableFunction: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -261,6 +266,8 @@ class Relation(google.protobuf.message.Message): html_string: global___HtmlString | None = ..., cached_local_relation: global___CachedLocalRelation | None = ..., cached_remote_relation: global___CachedRemoteRelation | None = ..., + common_inline_user_defined_table_function: global___CommonInlineUserDefinedTableFunction + | None = ..., fill_na: global___NAFill | None = ..., drop_na: global___NADrop | None = ..., replace: global___NAReplace | None = ..., @@ -297,6 +304,8 @@ class Relation(google.protobuf.message.Message): b"collect_metrics", "common", b"common", + "common_inline_user_defined_table_function", + b"common_inline_user_defined_table_function", "corr", b"corr", "cov", @@ -406,6 +415,8 @@ class Relation(google.protobuf.message.Message): b"collect_metrics", "common", b"common", + "common_inline_user_defined_table_function", + b"common_inline_user_defined_table_function", "corr", b"corr", "cov", @@ -533,6 +544,7 @@ class Relation(google.protobuf.message.Message): "html_string", "cached_local_relation", "cached_remote_relation", + "common_inline_user_defined_table_function", "fill_na", "drop_na", "replace", @@ -3378,6 +3390,113 @@ class ApplyInPandasWithState(google.protobuf.message.Message): global___ApplyInPandasWithState = ApplyInPandasWithState +class CommonInlineUserDefinedTableFunction(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FUNCTION_NAME_FIELD_NUMBER: builtins.int + DETERMINISTIC_FIELD_NUMBER: builtins.int + ARGUMENTS_FIELD_NUMBER: builtins.int + PYTHON_UDTF_FIELD_NUMBER: builtins.int + function_name: builtins.str + """(Required) Name of the user-defined table function.""" + deterministic: builtins.bool + """(Optional) Whether the user-defined table function is deterministic.""" + @property + def arguments( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Optional) Function input arguments. Empty arguments are allowed.""" + @property + def python_udtf(self) -> global___PythonUDTF: ... + def __init__( + self, + *, + function_name: builtins.str = ..., + deterministic: builtins.bool = ..., + arguments: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + | None = ..., + python_udtf: global___PythonUDTF | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "function", b"function", "python_udtf", b"python_udtf" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "arguments", + b"arguments", + "deterministic", + b"deterministic", + "function", + b"function", + "function_name", + b"function_name", + "python_udtf", + b"python_udtf", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["function", b"function"] + ) -> typing_extensions.Literal["python_udtf"] | None: ... + +global___CommonInlineUserDefinedTableFunction = CommonInlineUserDefinedTableFunction + +class PythonUDTF(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + RETURN_TYPE_FIELD_NUMBER: builtins.int + EVAL_TYPE_FIELD_NUMBER: builtins.int + COMMAND_FIELD_NUMBER: builtins.int + PYTHON_VER_FIELD_NUMBER: builtins.int + @property + def return_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: + """(Optional) Return type of the Python UDTF.""" + eval_type: builtins.int + """(Required) EvalType of the Python UDTF.""" + command: builtins.bytes + """(Required) The encoded commands of the Python UDTF.""" + python_ver: builtins.str + """(Required) Python version being used in the client.""" + def __init__( + self, + *, + return_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + eval_type: builtins.int = ..., + command: builtins.bytes = ..., + python_ver: builtins.str = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_return_type", b"_return_type", "return_type", b"return_type" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_return_type", + b"_return_type", + "command", + b"command", + "eval_type", + b"eval_type", + "python_ver", + b"python_ver", + "return_type", + b"return_type", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_return_type", b"_return_type"] + ) -> typing_extensions.Literal["return_type"] | None: ... + +global___PythonUDTF = PythonUDTF + class CollectMetrics(google.protobuf.message.Message): """Collect arbitrary (named) metrics from a dataset.""" diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index ea88d60d760..3f9d46a22f4 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -88,6 +88,7 @@ if TYPE_CHECKING: from pyspark.sql.connect._typing import OptionalPrimitiveType from pyspark.sql.connect.catalog import Catalog from pyspark.sql.connect.udf import UDFRegistration + from pyspark.sql.connect.udtf import UDTFRegistration # `_active_spark_session` stores the active spark connect session created by @@ -599,7 +600,7 @@ class SparkSession: raise PySparkAttributeError( error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name} ) - elif name in ["newSession", "sparkContext", "udtf"]: + elif name in ["newSession", "sparkContext"]: raise PySparkNotImplementedError( error_class="NOT_IMPLEMENTED", message_parameters={"feature": f"{name}()"} ) @@ -613,6 +614,14 @@ class SparkSession: udf.__doc__ = PySparkSession.udf.__doc__ + @property + def udtf(self) -> "UDTFRegistration": + from pyspark.sql.connect.udtf import UDTFRegistration + + return UDTFRegistration(self) + + udtf.__doc__ = PySparkSession.udtf.__doc__ + @property def version(self) -> str: result = self._client._analyze(method="spark_version").spark_version diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py new file mode 100644 index 00000000000..1fe8e1024ee --- /dev/null +++ b/python/pyspark/sql/connect/udtf.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +User-defined table function related classes and functions +""" +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + +import warnings +from typing import Type, TYPE_CHECKING, Optional, Union + +from pyspark.rdd import PythonEvalType +from pyspark.sql.connect.column import Column +from pyspark.sql.connect.expressions import ColumnReference +from pyspark.sql.connect.plan import ( + CommonInlineUserDefinedTableFunction, + PythonUDTF, +) +from pyspark.sql.connect.types import UnparsedDataType +from pyspark.sql.connect.utils import get_python_ver +from pyspark.sql.udtf import UDTFRegistration as PySparkUDTFRegistration +from pyspark.sql.udtf import _validate_udtf_handler +from pyspark.sql.types import DataType, StructType +from pyspark.errors import PySparkRuntimeError, PySparkTypeError + + +if TYPE_CHECKING: + from pyspark.sql.connect._typing import ColumnOrName + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.session import SparkSession + + +def _create_udtf( + cls: Type, + returnType: Union[StructType, str], + name: Optional[str] = None, + evalType: int = PythonEvalType.SQL_TABLE_UDF, + deterministic: bool = True, +) -> "UserDefinedTableFunction": + udtf_obj = UserDefinedTableFunction( + cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic + ) + return udtf_obj + + +def _create_py_udtf( + cls: Type, + returnType: Union[StructType, str], + name: Optional[str] = None, + deterministic: bool = True, + useArrow: Optional[bool] = None, +) -> "UserDefinedTableFunction": + if useArrow is not None: + arrow_enabled = useArrow + else: + from pyspark.sql.connect.session import _active_spark_session + + arrow_enabled = ( + _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") == "true" + if _active_spark_session is not None + else True + ) + + # Create a regular Python UDTF and check for invalid handler class. + regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) + + if not arrow_enabled: + return regular_udtf + + from pyspark.sql.pandas.utils import ( + require_minimum_pandas_version, + require_minimum_pyarrow_version, + ) + + try: + require_minimum_pandas_version() + require_minimum_pyarrow_version() + except ImportError as e: + warnings.warn( + f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. " + f"Falling back to using regular Python UDTFs.", + UserWarning, + ) + return regular_udtf + + from pyspark.sql.udtf import _vectorize_udtf + + vectorized_udtf = _vectorize_udtf(cls) + return _create_udtf( + vectorized_udtf, returnType, name, PythonEvalType.SQL_ARROW_TABLE_UDF, deterministic + ) + + +class UserDefinedTableFunction: + """ + User defined function in Python + + Notes + ----- + The constructor of this class is not supposed to be directly called. + Use :meth:`pyspark.sql.functions.udtf` to create this instance. + """ + + def __init__( + self, + func: Type, + returnType: Union[StructType, str], + name: Optional[str] = None, + evalType: int = PythonEvalType.SQL_TABLE_UDF, + deterministic: bool = True, + ) -> None: + self.func = func + self.returnType: DataType = ( + UnparsedDataType(returnType) if isinstance(returnType, str) else returnType + ) + self._name = name or func.__name__ + self.evalType = evalType + self.deterministic = deterministic + + _validate_udtf_handler(func) + + def _build_common_inline_user_defined_table_function( + self, *cols: "ColumnOrName" + ) -> CommonInlineUserDefinedTableFunction: + arg_cols = [ + col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols + ] + arg_exprs = [col._expr for col in arg_cols] + + udtf = PythonUDTF( + func=self.func, + return_type=self.returnType, + eval_type=self.evalType, + python_ver=get_python_ver(), + ) + return CommonInlineUserDefinedTableFunction( + function_name=self._name, + function=udtf, + deterministic=self.deterministic, + arguments=arg_exprs, + ) + + def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + from pyspark.sql.connect.dataframe import DataFrame + from pyspark.sql.connect.session import _active_spark_session + + if _active_spark_session is None: + raise PySparkRuntimeError( + "An active SparkSession is required for " + "executing a Python user-defined table function." + ) + + plan = self._build_common_inline_user_defined_table_function(*cols) + return DataFrame.withPlan(plan, _active_spark_session) + + def asNondeterministic(self) -> "UserDefinedTableFunction": + self.deterministic = False + return self + + +class UDTFRegistration: + """ + Wrapper for user-defined table function registration. + + .. versionadded:: 3.5.0 + """ + + def __init__(self, sparkSession: "SparkSession"): + self.sparkSession = sparkSession + + def register( + self, + name: str, + f: "UserDefinedTableFunction", + ) -> "UserDefinedTableFunction": + if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]: + raise PySparkTypeError( + error_class="INVALID_UDTF_EVAL_TYPE", + message_parameters={ + "name": name, + "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF", + }, + ) + + self.sparkSession._client.register_udtf( + f.func, f.returnType, name, f.evalType, f.deterministic + ) + return f + + register.__doc__ = PySparkUDTFRegistration.register.__doc__ diff --git a/python/pyspark/sql/connect/utils.py b/python/pyspark/sql/connect/utils.py index 25a94676551..8872ba50633 100644 --- a/python/pyspark/sql/connect/utils.py +++ b/python/pyspark/sql/connect/utils.py @@ -52,3 +52,7 @@ def require_minimum_grpc_version() -> None: "grpcio >= %s must be installed; however, " "your version was %s." % (minimum_grpc_version, grpc.__version__) ) + + +def get_python_ver() -> str: + return "%d.%d" % sys.version_info[:2] diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b2017627598..f566fcee0e3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15510,12 +15510,13 @@ def udf( return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow) +@try_remote_functions def udtf( cls: Optional[Type] = None, *, returnType: Union[StructType, str], useArrow: Optional[bool] = None, -) -> Union[UserDefinedTableFunction, functools.partial]: +) -> Union["UserDefinedTableFunction", Callable[[Type], "UserDefinedTableFunction"]]: """Creates a user defined table function (UDTF). .. versionadded:: 3.5.0 diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 268011ef1e4..c1235620990 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -27,7 +27,6 @@ from collections import defaultdict from pyspark.errors import ( PySparkAttributeError, - PySparkNotImplementedError, PySparkTypeError, PySparkException, PySparkValueError, @@ -3186,16 +3185,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): rows = [cols] * row_count self.assertEqual(row_count, self.connect.createDataFrame(data=rows).count()) - def test_unsupported_udtf(self): - with self.assertRaises(PySparkNotImplementedError) as e: - self.connect.udtf.register() - - self.check_error( - exception=e.exception, - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "udtf()"}, - ) - def test_unsupported_jvm_attribute(self): # Unsupported jvm attributes for Spark session. unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"] diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 3e3b4dd5b16..a5d330fe1a7 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -20,6 +20,7 @@ import unittest from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql import SparkSession as PySparkSession from pyspark.sql.types import StringType, StructType, StructField, ArrayType, IntegerType +from pyspark.testing import assertDataFrameEqual from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.sqlutils import SQLTestUtils @@ -2342,6 +2343,23 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S sdf.withColumn("A", sfun(sdf.c)).toPandas(), ) + def test_udtf(self): + class TestUDTF: + def eval(self, x: int, y: int): + yield x, x + 1 + yield y, y + 1 + + sfunc = SF.udtf(TestUDTF, returnType="a: int, b: int") + cfunc = CF.udtf(TestUDTF, returnType="a: int, b: int") + + assertDataFrameEqual(sfunc(SF.lit(1), SF.lit(1)), cfunc(CF.lit(1), CF.lit(1))) + + self.spark.udtf.register("test_udtf", sfunc) + self.connect.udtf.register("test_udtf", cfunc) + + query = "select * from test_udtf(1, 2)" + assertDataFrameEqual(self.spark.sql(query), self.connect.sql(query)) + def test_pandas_udf_import(self): self.assert_eq(getattr(CF, "pandas_udf"), getattr(SF, "pandas_udf")) diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py new file mode 100644 index 00000000000..f5f37f1f5c0 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -0,0 +1,141 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.testing.connectutils import should_test_connect + +if should_test_connect: + from pyspark import sql + from pyspark.sql.connect.udtf import UserDefinedTableFunction + + sql.udtf.UserDefinedTableFunction = UserDefinedTableFunction + +from pyspark.sql.connect.functions import lit, udtf +from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.errors.exceptions.connect import SparkConnectGrpcException + + +class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): + @classmethod + def setUpClass(cls): + super(UDTFParityTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "false") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") + finally: + super(UDTFParityTests, cls).tearDownClass() + + # TODO: use PySpark error classes instead of SparkConnectGrpcException + + def test_udtf_with_invalid_return_type(self): + @udtf(returnType="int") + class TestUDTF: + def eval(self, a: int): + yield a + 1, + + with self.assertRaisesRegex( + SparkConnectGrpcException, "Invalid Python user-defined table function return type." + ): + TestUDTF(lit(1)).collect() + + def test_udtf_with_wrong_num_output(self): + err_msg = ( + "java.lang.IllegalStateException: Input row doesn't have expected number of " + + "values required by the schema." + ) + + @udtf(returnType="a: int, b: int") + class TestUDTF: + def eval(self, a: int): + yield a, + + with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): + TestUDTF(lit(1)).collect() + + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a: int): + yield a, a + 1 + + with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): + TestUDTF(lit(1)).collect() + + def test_udtf_terminate_with_wrong_num_output(self): + err_msg = ( + "java.lang.IllegalStateException: Input row doesn't have expected number of " + "values required by the schema." + ) + + @udtf(returnType="a: int, b: int") + class TestUDTF: + def eval(self, a: int): + yield a, a + 1 + + def terminate(self): + yield 1, 2, 3 + + with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): + TestUDTF(lit(1)).show() + + @udtf(returnType="a: int, b: int") + class TestUDTF: + def eval(self, a: int): + yield a, a + 1 + + def terminate(self): + yield 1, + + with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): + TestUDTF(lit(1)).show() + + def test_udtf_with_empty_yield(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a: int): + yield + + with self.assertRaisesRegex(SparkConnectGrpcException, "java.lang.NullPointerException"): + TestUDTF(lit(1)).collect() + + +class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): + @classmethod + def setUpClass(cls): + super(ArrowUDTFParityTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", "true") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled") + finally: + super(ArrowUDTFParityTests, cls).tearDownClass() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_udtf import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 8f42b4123e4..f109302dec5 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -30,6 +30,8 @@ from pyspark.errors import ( from pyspark.rdd import PythonEvalType from pyspark.sql.functions import lit, udf, udtf from pyspark.sql.types import Row +from pyspark.testing import assertDataFrameEqual +from pyspark.sql.types import MapType, StringType, IntegerType from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -207,8 +209,8 @@ class BaseUDTFTestsMixin: self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)]) df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"]) self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), [Row(a=1, b=2)]) - self.assertEqual( - TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, b=None), Row(a=1, b=2)] + assertDataFrameEqual( + TestUDTF(lit(1)).join(df, "a", "left"), [Row(a=None, b=None), Row(a=1, b=2)] ) def test_udtf_with_none_input(self): @@ -381,6 +383,38 @@ class BaseUDTFTestsMixin: ): TestUDTF(rand(0) * 100).collect() + def test_udtf_with_invalid_return_type(self): + @udtf(returnType="int") + class TestUDTF: + def eval(self, a: int): + yield a + 1, + + with self.assertRaises(PySparkTypeError) as e: + TestUDTF(lit(1)).collect() + + self.check_error( + exception=e.exception, + error_class="UDTF_RETURN_TYPE_MISMATCH", + message_parameters={"name": "TestUDTF", "return_type": "IntegerType()"}, + ) + + @udtf(returnType=MapType(StringType(), IntegerType())) + class TestUDTF: + def eval(self, a: int): + yield a + 1, + + with self.assertRaises(PySparkTypeError) as e: + TestUDTF(lit(1)).collect() + + self.check_error( + exception=e.exception, + error_class="UDTF_RETURN_TYPE_MISMATCH", + message_parameters={ + "name": "TestUDTF", + "return_type": "MapType(StringType(), IntegerType(), True)", + }, + ) + def test_udtf_with_struct_input_type(self): @udtf(returnType="x: string") class TestUDTF: diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index b0f373430be..3ab74193093 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -19,7 +19,7 @@ User-defined table function related classes and functions """ import sys import warnings -from typing import Iterator, Type, TYPE_CHECKING, Optional, Union +from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union from py4j.java_gateway import JavaObject @@ -77,27 +77,35 @@ def _create_py_udtf( # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) - if arrow_enabled: - try: - require_minimum_pandas_version() - require_minimum_pyarrow_version() - except ImportError as e: - warnings.warn( - f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. " - f"Falling back to using regular Python UDTFs.", - UserWarning, - ) - return regular_udtf - return _create_arrow_udtf(regular_udtf) - else: + if not arrow_enabled: return regular_udtf + # Return the regular UDTF if the required dependencies are not satisfied. + try: + require_minimum_pandas_version() + require_minimum_pyarrow_version() + except ImportError as e: + warnings.warn( + f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. " + f"Falling back to using regular Python UDTFs.", + UserWarning, + ) + return regular_udtf -def _create_arrow_udtf(regular_udtf: "UserDefinedTableFunction") -> "UserDefinedTableFunction": - """Create an Arrow-optimized Python UDTF.""" - import pandas as pd + # Return the vectorized UDTF. + vectorized_udtf = _vectorize_udtf(cls) + return _create_udtf( + cls=vectorized_udtf, + returnType=returnType, + name=name, + evalType=PythonEvalType.SQL_ARROW_TABLE_UDF, + deterministic=regular_udtf.deterministic, + ) - cls = regular_udtf.func + +def _vectorize_udtf(cls: Type) -> Type: + """Vectorize a Python UDTF handler class.""" + import pandas as pd class VectorizedUDTF: def __init__(self) -> None: @@ -126,13 +134,24 @@ def _create_arrow_udtf(regular_udtf: "UserDefinedTableFunction") -> "UserDefined if hasattr(cls, "terminate"): getattr(vectorized_udtf, "terminate").__doc__ = getattr(cls, "terminate").__doc__ - return _create_udtf( - cls=vectorized_udtf, - returnType=regular_udtf.returnType, - name=regular_udtf._name, - evalType=PythonEvalType.SQL_ARROW_TABLE_UDF, - deterministic=regular_udtf.deterministic, - ) + return vectorized_udtf + + +def _validate_udtf_handler(cls: Any) -> None: + """Validate the handler class of a UDTF.""" + # TODO(SPARK-43968): add more compile time checks for UDTFs + + if not isinstance(cls, type): + raise PySparkTypeError( + f"Invalid user defined table function: the function handler " + f"must be a class, but got {type(cls).__name__}. Please provide " + "a class as the handler." + ) + + if not hasattr(cls, "eval"): + raise PySparkAttributeError( + error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": cls.__name__} + ) class UserDefinedTableFunction: @@ -157,15 +176,6 @@ class UserDefinedTableFunction: evalType: int = PythonEvalType.SQL_TABLE_UDF, deterministic: bool = True, ): - - if not isinstance(func, type): - raise PySparkTypeError( - f"Invalid user defined table function: the function handler " - f"must be a class, but got {type(func).__name__}. Please provide " - "a class as the handler." - ) - - # TODO(SPARK-43968): add more compile time checks for UDTFs self.func = func self._returnType = returnType self._returnType_placeholder: Optional[StructType] = None @@ -175,29 +185,26 @@ class UserDefinedTableFunction: self.evalType = evalType self.deterministic = deterministic - if not hasattr(func, "eval"): - raise PySparkAttributeError( - error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": self._name} - ) + _validate_udtf_handler(func) @property def returnType(self) -> StructType: # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string. # This makes sure this is called after SparkContext is initialized. if self._returnType_placeholder is None: - if isinstance(self._returnType, StructType): - self._returnType_placeholder = self._returnType - else: - assert isinstance(self._returnType, str) + if isinstance(self._returnType, str): parsed = _parse_datatype_string(self._returnType) - if not isinstance(parsed, StructType): - raise PySparkTypeError( - f"Invalid return type for the user defined table function " - f"'{self._name}': {self._returnType}. The return type of a " - f"UDTF must be a 'StructType'. Please ensure the return " - "type is a correctly formatted 'StructType' string." - ) - self._returnType_placeholder = parsed + else: + parsed = self._returnType + if not isinstance(parsed, StructType): + raise PySparkTypeError( + error_class="UDTF_RETURN_TYPE_MISMATCH", + message_parameters={ + "name": self._name, + "return_type": f"{parsed}", + }, + ) + self._returnType_placeholder = parsed return self._returnType_placeholder @property @@ -254,8 +261,8 @@ class UDTFRegistration: def register( self, name: str, - f: UserDefinedTableFunction, - ) -> UserDefinedTableFunction: + f: "UserDefinedTableFunction", + ) -> "UserDefinedTableFunction": """Register a Python user-defined table function as a SQL table function. .. versionadded:: 3.5.0 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org