This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d4629563492 [SPARK-44749][SQL][PYTHON] Support named arguments in 
Python UDTF
d4629563492 is described below

commit d4629563492ec3090b5bd5b924790507c42f4e86
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Aug 14 08:57:55 2023 -0700

    [SPARK-44749][SQL][PYTHON] Support named arguments in Python UDTF
    
    ### What changes were proposed in this pull request?
    
    Supports named arguments in Python UDTF.
    
    For example:
    
    ```py
    >>> udtf(returnType="a: int")
    ... class TestUDTF:
    ...     def eval(self, a, b):
    ...         yield a,
    ...
    >>> spark.udtf.register("test_udtf", TestUDTF)
    
    >>> TestUDTF(a=lit(10), b=lit("x")).show()
    +---+
    |  a|
    +---+
    | 10|
    +---+
    
    >>> TestUDTF(b=lit("x"), a=lit(10)).show()
    +---+
    |  a|
    +---+
    | 10|
    +---+
    
    >>> spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')").show()
    +---+
    |  a|
    +---+
    | 10|
    +---+
    
    >>> spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)").show()
    +---+
    |  a|
    +---+
    | 10|
    +---+
    ```
    
    or:
    
    ```py
    >>> udtf
    ... class TestUDTF:
    ...     staticmethod
    ...     def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
    ...         return AnalyzeResult(
    ...             StructType(
    ...                 [StructField(key, arg.data_type) for key, arg in 
sorted(kwargs.items())]
    ...             )
    ...         )
    ...     def eval(self, **kwargs):
    ...         yield tuple(value for _, value in sorted(kwargs.items()))
    ...
    >>> spark.udtf.register("test_udtf", TestUDTF)
    
    >>> spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x', x=>100.0)").show()
    +---+---+-----+
    |  a|  b|    x|
    +---+---+-----+
    | 10|  x|100.0|
    +---+---+-----+
    
    >>> spark.sql("SELECT * FROM test_udtf(x=>10, a=>'x', z=>100.0)").show()
    +---+---+-----+
    |  a|  x|    z|
    +---+---+-----+
    |  x| 10|100.0|
    +---+---+-----+
    ```
    
    ### Why are the changes needed?
    
    Now that named arguments are supported 
(https://github.com/apache/spark/pull/41796, 
https://github.com/apache/spark/pull/42020).
    
    It should be supported in Python UDTF.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, named arguments will be available for Python UDTF.
    
    ### How was this patch tested?
    
    Added related tests.
    
    Closes #42422 from ueshin/issues/SPARK-44749/kwargs.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 .../main/protobuf/spark/connect/expressions.proto  |   9 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |   7 ++
 python/pyspark/sql/column.py                       |  14 +++
 python/pyspark/sql/connect/expressions.py          |  20 ++++
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 122 +++++++++++----------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  34 ++++++
 python/pyspark/sql/connect/udtf.py                 |  18 +--
 python/pyspark/sql/functions.py                    |  38 +++++++
 python/pyspark/sql/tests/test_udtf.py              |  88 +++++++++++++++
 python/pyspark/sql/udtf.py                         |  36 ++++--
 python/pyspark/sql/worker/analyze_udtf.py          |  20 +++-
 python/pyspark/worker.py                           |  29 ++++-
 .../plans/logical/FunctionBuilderBase.scala        |  67 +++++++----
 .../execution/python/ArrowEvalPythonUDTFExec.scala |   5 +-
 .../execution/python/ArrowPythonUDTFRunner.scala   |   8 +-
 .../execution/python/BatchEvalPythonUDTFExec.scala |  30 +++--
 .../sql/execution/python/EvalPythonUDTFExec.scala  |  33 ++++--
 .../python/UserDefinedPythonFunction.scala         |  42 +++++--
 18 files changed, 472 insertions(+), 148 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 557b9db9123..b222f663cd0 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -47,6 +47,7 @@ message Expression {
     UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
     CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
     CallFunction call_function = 16;
+    NamedArgumentExpression named_argument_expression = 17;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // relations they can add them here. During the planning the correct 
resolution is done.
@@ -380,3 +381,11 @@ message CallFunction {
   // (Optional) Function arguments. Empty arguments are allowed.
   repeated Expression arguments = 2;
 }
+
+message NamedArgumentExpression {
+  // (Required) The key of the named argument.
+  string key = 1;
+
+  // (Required) The value expression of the named argument.
+  Expression value = 2;
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e6305cd9d1a..4ab24cb058b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1384,6 +1384,8 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
       case proto.Expression.ExprTypeCase.CALL_FUNCTION =>
         transformCallFunction(exp.getCallFunction)
+      case proto.Expression.ExprTypeCase.NAMED_ARGUMENT_EXPRESSION =>
+        transformNamedArgumentExpression(exp.getNamedArgumentExpression)
       case _ =>
         throw InvalidPlanInput(
           s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not 
supported")
@@ -1505,6 +1507,11 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       false)
   }
 
+  private def transformNamedArgumentExpression(
+      namedArg: proto.NamedArgumentExpression): Expression = {
+    NamedArgumentExpression(namedArg.getKey, 
transformExpression(namedArg.getValue))
+  }
+
   private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket 
= {
     unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf)
   }
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 087cfaaa20b..3a6d6e1cea7 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -73,9 +73,23 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject:
     return _to_java_column(col).expr()
 
 
+@overload
+def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject:
+    pass
+
+
+@overload
 def _to_seq(
     sc: SparkContext,
     cols: Iterable["ColumnOrName"],
+    converter: Optional[Callable[["ColumnOrName"], JavaObject]],
+) -> JavaObject:
+    pass
+
+
+def _to_seq(
+    sc: SparkContext,
+    cols: Union[Iterable["ColumnOrName"], Iterable[JavaObject]],
     converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
 ) -> JavaObject:
     """
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index d0a9b1d69ae..34aa4da1117 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1054,3 +1054,23 @@ class CallFunction(Expression):
             return f"CallFunction('{self._name}', {', '.join([str(arg) for arg 
in self._args])})"
         else:
             return f"CallFunction('{self._name}')"
+
+
+class NamedArgumentExpression(Expression):
+    def __init__(self, key: str, value: Expression):
+        super().__init__()
+
+        assert isinstance(key, str)
+        self._key = key
+
+        assert isinstance(value, Expression)
+        self._value = value
+
+    def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
+        expr = proto.Expression()
+        expr.named_argument_expression.key = self._key
+        
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
+        return expr
+
+    def __repr__(self) -> str:
+        return f"{self._key} => {self._value}"
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 51d1a5d48a1..51ad47bb1c8 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xd9+\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbf,\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -45,63 +45,65 @@ if _descriptor._USE_C_DESCRIPTORS == False:
         b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
     )
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 5698
-    _EXPRESSION_WINDOW._serialized_start = 1543
-    _EXPRESSION_WINDOW._serialized_end = 2326
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1833
-    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2326
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2100
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2245
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2247
-    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2326
-    _EXPRESSION_SORTORDER._serialized_start = 2329
-    _EXPRESSION_SORTORDER._serialized_end = 2754
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2559
-    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2667
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2669
-    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2754
-    _EXPRESSION_CAST._serialized_start = 2757
-    _EXPRESSION_CAST._serialized_end = 2902
-    _EXPRESSION_LITERAL._serialized_start = 2905
-    _EXPRESSION_LITERAL._serialized_end = 4468
-    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3740
-    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3857
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3859
-    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3957
-    _EXPRESSION_LITERAL_ARRAY._serialized_start = 3960
-    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4090
-    _EXPRESSION_LITERAL_MAP._serialized_start = 4093
-    _EXPRESSION_LITERAL_MAP._serialized_end = 4320
-    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4323
-    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4452
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4470
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4582
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4585
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4789
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4791
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4841
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4843
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4925
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4927
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5013
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5016
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5148
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 5151
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 5338
-    _EXPRESSION_ALIAS._serialized_start = 5340
-    _EXPRESSION_ALIAS._serialized_end = 5460
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5463
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5621
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5623
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5685
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5701
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6065
-    _PYTHONUDF._serialized_start = 6068
-    _PYTHONUDF._serialized_end = 6223
-    _SCALARSCALAUDF._serialized_start = 6226
-    _SCALARSCALAUDF._serialized_end = 6410
-    _JAVAUDF._serialized_start = 6413
-    _JAVAUDF._serialized_end = 6562
-    _CALLFUNCTION._serialized_start = 6564
-    _CALLFUNCTION._serialized_end = 6672
+    _EXPRESSION._serialized_end = 5800
+    _EXPRESSION_WINDOW._serialized_start = 1645
+    _EXPRESSION_WINDOW._serialized_end = 2428
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935
+    _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2428
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2202
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2347
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2349
+    _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2428
+    _EXPRESSION_SORTORDER._serialized_start = 2431
+    _EXPRESSION_SORTORDER._serialized_end = 2856
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2661
+    _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2769
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771
+    _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856
+    _EXPRESSION_CAST._serialized_start = 2859
+    _EXPRESSION_CAST._serialized_end = 3004
+    _EXPRESSION_LITERAL._serialized_start = 3007
+    _EXPRESSION_LITERAL._serialized_end = 4570
+    _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3842
+    _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3959
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3961
+    _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4059
+    _EXPRESSION_LITERAL_ARRAY._serialized_start = 4062
+    _EXPRESSION_LITERAL_ARRAY._serialized_end = 4192
+    _EXPRESSION_LITERAL_MAP._serialized_start = 4195
+    _EXPRESSION_LITERAL_MAP._serialized_end = 4422
+    _EXPRESSION_LITERAL_STRUCT._serialized_start = 4425
+    _EXPRESSION_LITERAL_STRUCT._serialized_end = 4554
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4572
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4684
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4687
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4891
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4893
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4943
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4945
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5027
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5029
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5115
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5118
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5250
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 5253
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 5440
+    _EXPRESSION_ALIAS._serialized_start = 5442
+    _EXPRESSION_ALIAS._serialized_end = 5562
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5565
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5723
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5725
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5787
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5803
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6167
+    _PYTHONUDF._serialized_start = 6170
+    _PYTHONUDF._serialized_end = 6325
+    _SCALARSCALAUDF._serialized_start = 6328
+    _SCALARSCALAUDF._serialized_end = 6512
+    _JAVAUDF._serialized_start = 6515
+    _JAVAUDF._serialized_end = 6664
+    _CALLFUNCTION._serialized_start = 6666
+    _CALLFUNCTION._serialized_end = 6774
+    _NAMEDARGUMENTEXPRESSION._serialized_start = 6776
+    _NAMEDARGUMENTEXPRESSION._serialized_end = 6868
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index b9b16ce35e3..2b418ef23f6 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1102,6 +1102,7 @@ class Expression(google.protobuf.message.Message):
     UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
     COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
     CALL_FUNCTION_FIELD_NUMBER: builtins.int
+    NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def literal(self) -> global___Expression.Literal: ...
@@ -1138,6 +1139,8 @@ class Expression(google.protobuf.message.Message):
     @property
     def call_function(self) -> global___CallFunction: ...
     @property
+    def named_argument_expression(self) -> global___NamedArgumentExpression: 
...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
         relations they can add them here. During the planning the correct 
resolution is done.
@@ -1162,6 +1165,7 @@ class Expression(google.protobuf.message.Message):
         | None = ...,
         common_inline_user_defined_function: 
global___CommonInlineUserDefinedFunction | None = ...,
         call_function: global___CallFunction | None = ...,
+        named_argument_expression: global___NamedArgumentExpression | None = 
...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -1185,6 +1189,8 @@ class Expression(google.protobuf.message.Message):
             b"lambda_function",
             "literal",
             b"literal",
+            "named_argument_expression",
+            b"named_argument_expression",
             "sort_order",
             b"sort_order",
             "unresolved_attribute",
@@ -1226,6 +1232,8 @@ class Expression(google.protobuf.message.Message):
             b"lambda_function",
             "literal",
             b"literal",
+            "named_argument_expression",
+            b"named_argument_expression",
             "sort_order",
             b"sort_order",
             "unresolved_attribute",
@@ -1265,6 +1273,7 @@ class Expression(google.protobuf.message.Message):
         "unresolved_named_lambda_variable",
         "common_inline_user_defined_function",
         "call_function",
+        "named_argument_expression",
         "extension",
     ] | None: ...
 
@@ -1505,3 +1514,28 @@ class CallFunction(google.protobuf.message.Message):
     ) -> None: ...
 
 global___CallFunction = CallFunction
+
+class NamedArgumentExpression(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    KEY_FIELD_NUMBER: builtins.int
+    VALUE_FIELD_NUMBER: builtins.int
+    key: builtins.str
+    """(Required) The key of the named argument."""
+    @property
+    def value(self) -> global___Expression:
+        """(Required) The value expression of the named argument."""
+    def __init__(
+        self,
+        *,
+        key: builtins.str = ...,
+        value: global___Expression | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["value", b"value"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["key", b"key", "value", 
b"value"]
+    ) -> None: ...
+
+global___NamedArgumentExpression = NamedArgumentExpression
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
index c8495626292..ce37832854c 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -22,11 +22,11 @@ from pyspark.sql.connect.utils import check_dependencies
 check_dependencies(__name__)
 
 import warnings
-from typing import Type, TYPE_CHECKING, Optional, Union
+from typing import List, Type, TYPE_CHECKING, Optional, Union
 
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.connect.column import Column
-from pyspark.sql.connect.expressions import ColumnReference
+from pyspark.sql.connect.expressions import ColumnReference, Expression, 
NamedArgumentExpression
 from pyspark.sql.connect.plan import (
     CommonInlineUserDefinedTableFunction,
     PythonUDTF,
@@ -146,12 +146,14 @@ class UserDefinedTableFunction:
         self.deterministic = deterministic
 
     def _build_common_inline_user_defined_table_function(
-        self, *cols: "ColumnOrName"
+        self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
     ) -> CommonInlineUserDefinedTableFunction:
-        arg_cols = [
-            col if isinstance(col, Column) else Column(ColumnReference(col)) 
for col in cols
+        def to_expr(col: "ColumnOrName") -> Expression:
+            return col._expr if isinstance(col, Column) else 
ColumnReference(col)
+
+        arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
+            NamedArgumentExpression(key, to_expr(value)) for key, value in 
kwargs.items()
         ]
-        arg_exprs = [col._expr for col in arg_cols]
 
         udtf = PythonUDTF(
             func=self.func,
@@ -166,13 +168,13 @@ class UserDefinedTableFunction:
             arguments=arg_exprs,
         )
 
-    def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
+    def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> 
"DataFrame":
         from pyspark.sql.connect.session import SparkSession
         from pyspark.sql.connect.dataframe import DataFrame
 
         session = SparkSession.active()
 
-        plan = self._build_common_inline_user_defined_table_function(*cols)
+        plan = self._build_common_inline_user_defined_table_function(*args, 
**kwargs)
         return DataFrame.withPlan(plan, session)
 
     def asNondeterministic(self) -> "UserDefinedTableFunction":
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index fdb4ec8111e..9cc364cc1f8 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -15547,6 +15547,12 @@ def udtf(
 
     .. versionadded:: 3.5.0
 
+    .. versionchanged:: 4.0.0
+        Supports Python side analysis.
+
+    .. versionchanged:: 4.0.0
+        Supports keyword-arguments.
+
     Parameters
     ----------
     cls : class
@@ -15623,6 +15629,38 @@ def udtf(
     |  1|  x|
     +---+---+
 
+    UDTF can use keyword arguments:
+
+    >>> @udtf
+    ... class TestUDTFWithKwargs:
+    ...     @staticmethod
+    ...     def analyze(
+    ...         a: AnalyzeArgument, b: AnalyzeArgument, **kwargs: 
AnalyzeArgument
+    ...     ) -> AnalyzeResult:
+    ...         return AnalyzeResult(
+    ...             StructType().add("a", a.data_type)
+    ...                 .add("b", b.data_type)
+    ...                 .add("x", kwargs["x"].data_type)
+    ...         )
+    ...
+    ...     def eval(self, a, b, **kwargs):
+    ...         yield a, b, kwargs["x"]
+    ...
+    >>> TestUDTFWithKwargs(lit(1), x=lit("x"), b=lit("b")).show()
+    +---+---+---+
+    |  a|  b|  x|
+    +---+---+---+
+    |  1|  b|  x|
+    +---+---+---+
+
+    >>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs)
+    >>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show()
+    +---+---+---+
+    |  a|  b|  x|
+    +---+---+---+
+    |  1|  b|  x|
+    +---+---+---+
+
     Arrow optimization can be explicitly enabled when creating UDTFs:
 
     >>> @udtf(returnType="c1: int, c2: int", useArrow=True)
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 300067716e9..cd0604ccace 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -1565,6 +1565,7 @@ class BaseUDTFTestsMixin:
         expected = [Row(c1="hello", c2="world")]
         assertDataFrameEqual(TestUDTF(), expected)
         assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), 
expected)
+        assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), 
expected)
 
         with self.assertRaisesRegex(
             AnalysisException, r"analyze\(\) takes 0 positional arguments but 
1 was given"
@@ -1795,6 +1796,93 @@ class BaseUDTFTestsMixin:
                     assertSchemaEqual(df.schema, StructType().add("col1", 
IntegerType()))
                     assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])
 
+    def test_udtf_with_named_arguments(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a, b):
+                yield a,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        for i, df in enumerate(
+            [
+                self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+                self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
+                TestUDTF(a=lit(10), b=lit("x")),
+                TestUDTF(b=lit("x"), a=lit(10)),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(a=10)])
+
+    def test_udtf_with_named_arguments_negative(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a, b):
+                yield a,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        with self.assertRaisesRegex(
+            AnalysisException,
+            
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+        ):
+            self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show()
+
+        with self.assertRaisesRegex(AnalysisException, 
"UNEXPECTED_POSITIONAL_ARGUMENT"):
+            self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show()
+
+        with self.assertRaisesRegex(
+            PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
+        ):
+            self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show()
+
+    def test_udtf_with_kwargs(self):
+        @udtf(returnType="a: int, b: string")
+        class TestUDTF:
+            def eval(self, **kwargs):
+                yield kwargs["a"], kwargs["b"]
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        for i, df in enumerate(
+            [
+                self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+                self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
+                TestUDTF(a=lit(10), b=lit("x")),
+                TestUDTF(b=lit("x"), a=lit(10)),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(a=10, b="x")])
+
+    def test_udtf_with_analyze_kwargs(self):
+        @udtf
+        class TestUDTF:
+            @staticmethod
+            def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
+                return AnalyzeResult(
+                    StructType(
+                        [StructField(key, arg.data_type) for key, arg in 
sorted(kwargs.items())]
+                    )
+                )
+
+            def eval(self, **kwargs):
+                yield tuple(value for _, value in sorted(kwargs.items()))
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        for i, df in enumerate(
+            [
+                self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
+                self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
+                TestUDTF(a=lit(10), b=lit("x")),
+                TestUDTF(b=lit("x"), a=lit(10)),
+            ]
+        ):
+            with self.subTest(query_no=i):
+                assertDataFrameEqual(df, [Row(a=10, b="x")])
+
 
 class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 027a2646a46..1ca87aae758 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -29,7 +29,7 @@ from py4j.java_gateway import JavaObject
 
 from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, 
PySparkTypeError
 from pyspark.rdd import PythonEvalType
-from pyspark.sql.column import _to_java_column, _to_seq
+from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
 from pyspark.sql.types import DataType, StructType, _parse_datatype_string
 from pyspark.sql.udf import _wrap_function
@@ -148,9 +148,9 @@ def _vectorize_udtf(cls: Type) -> Type:
     # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
     def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]:
         @wraps(f)
-        def evaluate(*a: Any) -> Any:
+        def evaluate(*a: Any, **kw: Any) -> Any:
             try:
-                return f(*a)
+                return f(*a, **kw)
             except Exception as e:
                 raise PySparkRuntimeError(
                     error_class="UDTF_EXEC_ERROR",
@@ -168,18 +168,22 @@ def _vectorize_udtf(cls: Type) -> Type:
         ):
 
             @staticmethod
-            def analyze(*args: AnalyzeArgument) -> AnalyzeResult:
-                return cls.analyze(*args)
+            def analyze(*args: AnalyzeArgument, **kwargs: AnalyzeArgument) -> 
AnalyzeResult:
+                return cls.analyze(*args, **kwargs)
 
-        def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]:
-            if len(args) == 0:
+        def eval(self, *args: pd.Series, **kwargs: pd.Series) -> 
Iterator[pd.DataFrame]:
+            if len(args) == 0 and len(kwargs) == 0:
                 yield pd.DataFrame(wrap_func(self.func.eval)())
             else:
                 # Create tuples from the input pandas Series, each tuple
                 # represents a row across all Series.
-                row_tuples = zip(*args)
+                keys = list(kwargs.keys())
+                len_args = len(args)
+                row_tuples = zip(*args, *[kwargs[key] for key in keys])
                 for row in row_tuples:
-                    res = wrap_func(self.func.eval)(*row)
+                    res = wrap_func(self.func.eval)(
+                        *row[:len_args], **{key: row[len_args + i] for i, key 
in enumerate(keys)}
+                    )
                     if res is not None and not isinstance(res, Iterable):
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_NOT_ITERABLE",
@@ -339,14 +343,24 @@ class UserDefinedTableFunction:
             )
         return judtf
 
-    def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
+    def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> 
"DataFrame":
         from pyspark.sql import DataFrame, SparkSession
 
         spark = SparkSession._getActiveSessionOrCreate()
         sc = spark.sparkContext
 
+        assert sc._jvm is not None
+        jcols = [_to_java_column(arg) for arg in args] + [
+            sc._jvm.Column(
+                
sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression(
+                    key, _to_java_expr(value)
+                )
+            )
+            for key, value in kwargs.items()
+        ]
+
         judtf = self._judtf
-        jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, 
_to_java_column))
+        jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, jcols))
         return DataFrame(jPythonUDTF, spark)
 
     def asNondeterministic(self) -> "UserDefinedTableFunction":
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 9ffa03541e6..7ba0789fa7b 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -19,7 +19,7 @@ import inspect
 import os
 import sys
 import traceback
-from typing import List, IO
+from typing import Dict, List, IO, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkRuntimeError, PySparkValueError
@@ -69,11 +69,12 @@ def read_udtf(infile: IO) -> type:
     return handler
 
 
-def read_arguments(infile: IO) -> List[AnalyzeArgument]:
+def read_arguments(infile: IO) -> Tuple[List[AnalyzeArgument], Dict[str, 
AnalyzeArgument]]:
     """Reads the arguments for `analyze` static method."""
     # Receive arguments
     num_args = read_int(infile)
     args: List[AnalyzeArgument] = []
+    kwargs: Dict[str, AnalyzeArgument] = {}
     for _ in range(num_args):
         dt = _parse_datatype_json_string(utf8_deserializer.loads(infile))
         if read_bool(infile):  # is foldable
@@ -83,8 +84,15 @@ def read_arguments(infile: IO) -> List[AnalyzeArgument]:
         else:
             value = None
         is_table = read_bool(infile)  # is table argument
-        args.append(AnalyzeArgument(data_type=dt, value=value, 
is_table=is_table))
-    return args
+        argument = AnalyzeArgument(data_type=dt, value=value, 
is_table=is_table)
+
+        is_named_arg = read_bool(infile)
+        if is_named_arg:
+            name = utf8_deserializer.loads(infile)
+            kwargs[name] = argument
+        else:
+            args.append(argument)
+    return args, kwargs
 
 
 def main(infile: IO, outfile: IO) -> None:
@@ -107,9 +115,9 @@ def main(infile: IO, outfile: IO) -> None:
         _accumulatorRegistry.clear()
 
         handler = read_udtf(infile)
-        args = read_arguments(infile)
+        args, kwargs = read_arguments(infile)
 
-        result = handler.analyze(*args)  # type: ignore[attr-defined]
+        result = handler.analyze(*args, **kwargs)  # type: ignore[attr-defined]
 
         if not isinstance(result, AnalyzeResult):
             raise PySparkValueError(
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 6f27400387e..8916a794001 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -550,7 +550,16 @@ def read_udtf(pickleSer, infile, eval_type):
 
     # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
     num_arg = read_int(infile)
-    arg_offsets = [read_int(infile) for _ in range(num_arg)]
+    args_offsets = []
+    kwargs_offsets = {}
+    for _ in range(num_arg):
+        offset = read_int(infile)
+        if read_bool(infile):
+            name = utf8_deserializer.loads(infile)
+            kwargs_offsets[name] = offset
+        else:
+            args_offsets.append(offset)
+
     handler = read_command(pickleSer, infile)
     if not isinstance(handler, type):
         raise PySparkRuntimeError(
@@ -619,7 +628,9 @@ def read_udtf(pickleSer, infile, eval_type):
                 )
                 return result
 
-            return lambda *a: map(lambda res: (res, arrow_return_type), 
map(verify_result, f(*a)))
+            return lambda *a, **kw: map(
+                lambda res: (res, arrow_return_type), map(verify_result, f(*a, 
**kw))
+            )
 
         eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type)
 
@@ -633,7 +644,10 @@ def read_udtf(pickleSer, infile, eval_type):
                 for a in it:
                     # The eval function yields an iterator. Each element 
produced by this
                     # iterator is a tuple in the form of (pandas.DataFrame, 
arrow_return_type).
-                    yield from eval(*[a[o] for o in arg_offsets])
+                    yield from eval(
+                        *[a[o] for o in args_offsets],
+                        **{k: a[o] for k, o in kwargs_offsets.items()},
+                    )
             finally:
                 if terminate is not None:
                     yield from terminate()
@@ -667,9 +681,9 @@ def read_udtf(pickleSer, infile, eval_type):
                 return toInternal(result)
 
             # Evaluate the function and return a tuple back to the executor.
-            def evaluate(*a) -> tuple:
+            def evaluate(*a, **kw) -> tuple:
                 try:
-                    res = f(*a)
+                    res = f(*a, **kw)
                 except Exception as e:
                     raise PySparkRuntimeError(
                         error_class="UDTF_EXEC_ERROR",
@@ -705,7 +719,10 @@ def read_udtf(pickleSer, infile, eval_type):
         def mapper(_, it):
             try:
                 for a in it:
-                    yield eval(*[a[o] for o in arg_offsets])
+                    yield eval(
+                        *[a[o] for o in args_offsets],
+                        **{k: a[o] for k, o in kwargs_offsets.items()},
+                    )
             finally:
                 if terminate is not None:
                     yield terminate()
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
index 1088655f60c..d13bfab6d70 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
@@ -72,6 +72,38 @@ trait FunctionBuilderBase[T] {
 }
 
 object NamedParametersSupport {
+  /**
+   * This method splits named arguments from the argument list.
+   * Also checks if:
+   * - the named arguments don't contains positional arguments once keyword 
arguments start
+   * - the named arguments don't use the duplicated names
+   *
+   * @param functionSignature The function signature that defines the 
positional ordering
+   * @param args The argument list provided in function invocation
+   * @return A tuple of a list of positional arguments and a list of keyword 
arguments
+   */
+  def splitAndCheckNamedArguments(
+      args: Seq[Expression],
+      functionName: String): (Seq[Expression], Seq[NamedArgumentExpression]) = 
{
+    val (positionalArgs, namedArgs) = 
args.span(!_.isInstanceOf[NamedArgumentExpression])
+
+    val namedParametersSet = collection.mutable.Set[String]()
+
+    (positionalArgs,
+      namedArgs.zipWithIndex.map {
+        case (namedArg @ NamedArgumentExpression(parameterName, _), _) =>
+          if (namedParametersSet.contains(parameterName)) {
+            throw QueryCompilationErrors.doubleNamedArgumentReference(
+              functionName, parameterName)
+          }
+          namedParametersSet.add(parameterName)
+          namedArg
+        case (_, index) =>
+          throw QueryCompilationErrors.unexpectedPositionalArgument(
+            functionName, namedArgs(index - 
1).asInstanceOf[NamedArgumentExpression].key)
+      })
+  }
+
   /**
    * This method is the default routine which rearranges the arguments in 
positional order according
    * to the function signature provided. This will also fill in any default 
values that exists for
@@ -93,7 +125,7 @@ object NamedParametersSupport {
         functionName, functionSignature)
     }
 
-    val (positionalArgs, namedArgs) = 
args.span(!_.isInstanceOf[NamedArgumentExpression])
+    val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, 
functionName)
     val namedParameters: Seq[InputParameter] = 
parameters.drop(positionalArgs.size)
 
     // The following loop checks for the following:
@@ -102,28 +134,16 @@ object NamedParametersSupport {
     val allParameterNames: Seq[String] = parameters.map(_.name)
     val parameterNamesSet: Set[String] = allParameterNames.toSet
     val positionalParametersSet = 
allParameterNames.take(positionalArgs.size).toSet
-    val namedParametersSet = collection.mutable.Set[String]()
 
-    namedArgs.zipWithIndex.foreach { case (arg, index) =>
-      arg match {
-        case namedArg: NamedArgumentExpression =>
-          val parameterName = namedArg.key
-          if (!parameterNamesSet.contains(parameterName)) {
-            throw 
QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key,
-              parameterNamesSet.toSeq)
-          }
-          if (positionalParametersSet.contains(parameterName)) {
-            throw 
QueryCompilationErrors.positionalAndNamedArgumentDoubleReference(
-              functionName, namedArg.key)
-          }
-          if (namedParametersSet.contains(parameterName)) {
-            throw QueryCompilationErrors.doubleNamedArgumentReference(
-              functionName, namedArg.key)
-          }
-          namedParametersSet.add(namedArg.key)
-        case _ =>
-          throw QueryCompilationErrors.unexpectedPositionalArgument(
-            functionName, namedArgs(index - 
1).asInstanceOf[NamedArgumentExpression].key)
+    namedArgs.foreach { namedArg =>
+      val parameterName = namedArg.key
+      if (!parameterNamesSet.contains(parameterName)) {
+        throw QueryCompilationErrors.unrecognizedParameterName(functionName, 
namedArg.key,
+          parameterNamesSet.toSeq)
+      }
+      if (positionalParametersSet.contains(parameterName)) {
+        throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference(
+          functionName, namedArg.key)
       }
     }
 
@@ -136,8 +156,7 @@ object NamedParametersSupport {
     }
 
     // This constructs a map from argument name to value for argument 
rearrangement.
-    val namedArgMap = namedArgs.map { arg =>
-      val namedArg = arg.asInstanceOf[NamedArgumentExpression]
+    val namedArgMap = namedArgs.map { namedArg =>
       namedArg.key -> namedArg.value
     }.toMap
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 9c0addfd2ae..8ebd8a3a106 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
+import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
 
@@ -52,7 +53,7 @@ case class ArrowEvalPythonUDTFExec(
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
 
   override protected def evaluate(
-      argOffsets: Array[Int],
+      argMetas: Array[ArgumentMetadata],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[Iterator[InternalRow]] = {
@@ -64,7 +65,7 @@ case class ArrowEvalPythonUDTFExec(
     val columnarBatchIter = new ArrowPythonUDTFRunner(
       udtf,
       evalType,
-      argOffsets,
+      argMetas,
       schema,
       sessionLocalTimeZone,
       largeVarTypes,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 1dd06c2dc73..c0fa8b58bee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -23,6 +23,7 @@ import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.PythonUDTF
 import org.apache.spark.sql.execution.metric.SQLMetric
+import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -33,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 class ArrowPythonUDTFRunner(
     udtf: PythonUDTF,
     evalType: Int,
-    offsets: Array[Int],
+    argMetas: Array[ArgumentMetadata],
     protected override val schema: StructType,
     protected override val timeZoneId: String,
     protected override val largeVarTypes: Boolean,
@@ -41,7 +42,8 @@ class ArrowPythonUDTFRunner(
     val pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
-      Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(offsets), 
jobArtifactUUID)
+      Seq(ChainedPythonFunctions(Seq(udtf.func))),
+      evalType, Array(argMetas.map(_.offset)), jobArtifactUUID)
   with BasicPythonArrowInput
   with BasicPythonArrowOutput {
 
@@ -49,7 +51,7 @@ class ArrowPythonUDTFRunner(
       dataOut: DataOutputStream,
       funcs: Seq[ChainedPythonFunctions],
       argOffsets: Array[Array[Int]]): Unit = {
-    PythonUDTFRunner.writeUDTF(dataOut, udtf, offsets)
+    PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
   }
 
   override val pythonExec: String =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index 6c8412f8b37..cbc90f34a37 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.execution.metric.SQLMetric
+import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -55,7 +56,7 @@ case class BatchEvalPythonUDTFExec(
    * an iterator of internal rows for every input row.
    */
   override protected def evaluate(
-      argOffsets: Array[Int],
+      argMetas: Array[ArgumentMetadata],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[Iterator[InternalRow]] = {
@@ -66,7 +67,7 @@ case class BatchEvalPythonUDTFExec(
 
     // Output iterator for results from Python.
     val outputIterator =
-      new PythonUDTFRunner(udtf, argOffsets, pythonMetrics, jobArtifactUUID)
+      new PythonUDTFRunner(udtf, argMetas, pythonMetrics, jobArtifactUUID)
         .compute(inputIterator, context.partitionId(), context)
 
     val unpickle = new Unpickler
@@ -93,12 +94,12 @@ case class BatchEvalPythonUDTFExec(
 
 class PythonUDTFRunner(
     udtf: PythonUDTF,
-    argOffsets: Array[Int],
+    argMetas: Array[ArgumentMetadata],
     pythonMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonUDFRunner(
     Seq(ChainedPythonFunctions(Seq(udtf.func))),
-    PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, 
jobArtifactUUID) {
+    PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), 
pythonMetrics, jobArtifactUUID) {
 
   protected override def newWriter(
       env: SparkEnv,
@@ -109,7 +110,7 @@ class PythonUDTFRunner(
     new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        PythonUDTFRunner.writeUDTF(dataOut, udtf, argOffsets)
+        PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas)
       }
     }
   }
@@ -117,10 +118,21 @@ class PythonUDTFRunner(
 
 object PythonUDTFRunner {
 
-  def writeUDTF(dataOut: DataOutputStream, udtf: PythonUDTF, argOffsets: 
Array[Int]): Unit = {
-    dataOut.writeInt(argOffsets.length)
-    argOffsets.foreach { offset =>
-      dataOut.writeInt(offset)
+  def writeUDTF(
+      dataOut: DataOutputStream,
+      udtf: PythonUDTF,
+      argMetas: Array[ArgumentMetadata]): Unit = {
+    dataOut.writeInt(argMetas.length)
+    argMetas.foreach {
+      case ArgumentMetadata(offset, name) =>
+        dataOut.writeInt(offset)
+        name match {
+          case Some(name) =>
+            dataOut.writeBoolean(true)
+            PythonWorkerUtils.writeUTF(name, dataOut)
+          case _ =>
+            dataOut.writeBoolean(false)
+        }
     }
     dataOut.writeInt(udtf.func.command.length)
     dataOut.write(udtf.func.command.toArray)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
index fab417a0f86..410209e0ada 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala
@@ -26,9 +26,20 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.UnaryExecNode
+import 
org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata
 import org.apache.spark.sql.types.{DataType, StructField, StructType}
 import org.apache.spark.util.Utils
 
+object EvalPythonUDTFExec {
+  /**
+   * Metadata for arguments of Python UDTF.
+   *
+   * @param offset the offset of the argument
+   * @param name the name of the argument if it's a `NamedArgumentExpression`
+   */
+  case class ArgumentMetadata(offset: Int, name: Option[String])
+}
+
 /**
  * A physical plan that evaluates a [[PythonUDTF]], one partition of tuples at 
a time.
  * This is similar to [[EvalPythonExec]].
@@ -45,7 +56,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
   override def producedAttributes: AttributeSet = AttributeSet(resultAttrs)
 
   protected def evaluate(
-      argOffsets: Array[Int],
+      argMetas: Array[ArgumentMetadata],
       iter: Iterator[InternalRow],
       schema: StructType,
       context: TaskContext): Iterator[Iterator[InternalRow]]
@@ -68,13 +79,19 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
       // flatten all the arguments
       val allInputs = new ArrayBuffer[Expression]
       val dataTypes = new ArrayBuffer[DataType]
-      val argOffsets = udtf.children.map { e =>
-        if (allInputs.exists(_.semanticEquals(e))) {
-          allInputs.indexWhere(_.semanticEquals(e))
+      val argMetas = udtf.children.map { e =>
+        val (key, value) = e match {
+          case NamedArgumentExpression(key, value) =>
+            (Some(key), value)
+          case _ =>
+            (None, e)
+        }
+        if (allInputs.exists(_.semanticEquals(value))) {
+          ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key)
         } else {
-          allInputs += e
-          dataTypes += e.dataType
-          allInputs.length - 1
+          allInputs += value
+          dataTypes += value.dataType
+          ArgumentMetadata(allInputs.length - 1, key)
         }
       }.toArray
       val projection = MutableProjection.create(allInputs.toSeq, child.output)
@@ -93,7 +110,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode {
         projection(inputRow)
       }
 
-      val outputRowIterator = evaluate(argOffsets, projectedRowIter, schema, 
context)
+      val outputRowIterator = evaluate(argMetas, projectedRowIter, schema, 
context)
 
       val pruneChildForResult: InternalRow => InternalRow =
         if (child.outputSet == AttributeSet(requiredChildOutput)) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 5fa9c89b3d1..38d521c16d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -33,7 +33,7 @@ import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python._
 import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.{Expression, 
FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, PythonUDAF, 
PythonUDF, PythonUDTF, UnresolvedPolymorphicPythonUDTF}
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, 
OneRowRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, 
NamedParametersSupport, OneRowRelation}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -101,6 +101,13 @@ case class UserDefinedPythonTableFunction(
   }
 
   def builder(exprs: Seq[Expression]): LogicalPlan = {
+    /*
+     * Check if the named arguments:
+     * - don't have duplicated names
+     * - don't contain positional arguments
+     */
+    NamedParametersSupport.splitAndCheckNamedArguments(exprs, name)
+
     val udtf = returnType match {
       case Some(rt) =>
         PythonUDTF(
@@ -213,8 +220,6 @@ object UserDefinedPythonTableFunction {
     val bufferStream = new DirectByteBufferOutputStream()
     try {
       val dataOut = new DataOutputStream(new 
BufferedOutputStream(bufferStream, bufferSize))
-      val dataIn = new DataInputStream(new BufferedInputStream(
-        new WorkerInputStream(worker, bufferStream), bufferSize))
 
       PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
       PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, 
dataOut)
@@ -237,11 +242,22 @@ object UserDefinedPythonTableFunction {
           dataOut.writeBoolean(false)
         }
         dataOut.writeBoolean(is_table)
+        // If the expr is NamedArgumentExpression, send its name.
+        expr match {
+          case NamedArgumentExpression(key, _) =>
+            dataOut.writeBoolean(true)
+            PythonWorkerUtils.writeUTF(key, dataOut)
+          case _ =>
+            dataOut.writeBoolean(false)
+        }
       }
 
       dataOut.writeInt(SpecialLengths.END_OF_STREAM)
       dataOut.flush()
 
+      val dataIn = new DataInputStream(new BufferedInputStream(
+        new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
+
       // Receive the schema
       val schema = dataIn.readInt() match {
         case length if length >= 0 =>
@@ -273,9 +289,13 @@ object UserDefinedPythonTableFunction {
       case eof: EOFException =>
         throw new SparkException("Python worker exited unexpectedly 
(crashed)", eof)
     } finally {
-      if (!releasedOrClosed) {
-        // An error happened. Force to close the worker.
-        env.destroyPythonWorker(pythonExec, workerModule, 
envVars.asScala.toMap, worker)
+      try {
+        bufferStream.close()
+      } finally {
+        if (!releasedOrClosed) {
+          // An error happened. Force to close the worker.
+          env.destroyPythonWorker(pythonExec, workerModule, 
envVars.asScala.toMap, worker)
+        }
       }
     }
   }
@@ -288,8 +308,7 @@ object UserDefinedPythonTableFunction {
    * This is a port and simplified version of `PythonRunner.ReaderInputStream`,
    * and only supports to write all at once and then read all.
    */
-  private class WorkerInputStream(
-      worker: PythonWorker, bufferStream: DirectByteBufferOutputStream) 
extends InputStream {
+  private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) 
extends InputStream {
 
     private[this] val temp = new Array[Byte](1)
 
@@ -312,14 +331,15 @@ object UserDefinedPythonTableFunction {
           n = worker.channel.read(buf)
         }
         if (worker.selectionKey.isWritable) {
-          val buffer = bufferStream.toByteBuffer
           var acceptsInput = true
           while (acceptsInput && buffer.hasRemaining) {
             val n = worker.channel.write(buffer)
             acceptsInput = n > 0
           }
-          // We no longer have any data to write to the socket.
-          worker.selectionKey.interestOps(SelectionKey.OP_READ)
+          if (!buffer.hasRemaining) {
+            // We no longer have any data to write to the socket.
+            worker.selectionKey.interestOps(SelectionKey.OP_READ)
+          }
         }
       }
       n


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to