This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d4629563492 [SPARK-44749][SQL][PYTHON] Support named arguments in Python UDTF d4629563492 is described below commit d4629563492ec3090b5bd5b924790507c42f4e86 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Aug 14 08:57:55 2023 -0700 [SPARK-44749][SQL][PYTHON] Support named arguments in Python UDTF ### What changes were proposed in this pull request? Supports named arguments in Python UDTF. For example: ```py >>> udtf(returnType="a: int") ... class TestUDTF: ... def eval(self, a, b): ... yield a, ... >>> spark.udtf.register("test_udtf", TestUDTF) >>> TestUDTF(a=lit(10), b=lit("x")).show() +---+ | a| +---+ | 10| +---+ >>> TestUDTF(b=lit("x"), a=lit(10)).show() +---+ | a| +---+ | 10| +---+ >>> spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')").show() +---+ | a| +---+ | 10| +---+ >>> spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)").show() +---+ | a| +---+ | 10| +---+ ``` or: ```py >>> udtf ... class TestUDTF: ... staticmethod ... def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult: ... return AnalyzeResult( ... StructType( ... [StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())] ... ) ... ) ... def eval(self, **kwargs): ... yield tuple(value for _, value in sorted(kwargs.items())) ... >>> spark.udtf.register("test_udtf", TestUDTF) >>> spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x', x=>100.0)").show() +---+---+-----+ | a| b| x| +---+---+-----+ | 10| x|100.0| +---+---+-----+ >>> spark.sql("SELECT * FROM test_udtf(x=>10, a=>'x', z=>100.0)").show() +---+---+-----+ | a| x| z| +---+---+-----+ | x| 10|100.0| +---+---+-----+ ``` ### Why are the changes needed? Now that named arguments are supported (https://github.com/apache/spark/pull/41796, https://github.com/apache/spark/pull/42020). It should be supported in Python UDTF. ### Does this PR introduce _any_ user-facing change? Yes, named arguments will be available for Python UDTF. ### How was this patch tested? Added related tests. Closes #42422 from ueshin/issues/SPARK-44749/kwargs. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- .../main/protobuf/spark/connect/expressions.proto | 9 ++ .../sql/connect/planner/SparkConnectPlanner.scala | 7 ++ python/pyspark/sql/column.py | 14 +++ python/pyspark/sql/connect/expressions.py | 20 ++++ .../pyspark/sql/connect/proto/expressions_pb2.py | 122 +++++++++++---------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 34 ++++++ python/pyspark/sql/connect/udtf.py | 18 +-- python/pyspark/sql/functions.py | 38 +++++++ python/pyspark/sql/tests/test_udtf.py | 88 +++++++++++++++ python/pyspark/sql/udtf.py | 36 ++++-- python/pyspark/sql/worker/analyze_udtf.py | 20 +++- python/pyspark/worker.py | 29 ++++- .../plans/logical/FunctionBuilderBase.scala | 67 +++++++---- .../execution/python/ArrowEvalPythonUDTFExec.scala | 5 +- .../execution/python/ArrowPythonUDTFRunner.scala | 8 +- .../execution/python/BatchEvalPythonUDTFExec.scala | 30 +++-- .../sql/execution/python/EvalPythonUDTFExec.scala | 33 ++++-- .../python/UserDefinedPythonFunction.scala | 42 +++++-- 18 files changed, 472 insertions(+), 148 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 557b9db9123..b222f663cd0 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -47,6 +47,7 @@ message Expression { UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14; CommonInlineUserDefinedFunction common_inline_user_defined_function = 15; CallFunction call_function = 16; + NamedArgumentExpression named_argument_expression = 17; // This field is used to mark extensions to the protocol. When plugins generate arbitrary // relations they can add them here. During the planning the correct resolution is done. @@ -380,3 +381,11 @@ message CallFunction { // (Optional) Function arguments. Empty arguments are allowed. repeated Expression arguments = 2; } + +message NamedArgumentExpression { + // (Required) The key of the named argument. + string key = 1; + + // (Required) The value expression of the named argument. + Expression value = 2; +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e6305cd9d1a..4ab24cb058b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1384,6 +1384,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction) case proto.Expression.ExprTypeCase.CALL_FUNCTION => transformCallFunction(exp.getCallFunction) + case proto.Expression.ExprTypeCase.NAMED_ARGUMENT_EXPRESSION => + transformNamedArgumentExpression(exp.getNamedArgumentExpression) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -1505,6 +1507,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { false) } + private def transformNamedArgumentExpression( + namedArg: proto.NamedArgumentExpression): Expression = { + NamedArgumentExpression(namedArg.getKey, transformExpression(namedArg.getValue)) + } + private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = { unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf) } diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 087cfaaa20b..3a6d6e1cea7 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -73,9 +73,23 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject: return _to_java_column(col).expr() +@overload +def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject: + pass + + +@overload def _to_seq( sc: SparkContext, cols: Iterable["ColumnOrName"], + converter: Optional[Callable[["ColumnOrName"], JavaObject]], +) -> JavaObject: + pass + + +def _to_seq( + sc: SparkContext, + cols: Union[Iterable["ColumnOrName"], Iterable[JavaObject]], converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None, ) -> JavaObject: """ diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index d0a9b1d69ae..34aa4da1117 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -1054,3 +1054,23 @@ class CallFunction(Expression): return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})" else: return f"CallFunction('{self._name}')" + + +class NamedArgumentExpression(Expression): + def __init__(self, key: str, value: Expression): + super().__init__() + + assert isinstance(key, str) + self._key = key + + assert isinstance(value, Expression) + self._value = value + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + expr = proto.Expression() + expr.named_argument_expression.key = self._key + expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session)) + return expr + + def __repr__(self) -> str: + return f"{self._key} => {self._value}" diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 51d1a5d48a1..51ad47bb1c8 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xd9+\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbf,\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -45,63 +45,65 @@ if _descriptor._USE_C_DESCRIPTORS == False: b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 5698 - _EXPRESSION_WINDOW._serialized_start = 1543 - _EXPRESSION_WINDOW._serialized_end = 2326 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1833 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2326 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2100 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2245 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2247 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2326 - _EXPRESSION_SORTORDER._serialized_start = 2329 - _EXPRESSION_SORTORDER._serialized_end = 2754 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2559 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2667 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2669 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2754 - _EXPRESSION_CAST._serialized_start = 2757 - _EXPRESSION_CAST._serialized_end = 2902 - _EXPRESSION_LITERAL._serialized_start = 2905 - _EXPRESSION_LITERAL._serialized_end = 4468 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3740 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3857 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3859 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3957 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 3960 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 4090 - _EXPRESSION_LITERAL_MAP._serialized_start = 4093 - _EXPRESSION_LITERAL_MAP._serialized_end = 4320 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 4323 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 4452 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4470 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4582 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4585 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4789 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4791 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4841 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4843 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4925 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4927 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5013 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5016 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5148 - _EXPRESSION_UPDATEFIELDS._serialized_start = 5151 - _EXPRESSION_UPDATEFIELDS._serialized_end = 5338 - _EXPRESSION_ALIAS._serialized_start = 5340 - _EXPRESSION_ALIAS._serialized_end = 5460 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5463 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5621 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5623 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5685 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5701 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6065 - _PYTHONUDF._serialized_start = 6068 - _PYTHONUDF._serialized_end = 6223 - _SCALARSCALAUDF._serialized_start = 6226 - _SCALARSCALAUDF._serialized_end = 6410 - _JAVAUDF._serialized_start = 6413 - _JAVAUDF._serialized_end = 6562 - _CALLFUNCTION._serialized_start = 6564 - _CALLFUNCTION._serialized_end = 6672 + _EXPRESSION._serialized_end = 5800 + _EXPRESSION_WINDOW._serialized_start = 1645 + _EXPRESSION_WINDOW._serialized_end = 2428 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2428 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2202 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2347 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2349 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2428 + _EXPRESSION_SORTORDER._serialized_start = 2431 + _EXPRESSION_SORTORDER._serialized_end = 2856 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2661 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2769 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856 + _EXPRESSION_CAST._serialized_start = 2859 + _EXPRESSION_CAST._serialized_end = 3004 + _EXPRESSION_LITERAL._serialized_start = 3007 + _EXPRESSION_LITERAL._serialized_end = 4570 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3842 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3959 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3961 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4059 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 4062 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 4192 + _EXPRESSION_LITERAL_MAP._serialized_start = 4195 + _EXPRESSION_LITERAL_MAP._serialized_end = 4422 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 4425 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 4554 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4572 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4684 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4687 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4891 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4893 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4943 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4945 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5027 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5029 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5115 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5118 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5250 + _EXPRESSION_UPDATEFIELDS._serialized_start = 5253 + _EXPRESSION_UPDATEFIELDS._serialized_end = 5440 + _EXPRESSION_ALIAS._serialized_start = 5442 + _EXPRESSION_ALIAS._serialized_end = 5562 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5565 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5723 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5725 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5787 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5803 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6167 + _PYTHONUDF._serialized_start = 6170 + _PYTHONUDF._serialized_end = 6325 + _SCALARSCALAUDF._serialized_start = 6328 + _SCALARSCALAUDF._serialized_end = 6512 + _JAVAUDF._serialized_start = 6515 + _JAVAUDF._serialized_end = 6664 + _CALLFUNCTION._serialized_start = 6666 + _CALLFUNCTION._serialized_end = 6774 + _NAMEDARGUMENTEXPRESSION._serialized_start = 6776 + _NAMEDARGUMENTEXPRESSION._serialized_end = 6868 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index b9b16ce35e3..2b418ef23f6 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1102,6 +1102,7 @@ class Expression(google.protobuf.message.Message): UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int CALL_FUNCTION_FIELD_NUMBER: builtins.int + NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @@ -1138,6 +1139,8 @@ class Expression(google.protobuf.message.Message): @property def call_function(self) -> global___CallFunction: ... @property + def named_argument_expression(self) -> global___NamedArgumentExpression: ... + @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary relations they can add them here. During the planning the correct resolution is done. @@ -1162,6 +1165,7 @@ class Expression(google.protobuf.message.Message): | None = ..., common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ..., call_function: global___CallFunction | None = ..., + named_argument_expression: global___NamedArgumentExpression | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... def HasField( @@ -1185,6 +1189,8 @@ class Expression(google.protobuf.message.Message): b"lambda_function", "literal", b"literal", + "named_argument_expression", + b"named_argument_expression", "sort_order", b"sort_order", "unresolved_attribute", @@ -1226,6 +1232,8 @@ class Expression(google.protobuf.message.Message): b"lambda_function", "literal", b"literal", + "named_argument_expression", + b"named_argument_expression", "sort_order", b"sort_order", "unresolved_attribute", @@ -1265,6 +1273,7 @@ class Expression(google.protobuf.message.Message): "unresolved_named_lambda_variable", "common_inline_user_defined_function", "call_function", + "named_argument_expression", "extension", ] | None: ... @@ -1505,3 +1514,28 @@ class CallFunction(google.protobuf.message.Message): ) -> None: ... global___CallFunction = CallFunction + +class NamedArgumentExpression(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + """(Required) The key of the named argument.""" + @property + def value(self) -> global___Expression: + """(Required) The value expression of the named argument.""" + def __init__( + self, + *, + key: builtins.str = ..., + value: global___Expression | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + +global___NamedArgumentExpression = NamedArgumentExpression diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index c8495626292..ce37832854c 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -22,11 +22,11 @@ from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) import warnings -from typing import Type, TYPE_CHECKING, Optional, Union +from typing import List, Type, TYPE_CHECKING, Optional, Union from pyspark.rdd import PythonEvalType from pyspark.sql.connect.column import Column -from pyspark.sql.connect.expressions import ColumnReference +from pyspark.sql.connect.expressions import ColumnReference, Expression, NamedArgumentExpression from pyspark.sql.connect.plan import ( CommonInlineUserDefinedTableFunction, PythonUDTF, @@ -146,12 +146,14 @@ class UserDefinedTableFunction: self.deterministic = deterministic def _build_common_inline_user_defined_table_function( - self, *cols: "ColumnOrName" + self, *args: "ColumnOrName", **kwargs: "ColumnOrName" ) -> CommonInlineUserDefinedTableFunction: - arg_cols = [ - col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols + def to_expr(col: "ColumnOrName") -> Expression: + return col._expr if isinstance(col, Column) else ColumnReference(col) + + arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [ + NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items() ] - arg_exprs = [col._expr for col in arg_cols] udtf = PythonUDTF( func=self.func, @@ -166,13 +168,13 @@ class UserDefinedTableFunction: arguments=arg_exprs, ) - def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame": from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame session = SparkSession.active() - plan = self._build_common_inline_user_defined_table_function(*cols) + plan = self._build_common_inline_user_defined_table_function(*args, **kwargs) return DataFrame.withPlan(plan, session) def asNondeterministic(self) -> "UserDefinedTableFunction": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fdb4ec8111e..9cc364cc1f8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15547,6 +15547,12 @@ def udtf( .. versionadded:: 3.5.0 + .. versionchanged:: 4.0.0 + Supports Python side analysis. + + .. versionchanged:: 4.0.0 + Supports keyword-arguments. + Parameters ---------- cls : class @@ -15623,6 +15629,38 @@ def udtf( | 1| x| +---+---+ + UDTF can use keyword arguments: + + >>> @udtf + ... class TestUDTFWithKwargs: + ... @staticmethod + ... def analyze( + ... a: AnalyzeArgument, b: AnalyzeArgument, **kwargs: AnalyzeArgument + ... ) -> AnalyzeResult: + ... return AnalyzeResult( + ... StructType().add("a", a.data_type) + ... .add("b", b.data_type) + ... .add("x", kwargs["x"].data_type) + ... ) + ... + ... def eval(self, a, b, **kwargs): + ... yield a, b, kwargs["x"] + ... + >>> TestUDTFWithKwargs(lit(1), x=lit("x"), b=lit("b")).show() + +---+---+---+ + | a| b| x| + +---+---+---+ + | 1| b| x| + +---+---+---+ + + >>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs) + >>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show() + +---+---+---+ + | a| b| x| + +---+---+---+ + | 1| b| x| + +---+---+---+ + Arrow optimization can be explicitly enabled when creating UDTFs: >>> @udtf(returnType="c1: int, c2: int", useArrow=True) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 300067716e9..cd0604ccace 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1565,6 +1565,7 @@ class BaseUDTFTestsMixin: expected = [Row(c1="hello", c2="world")] assertDataFrameEqual(TestUDTF(), expected) assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) + assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected) with self.assertRaisesRegex( AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given" @@ -1795,6 +1796,93 @@ class BaseUDTFTestsMixin: assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + def test_udtf_with_named_arguments(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a, b): + yield a, + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), + self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + TestUDTF(a=lit(10), b=lit("x")), + TestUDTF(b=lit("x"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10)]) + + def test_udtf_with_named_arguments_negative(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a, b): + yield a, + + self.spark.udtf.register("test_udtf", TestUDTF) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show() + + with self.assertRaisesRegex( + PythonException, r"eval\(\) got an unexpected keyword argument 'c'" + ): + self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show() + + def test_udtf_with_kwargs(self): + @udtf(returnType="a: int, b: string") + class TestUDTF: + def eval(self, **kwargs): + yield kwargs["a"], kwargs["b"] + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), + self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + TestUDTF(a=lit(10), b=lit("x")), + TestUDTF(b=lit("x"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10, b="x")]) + + def test_udtf_with_analyze_kwargs(self): + @udtf + class TestUDTF: + @staticmethod + def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult( + StructType( + [StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())] + ) + ) + + def eval(self, **kwargs): + yield tuple(value for _, value in sorted(kwargs.items())) + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), + self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + TestUDTF(a=lit(10), b=lit("x")), + TestUDTF(b=lit("x"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10, b="x")]) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 027a2646a46..1ca87aae758 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -29,7 +29,7 @@ from py4j.java_gateway import JavaObject from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError from pyspark.rdd import PythonEvalType -from pyspark.sql.column import _to_java_column, _to_seq +from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.sql.types import DataType, StructType, _parse_datatype_string from pyspark.sql.udf import _wrap_function @@ -148,9 +148,9 @@ def _vectorize_udtf(cls: Type) -> Type: # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) - def evaluate(*a: Any) -> Any: + def evaluate(*a: Any, **kw: Any) -> Any: try: - return f(*a) + return f(*a, **kw) except Exception as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", @@ -168,18 +168,22 @@ def _vectorize_udtf(cls: Type) -> Type: ): @staticmethod - def analyze(*args: AnalyzeArgument) -> AnalyzeResult: - return cls.analyze(*args) + def analyze(*args: AnalyzeArgument, **kwargs: AnalyzeArgument) -> AnalyzeResult: + return cls.analyze(*args, **kwargs) - def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]: - if len(args) == 0: + def eval(self, *args: pd.Series, **kwargs: pd.Series) -> Iterator[pd.DataFrame]: + if len(args) == 0 and len(kwargs) == 0: yield pd.DataFrame(wrap_func(self.func.eval)()) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. - row_tuples = zip(*args) + keys = list(kwargs.keys()) + len_args = len(args) + row_tuples = zip(*args, *[kwargs[key] for key in keys]) for row in row_tuples: - res = wrap_func(self.func.eval)(*row) + res = wrap_func(self.func.eval)( + *row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)} + ) if res is not None and not isinstance(res, Iterable): raise PySparkRuntimeError( error_class="UDTF_RETURN_NOT_ITERABLE", @@ -339,14 +343,24 @@ class UserDefinedTableFunction: ) return judtf - def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame": from pyspark.sql import DataFrame, SparkSession spark = SparkSession._getActiveSessionOrCreate() sc = spark.sparkContext + assert sc._jvm is not None + jcols = [_to_java_column(arg) for arg in args] + [ + sc._jvm.Column( + sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression( + key, _to_java_expr(value) + ) + ) + for key, value in kwargs.items() + ] + judtf = self._judtf - jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column)) + jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, jcols)) return DataFrame(jPythonUDTF, spark) def asNondeterministic(self) -> "UserDefinedTableFunction": diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 9ffa03541e6..7ba0789fa7b 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -19,7 +19,7 @@ import inspect import os import sys import traceback -from typing import List, IO +from typing import Dict, List, IO, Tuple from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkRuntimeError, PySparkValueError @@ -69,11 +69,12 @@ def read_udtf(infile: IO) -> type: return handler -def read_arguments(infile: IO) -> List[AnalyzeArgument]: +def read_arguments(infile: IO) -> Tuple[List[AnalyzeArgument], Dict[str, AnalyzeArgument]]: """Reads the arguments for `analyze` static method.""" # Receive arguments num_args = read_int(infile) args: List[AnalyzeArgument] = [] + kwargs: Dict[str, AnalyzeArgument] = {} for _ in range(num_args): dt = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if read_bool(infile): # is foldable @@ -83,8 +84,15 @@ def read_arguments(infile: IO) -> List[AnalyzeArgument]: else: value = None is_table = read_bool(infile) # is table argument - args.append(AnalyzeArgument(data_type=dt, value=value, is_table=is_table)) - return args + argument = AnalyzeArgument(data_type=dt, value=value, is_table=is_table) + + is_named_arg = read_bool(infile) + if is_named_arg: + name = utf8_deserializer.loads(infile) + kwargs[name] = argument + else: + args.append(argument) + return args, kwargs def main(infile: IO, outfile: IO) -> None: @@ -107,9 +115,9 @@ def main(infile: IO, outfile: IO) -> None: _accumulatorRegistry.clear() handler = read_udtf(infile) - args = read_arguments(infile) + args, kwargs = read_arguments(infile) - result = handler.analyze(*args) # type: ignore[attr-defined] + result = handler.analyze(*args, **kwargs) # type: ignore[attr-defined] if not isinstance(result, AnalyzeResult): raise PySparkValueError( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6f27400387e..8916a794001 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -550,7 +550,16 @@ def read_udtf(pickleSer, infile, eval_type): # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand' num_arg = read_int(infile) - arg_offsets = [read_int(infile) for _ in range(num_arg)] + args_offsets = [] + kwargs_offsets = {} + for _ in range(num_arg): + offset = read_int(infile) + if read_bool(infile): + name = utf8_deserializer.loads(infile) + kwargs_offsets[name] = offset + else: + args_offsets.append(offset) + handler = read_command(pickleSer, infile) if not isinstance(handler, type): raise PySparkRuntimeError( @@ -619,7 +628,9 @@ def read_udtf(pickleSer, infile, eval_type): ) return result - return lambda *a: map(lambda res: (res, arrow_return_type), map(verify_result, f(*a))) + return lambda *a, **kw: map( + lambda res: (res, arrow_return_type), map(verify_result, f(*a, **kw)) + ) eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type) @@ -633,7 +644,10 @@ def read_udtf(pickleSer, infile, eval_type): for a in it: # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type). - yield from eval(*[a[o] for o in arg_offsets]) + yield from eval( + *[a[o] for o in args_offsets], + **{k: a[o] for k, o in kwargs_offsets.items()}, + ) finally: if terminate is not None: yield from terminate() @@ -667,9 +681,9 @@ def read_udtf(pickleSer, infile, eval_type): return toInternal(result) # Evaluate the function and return a tuple back to the executor. - def evaluate(*a) -> tuple: + def evaluate(*a, **kw) -> tuple: try: - res = f(*a) + res = f(*a, **kw) except Exception as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", @@ -705,7 +719,10 @@ def read_udtf(pickleSer, infile, eval_type): def mapper(_, it): try: for a in it: - yield eval(*[a[o] for o in arg_offsets]) + yield eval( + *[a[o] for o in args_offsets], + **{k: a[o] for k, o in kwargs_offsets.items()}, + ) finally: if terminate is not None: yield terminate() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 1088655f60c..d13bfab6d70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -72,6 +72,38 @@ trait FunctionBuilderBase[T] { } object NamedParametersSupport { + /** + * This method splits named arguments from the argument list. + * Also checks if: + * - the named arguments don't contains positional arguments once keyword arguments start + * - the named arguments don't use the duplicated names + * + * @param functionSignature The function signature that defines the positional ordering + * @param args The argument list provided in function invocation + * @return A tuple of a list of positional arguments and a list of keyword arguments + */ + def splitAndCheckNamedArguments( + args: Seq[Expression], + functionName: String): (Seq[Expression], Seq[NamedArgumentExpression]) = { + val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression]) + + val namedParametersSet = collection.mutable.Set[String]() + + (positionalArgs, + namedArgs.zipWithIndex.map { + case (namedArg @ NamedArgumentExpression(parameterName, _), _) => + if (namedParametersSet.contains(parameterName)) { + throw QueryCompilationErrors.doubleNamedArgumentReference( + functionName, parameterName) + } + namedParametersSet.add(parameterName) + namedArg + case (_, index) => + throw QueryCompilationErrors.unexpectedPositionalArgument( + functionName, namedArgs(index - 1).asInstanceOf[NamedArgumentExpression].key) + }) + } + /** * This method is the default routine which rearranges the arguments in positional order according * to the function signature provided. This will also fill in any default values that exists for @@ -93,7 +125,7 @@ object NamedParametersSupport { functionName, functionSignature) } - val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression]) + val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, functionName) val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) // The following loop checks for the following: @@ -102,28 +134,16 @@ object NamedParametersSupport { val allParameterNames: Seq[String] = parameters.map(_.name) val parameterNamesSet: Set[String] = allParameterNames.toSet val positionalParametersSet = allParameterNames.take(positionalArgs.size).toSet - val namedParametersSet = collection.mutable.Set[String]() - namedArgs.zipWithIndex.foreach { case (arg, index) => - arg match { - case namedArg: NamedArgumentExpression => - val parameterName = namedArg.key - if (!parameterNamesSet.contains(parameterName)) { - throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, - parameterNamesSet.toSeq) - } - if (positionalParametersSet.contains(parameterName)) { - throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( - functionName, namedArg.key) - } - if (namedParametersSet.contains(parameterName)) { - throw QueryCompilationErrors.doubleNamedArgumentReference( - functionName, namedArg.key) - } - namedParametersSet.add(namedArg.key) - case _ => - throw QueryCompilationErrors.unexpectedPositionalArgument( - functionName, namedArgs(index - 1).asInstanceOf[NamedArgumentExpression].key) + namedArgs.foreach { namedArg => + val parameterName = namedArg.key + if (!parameterNamesSet.contains(parameterName)) { + throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, + parameterNamesSet.toSeq) + } + if (positionalParametersSet.contains(parameterName)) { + throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( + functionName, namedArg.key) } } @@ -136,8 +156,7 @@ object NamedParametersSupport { } // This constructs a map from argument name to value for argument rearrangement. - val namedArgMap = namedArgs.map { arg => - val namedArg = arg.asInstanceOf[NamedArgumentExpression] + val namedArgMap = namedArgs.map { namedArg => namedArg.key -> namedArg.value }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 9c0addfd2ae..8ebd8a3a106 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} @@ -52,7 +53,7 @@ case class ArrowEvalPythonUDTFExec( private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override protected def evaluate( - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[Iterator[InternalRow]] = { @@ -64,7 +65,7 @@ case class ArrowEvalPythonUDTFExec( val columnarBatchIter = new ArrowPythonUDTFRunner( udtf, evalType, - argOffsets, + argMetas, schema, sessionLocalTimeZone, largeVarTypes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 1dd06c2dc73..c0fa8b58bee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -23,6 +23,7 @@ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDTF import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -33,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class ArrowPythonUDTFRunner( udtf: PythonUDTF, evalType: Int, - offsets: Array[Int], + argMetas: Array[ArgumentMetadata], protected override val schema: StructType, protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, @@ -41,7 +42,8 @@ class ArrowPythonUDTFRunner( val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(offsets), jobArtifactUUID) + Seq(ChainedPythonFunctions(Seq(udtf.func))), + evalType, Array(argMetas.map(_.offset)), jobArtifactUUID) with BasicPythonArrowInput with BasicPythonArrowOutput { @@ -49,7 +51,7 @@ class ArrowPythonUDTFRunner( dataOut: DataOutputStream, funcs: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]]): Unit = { - PythonUDTFRunner.writeUDTF(dataOut, udtf, offsets) + PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } override val pythonExec: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 6c8412f8b37..cbc90f34a37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.types.StructType /** @@ -55,7 +56,7 @@ case class BatchEvalPythonUDTFExec( * an iterator of internal rows for every input row. */ override protected def evaluate( - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[Iterator[InternalRow]] = { @@ -66,7 +67,7 @@ case class BatchEvalPythonUDTFExec( // Output iterator for results from Python. val outputIterator = - new PythonUDTFRunner(udtf, argOffsets, pythonMetrics, jobArtifactUUID) + new PythonUDTFRunner(udtf, argMetas, pythonMetrics, jobArtifactUUID) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -93,12 +94,12 @@ case class BatchEvalPythonUDTFExec( class PythonUDTFRunner( udtf: PythonUDTF, - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonUDFRunner( Seq(ChainedPythonFunctions(Seq(udtf.func))), - PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, jobArtifactUUID) { + PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), pythonMetrics, jobArtifactUUID) { protected override def newWriter( env: SparkEnv, @@ -109,7 +110,7 @@ class PythonUDTFRunner( new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDTFRunner.writeUDTF(dataOut, udtf, argOffsets) + PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } } } @@ -117,10 +118,21 @@ class PythonUDTFRunner( object PythonUDTFRunner { - def writeUDTF(dataOut: DataOutputStream, udtf: PythonUDTF, argOffsets: Array[Int]): Unit = { - dataOut.writeInt(argOffsets.length) - argOffsets.foreach { offset => - dataOut.writeInt(offset) + def writeUDTF( + dataOut: DataOutputStream, + udtf: PythonUDTF, + argMetas: Array[ArgumentMetadata]): Unit = { + dataOut.writeInt(argMetas.length) + argMetas.foreach { + case ArgumentMetadata(offset, name) => + dataOut.writeInt(offset) + name match { + case Some(name) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(name, dataOut) + case _ => + dataOut.writeBoolean(false) + } } dataOut.writeInt(udtf.func.command.length) dataOut.write(udtf.func.command.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala index fab417a0f86..410209e0ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala @@ -26,9 +26,20 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils +object EvalPythonUDTFExec { + /** + * Metadata for arguments of Python UDTF. + * + * @param offset the offset of the argument + * @param name the name of the argument if it's a `NamedArgumentExpression` + */ + case class ArgumentMetadata(offset: Int, name: Option[String]) +} + /** * A physical plan that evaluates a [[PythonUDTF]], one partition of tuples at a time. * This is similar to [[EvalPythonExec]]. @@ -45,7 +56,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { override def producedAttributes: AttributeSet = AttributeSet(resultAttrs) protected def evaluate( - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[Iterator[InternalRow]] @@ -68,13 +79,19 @@ trait EvalPythonUDTFExec extends UnaryExecNode { // flatten all the arguments val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = udtf.children.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + val argMetas = udtf.children.map { e => + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (allInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + allInputs += value + dataTypes += value.dataType + ArgumentMetadata(allInputs.length - 1, key) } }.toArray val projection = MutableProjection.create(allInputs.toSeq, child.output) @@ -93,7 +110,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { projection(inputRow) } - val outputRowIterator = evaluate(argOffsets, projectedRowIter, schema, context) + val outputRowIterator = evaluate(argMetas, projectedRowIter, schema, context) val pruneChildForResult: InternalRow => InternalRow = if (child.outputSet == AttributeSet(requiredChildOutput)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 5fa9c89b3d1..38d521c16d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, PythonUDAF, PythonUDF, PythonUDTF, UnresolvedPolymorphicPythonUDTF} -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -101,6 +101,13 @@ case class UserDefinedPythonTableFunction( } def builder(exprs: Seq[Expression]): LogicalPlan = { + /* + * Check if the named arguments: + * - don't have duplicated names + * - don't contain positional arguments + */ + NamedParametersSupport.splitAndCheckNamedArguments(exprs, name) + val udtf = returnType match { case Some(rt) => PythonUDTF( @@ -213,8 +220,6 @@ object UserDefinedPythonTableFunction { val bufferStream = new DirectByteBufferOutputStream() try { val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize)) - val dataIn = new DataInputStream(new BufferedInputStream( - new WorkerInputStream(worker, bufferStream), bufferSize)) PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) @@ -237,11 +242,22 @@ object UserDefinedPythonTableFunction { dataOut.writeBoolean(false) } dataOut.writeBoolean(is_table) + // If the expr is NamedArgumentExpression, send its name. + expr match { + case NamedArgumentExpression(key, _) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(key, dataOut) + case _ => + dataOut.writeBoolean(false) + } } dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() + val dataIn = new DataInputStream(new BufferedInputStream( + new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize)) + // Receive the schema val schema = dataIn.readInt() match { case length if length >= 0 => @@ -273,9 +289,13 @@ object UserDefinedPythonTableFunction { case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } finally { - if (!releasedOrClosed) { - // An error happened. Force to close the worker. - env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + try { + bufferStream.close() + } finally { + if (!releasedOrClosed) { + // An error happened. Force to close the worker. + env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } } } } @@ -288,8 +308,7 @@ object UserDefinedPythonTableFunction { * This is a port and simplified version of `PythonRunner.ReaderInputStream`, * and only supports to write all at once and then read all. */ - private class WorkerInputStream( - worker: PythonWorker, bufferStream: DirectByteBufferOutputStream) extends InputStream { + private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream { private[this] val temp = new Array[Byte](1) @@ -312,14 +331,15 @@ object UserDefinedPythonTableFunction { n = worker.channel.read(buf) } if (worker.selectionKey.isWritable) { - val buffer = bufferStream.toByteBuffer var acceptsInput = true while (acceptsInput && buffer.hasRemaining) { val n = worker.channel.write(buffer) acceptsInput = n > 0 } - // We no longer have any data to write to the socket. - worker.selectionKey.interestOps(SelectionKey.OP_READ) + if (!buffer.hasRemaining) { + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + } } } n --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org