This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d91904e6271 [SPARK-43965][PYTHON][CONNECT] Support Python UDTF in 
Spark Connect
d91904e6271 is described below

commit d91904e627101f260933ac42244a75736dc97881
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Sun Jul 16 16:23:32 2023 +0900

    [SPARK-43965][PYTHON][CONNECT] Support Python UDTF in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR supports creating and registering Python UDTFs in Spark Connect.
    
    ### Why are the changes needed?
    
    To make Python UDTF work with Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. After this PR, users can use Python UDTFs in Spark Connect.
    
    ### How was this patch tested?
    
    New unit tests
    
    Closes #41989 from allisonwang-db/spark-43965-udtf-spark-connect.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/protobuf/spark/connect/commands.proto |   1 +
 .../main/protobuf/spark/connect/relations.proto    |  31 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  65 ++++-
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/sql/connect/client/core.py          |  39 +++
 python/pyspark/sql/connect/functions.py            |  18 ++
 python/pyspark/sql/connect/plan.py                 |  93 ++++++-
 python/pyspark/sql/connect/proto/commands_pb2.py   | 160 ++++++------
 python/pyspark/sql/connect/proto/commands_pb2.pyi  |  12 +
 python/pyspark/sql/connect/proto/relations_pb2.py  | 268 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 119 +++++++++
 python/pyspark/sql/connect/session.py              |  11 +-
 python/pyspark/sql/connect/udtf.py                 | 205 ++++++++++++++++
 python/pyspark/sql/connect/utils.py                |   4 +
 python/pyspark/sql/functions.py                    |   3 +-
 .../sql/tests/connect/test_connect_basic.py        |  11 -
 .../sql/tests/connect/test_connect_function.py     |  18 ++
 .../pyspark/sql/tests/connect/test_parity_udtf.py  | 141 +++++++++++
 python/pyspark/sql/tests/test_udtf.py              |  38 ++-
 python/pyspark/sql/udtf.py                         | 111 +++++----
 20 files changed, 1071 insertions(+), 282 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 6689662fcf8..ce8c1d53943 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -41,6 +41,7 @@ message Command {
     StreamingQueryCommand streaming_query_command = 7;
     GetResourcesCommand get_resources_command = 8;
     StreamingQueryManagerCommand streaming_query_manager_command = 9;
+    CommonInlineUserDefinedTableFunction register_table_function = 10;
 
     // This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
     // Commands they can add them here. During the planning the correct 
resolution is done.
diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index b6a3e5fa236..8001b3cbcfa 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -71,6 +71,7 @@ message Relation {
     HtmlString html_string = 35;
     CachedLocalRelation cached_local_relation = 36;
     CachedRemoteRelation cached_remote_relation = 37;
+    CommonInlineUserDefinedTableFunction 
common_inline_user_defined_table_function = 38;
 
     // NA functions
     NAFill fill_na = 90;
@@ -941,6 +942,36 @@ message ApplyInPandasWithState {
   string timeout_conf = 7;
 }
 
+message CommonInlineUserDefinedTableFunction {
+  // (Required) Name of the user-defined table function.
+  string function_name = 1;
+
+  // (Optional) Whether the user-defined table function is deterministic.
+  bool deterministic = 2;
+
+  // (Optional) Function input arguments. Empty arguments are allowed.
+  repeated Expression arguments = 3;
+
+  // (Required) Type of the user-defined table function.
+  oneof function {
+    PythonUDTF python_udtf = 4;
+  }
+}
+
+message PythonUDTF {
+  // (Optional) Return type of the Python UDTF.
+  optional DataType return_type = 1;
+
+  // (Required) EvalType of the Python UDTF.
+  int32 eval_type = 2;
+
+  // (Required) The encoded commands of the Python UDTF.
+  bytes command = 3;
+
+  // (Required) Python version being used in the client.
+  string python_ver = 4;
+}
+
 // Collect arbitrary (named) metrics from a dataset.
 message CollectMetrics {
   // (Required) The input relation.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 492396631f3..d414169fd81 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -66,7 +66,7 @@ import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.command.CreateViewCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, 
JDBCPartition, JDBCRelation}
-import org.apache.spark.sql.execution.python.{PythonForeachWriter, 
UserDefinedPythonFunction}
+import org.apache.spark.sql.execution.python.{PythonForeachWriter, 
UserDefinedPythonFunction, UserDefinedPythonTableFunction}
 import org.apache.spark.sql.execution.stat.StatFunctions
 import 
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
 import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
@@ -153,6 +153,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) 
extends Logging {
         transformCoGroupMap(rel.getCoGroupMap)
       case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
         transformApplyInPandasWithState(rel.getApplyInPandasWithState)
+      case 
proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION =>
+        
transformCommonInlineUserDefinedTableFunction(rel.getCommonInlineUserDefinedTableFunction)
       case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
         transformCachedRemoteRelation(rel.getCachedRemoteRelation)
       case proto.Relation.RelTypeCase.COLLECT_METRICS =>
@@ -890,6 +892,32 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       .logicalPlan
   }
 
+  private def transformCommonInlineUserDefinedTableFunction(
+      fun: proto.CommonInlineUserDefinedTableFunction): LogicalPlan = {
+    fun.getFunctionCase match {
+      case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF 
=>
+        val function = createPythonUserDefinedTableFunction(fun)
+        
function.builder(fun.getArgumentsList.asScala.map(transformExpression).toSeq)
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${fun.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  private def transformPythonTableFunction(fun: proto.PythonUDTF): 
SimplePythonFunction = {
+    SimplePythonFunction(
+      command = fun.getCommand.toByteArray,
+      // Empty environment variables
+      envVars = Maps.newHashMap(),
+      pythonIncludes = 
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+      pythonExec = pythonExec,
+      pythonVer = fun.getPythonVer,
+      // Empty broadcast variables
+      broadcastVars = Lists.newArrayList(),
+      // Null accumulator
+      accumulator = null)
+  }
+
   private def transformCachedRemoteRelation(rel: proto.CachedRemoteRelation): 
LogicalPlan = {
     sessionHolder
       .getDataFrameOrThrow(rel.getRelationId)
@@ -2311,6 +2339,8 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     command.getCommandTypeCase match {
       case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
         handleRegisterUserDefinedFunction(command.getRegisterFunction)
+      case proto.Command.CommandTypeCase.REGISTER_TABLE_FUNCTION =>
+        
handleRegisterUserDefinedTableFunction(command.getRegisterTableFunction)
       case proto.Command.CommandTypeCase.WRITE_OPERATION =>
         handleWriteOperation(command.getWriteOperation)
       case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
@@ -2437,6 +2467,39 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     }
   }
 
+  private def handleRegisterUserDefinedTableFunction(
+      fun: proto.CommonInlineUserDefinedTableFunction): Unit = {
+    fun.getFunctionCase match {
+      case proto.CommonInlineUserDefinedTableFunction.FunctionCase.PYTHON_UDTF 
=>
+        val function = createPythonUserDefinedTableFunction(fun)
+        session.udtf.registerPython(fun.getFunctionName, function)
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${fun.getFunctionCase.getNumber} is not 
supported")
+    }
+  }
+
+  private def createPythonUserDefinedTableFunction(
+      fun: proto.CommonInlineUserDefinedTableFunction): 
UserDefinedPythonTableFunction = {
+    val udtf = fun.getPythonUdtf
+    // Currently return type is required for Python UDTFs.
+    // TODO(SPARK-44380): support `analyze` in Python UDTFs
+    assert(udtf.hasReturnType)
+    val returnType = transformDataType(udtf.getReturnType)
+    if (!returnType.isInstanceOf[StructType]) {
+      throw InvalidPlanInput(
+        "Invalid Python user-defined table function return type. " +
+          s"Expect a struct type, but got ${returnType.typeName}.")
+    }
+
+    UserDefinedPythonTableFunction(
+      name = fun.getFunctionName,
+      func = transformPythonTableFunction(udtf),
+      returnType = returnType.asInstanceOf[StructType],
+      pythonEvalType = udtf.getEvalType,
+      udfDeterministic = fun.getDeterministic)
+  }
+
   private def handleRegisterPythonUDF(fun: 
proto.CommonInlineUserDefinedFunction): Unit = {
     val udf = fun.getPythonUdf
     val function = transformPythonFunction(udf)
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 8c51024bf06..56b166b53c5 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -668,6 +668,11 @@ ERROR_CLASSES_JSON = """
       "The number of columns in the result does not match the specified 
schema. Expected column count: <expected>, Actual column count: <actual>. 
Please make sure the values returned by the function have the same number of 
columns as specified in the output schema."
     ]
   },
+  "UDTF_RETURN_TYPE_MISMATCH" : {
+    "message" : [
+      "Mismatch in return type for the UDTF '<name>'. Expected a 'StructType', 
but got '<return_type>'. Please ensure the return type is a correctly formatted 
StructType."
+    ]
+  },
   "UNEXPECTED_RESPONSE_FROM_SERVER" : {
     "message" : [
       "Unexpected response from iterator server."
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 537ab0a6140..00f2a85d602 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -75,6 +75,11 @@ from pyspark.sql.connect.expressions import (
     CommonInlineUserDefinedFunction,
     JavaUDF,
 )
+from pyspark.sql.connect.plan import (
+    CommonInlineUserDefinedTableFunction,
+    PythonUDTF,
+)
+from pyspark.sql.connect.utils import get_python_ver
 from pyspark.sql.pandas.types import _create_converter_to_pandas, 
from_arrow_schema
 from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
 from pyspark.rdd import PythonEvalType
@@ -641,6 +646,40 @@ class SparkConnectClient(object):
         self._execute(req)
         return name
 
+    def register_udtf(
+        self,
+        function: Any,
+        return_type: "DataTypeOrString",
+        name: str,
+        eval_type: int = PythonEvalType.SQL_TABLE_UDF,
+        deterministic: bool = True,
+    ) -> str:
+        """
+        Register a user-defined table function (UDTF) in the session catalog
+        as a temporary function. The return type, if specified, must be a
+        struct type and it's validated when building the proto message
+        for the PythonUDTF.
+        """
+        udtf = PythonUDTF(
+            func=function,
+            return_type=return_type,
+            eval_type=eval_type,
+            python_ver=get_python_ver(),
+        )
+
+        func = CommonInlineUserDefinedTableFunction(
+            function_name=name,
+            function=udtf,
+            deterministic=deterministic,
+            arguments=[],
+        ).udtf_plan(self)
+
+        req = self._execute_plan_request_with_metadata()
+        req.plan.command.register_table_function.CopyFrom(func)
+
+        self._execute(req)
+        return name
+
     def register_java(
         self,
         name: str,
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 1be759d9b6e..a1c0516ee0d 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -31,6 +31,7 @@ from typing import (
     overload,
     Optional,
     Tuple,
+    Type,
     Callable,
     ValuesView,
     cast,
@@ -52,6 +53,7 @@ from pyspark.sql.connect.expressions import (
     UnresolvedNamedLambdaVariable,
 )
 from pyspark.sql.connect.udf import _create_py_udf
+from pyspark.sql.connect.udtf import _create_py_udtf
 from pyspark.sql import functions as pysparkfuncs
 from pyspark.sql.types import _from_numpy_type, DataType, StructType, 
ArrayType, StringType
 
@@ -67,6 +69,7 @@ if TYPE_CHECKING:
         UserDefinedFunctionLike,
     )
     from pyspark.sql.connect.dataframe import DataFrame
+    from pyspark.sql.connect.udtf import UserDefinedTableFunction
 
 
 def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
@@ -3891,6 +3894,21 @@ def udf(
 udf.__doc__ = pysparkfuncs.udf.__doc__
 
 
+def udtf(
+    cls: Optional[Type] = None,
+    *,
+    returnType: Union[StructType, str],
+    useArrow: Optional[bool] = None,
+) -> Union["UserDefinedTableFunction", Callable[[Type], 
"UserDefinedTableFunction"]]:
+    if cls is None:
+        return functools.partial(_create_py_udtf, returnType=returnType, 
useArrow=useArrow)
+    else:
+        return _create_py_udtf(cls=cls, returnType=returnType, 
useArrow=useArrow)
+
+
+udtf.__doc__ = pysparkfuncs.udtf.__doc__
+
+
 def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
     return _invoke_function(udfName, *[_to_col(c) for c in cols])
 
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 97348d4863e..3390faa04de 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -18,7 +18,7 @@ from pyspark.sql.connect.utils import check_dependencies
 
 check_dependencies(__name__)
 
-from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, 
Mapping, Dict
+from typing import Any, List, Optional, Type, Sequence, Union, cast, 
TYPE_CHECKING, Mapping, Dict
 import functools
 import json
 from threading import Lock
@@ -26,6 +26,7 @@ from inspect import signature, isclass
 
 import pyarrow as pa
 
+from pyspark.serializers import CloudPickleSerializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.sql.types import DataType
 
@@ -33,11 +34,12 @@ import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.conversion import storage_level_to_proto
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import (
+    Expression,
     SortOrder,
     ColumnReference,
     LiteralExpression,
 )
-from pyspark.sql.connect.types import pyspark_types_to_proto_types
+from pyspark.sql.connect.types import pyspark_types_to_proto_types, 
UnparsedDataType
 from pyspark.errors import PySparkTypeError, PySparkNotImplementedError
 
 if TYPE_CHECKING:
@@ -2173,6 +2175,93 @@ class ApplyInPandasWithState(LogicalPlan):
         return plan
 
 
+class PythonUDTF:
+    """Represents a Python user-defined table function."""
+
+    def __init__(
+        self,
+        func: Type,
+        return_type: Union[DataType, str],
+        eval_type: int,
+        python_ver: str,
+    ) -> None:
+        self._func = func
+        self._name = func.__name__
+        self._return_type: DataType = (
+            UnparsedDataType(return_type) if isinstance(return_type, str) else 
return_type
+        )
+        self._eval_type = eval_type
+        self._python_ver = python_ver
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF:
+        udtf = proto.PythonUDTF()
+        # Currently the return type cannot be None.
+        # TODO(SPARK-44380): support `analyze` in Python UDTFs
+        assert self._return_type is not None
+        
udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type))
+        udtf.eval_type = self._eval_type
+        udtf.command = CloudPickleSerializer().dumps(self._func)
+        udtf.python_ver = self._python_ver
+        return udtf
+
+    def __repr__(self) -> str:
+        return (
+            f"PythonUDTF({self._name}, {self._return_type}, "
+            f"{self._eval_type}, {self._python_ver})"
+        )
+
+
+class CommonInlineUserDefinedTableFunction(LogicalPlan):
+    """
+    Logical plan object for a user-defined table function with
+    an inlined defined function body.
+    """
+
+    def __init__(
+        self,
+        function_name: str,
+        function: PythonUDTF,
+        deterministic: bool,
+        arguments: Sequence[Expression],
+    ) -> None:
+        super().__init__(None)
+        self._function_name = function_name
+        self._deterministic = deterministic
+        self._arguments = arguments
+        self._function = function
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        plan = self._create_proto_relation()
+        plan.common_inline_user_defined_table_function.function_name = 
self._function_name
+        plan.common_inline_user_defined_table_function.deterministic = 
self._deterministic
+        if len(self._arguments) > 0:
+            plan.common_inline_user_defined_table_function.arguments.extend(
+                [arg.to_plan(session) for arg in self._arguments]
+            )
+        plan.common_inline_user_defined_table_function.python_udtf.CopyFrom(
+            self._function.to_plan(session)
+        )
+        return plan
+
+    def udtf_plan(
+        self, session: "SparkConnectClient"
+    ) -> "proto.CommonInlineUserDefinedTableFunction":
+        """
+        Compared to `plan`, it returns a 
`proto.CommonInlineUserDefinedTableFunction`
+        instead of a `proto.Relation`.
+        """
+        plan = proto.CommonInlineUserDefinedTableFunction()
+        plan.function_name = self._function_name
+        plan.deterministic = self._deterministic
+        if len(self._arguments) > 0:
+            plan.arguments.extend([arg.to_plan(session) for arg in 
self._arguments])
+        plan.python_udtf.CopyFrom(cast(proto.PythonUDF, 
self._function.to_plan(session)))
+        return plan
+
+    def __repr__(self) -> str:
+        return f"{self._function_name}({', '.join([str(arg) for arg in 
self._arguments])})"
+
+
 class CachedRelation(LogicalPlan):
     def __init__(self, plan: proto.Relation) -> None:
         super(CachedRelation, self).__init__(None)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py 
b/python/pyspark/sql/connect/proto/commands_pb2.py
index 3947d172ed4..c852cebd140 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import relations_pb2 as 
spark_dot_connect_dot_rel
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x86\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
+    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -59,83 +59,83 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None
     _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001"
     _COMMAND._serialized_start = 167
-    _COMMAND._serialized_end = 1069
-    _SQLCOMMAND._serialized_start = 1072
-    _SQLCOMMAND._serialized_end = 1313
-    _SQLCOMMAND_ARGSENTRY._serialized_start = 1223
-    _SQLCOMMAND_ARGSENTRY._serialized_end = 1313
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1316
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1466
-    _WRITEOPERATION._serialized_start = 1469
-    _WRITEOPERATION._serialized_end = 2520
-    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1944
-    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 2002
-    _WRITEOPERATION_SAVETABLE._serialized_start = 2005
-    _WRITEOPERATION_SAVETABLE._serialized_end = 2263
-    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 2139
-    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 2263
-    _WRITEOPERATION_BUCKETBY._serialized_start = 2265
-    _WRITEOPERATION_BUCKETBY._serialized_end = 2356
-    _WRITEOPERATION_SAVEMODE._serialized_start = 2359
-    _WRITEOPERATION_SAVEMODE._serialized_end = 2496
-    _WRITEOPERATIONV2._serialized_start = 2523
-    _WRITEOPERATIONV2._serialized_end = 3336
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1944
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 2002
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 3095
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 3161
-    _WRITEOPERATIONV2_MODE._serialized_start = 3164
-    _WRITEOPERATIONV2_MODE._serialized_end = 3323
-    _WRITESTREAMOPERATIONSTART._serialized_start = 3339
-    _WRITESTREAMOPERATIONSTART._serialized_end = 4139
-    _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 1944
-    _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2002
-    _STREAMINGFOREACHFUNCTION._serialized_start = 4142
-    _STREAMINGFOREACHFUNCTION._serialized_end = 4321
-    _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 4323
-    _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 4444
-    _STREAMINGQUERYINSTANCEID._serialized_start = 4446
-    _STREAMINGQUERYINSTANCEID._serialized_end = 4511
-    _STREAMINGQUERYCOMMAND._serialized_start = 4514
-    _STREAMINGQUERYCOMMAND._serialized_end = 5146
-    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5013
-    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 5057
-    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 5059
-    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 5135
-    _STREAMINGQUERYCOMMANDRESULT._serialized_start = 5149
-    _STREAMINGQUERYCOMMANDRESULT._serialized_end = 6290
-    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5732
-    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 5902
-    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 5904
-    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 5976
-    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 5978
-    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6017
-    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6020
-    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 6217
-    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 
6219
-    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6275
-    _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6293
-    _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 6990
-    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start 
= 6824
-    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 
6903
-    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 
6905
-    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 
6979
-    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 6993
-    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8147
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7601
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7728
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 
7730
-    _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end 
= 7845
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start 
= 7847
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 
7906
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start
 = 7908
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end
 = 7983
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start
 = 7986
-    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end
 = 8132
-    _GETRESOURCESCOMMAND._serialized_start = 8149
-    _GETRESOURCESCOMMAND._serialized_end = 8170
-    _GETRESOURCESCOMMANDRESULT._serialized_start = 8173
-    _GETRESOURCESCOMMANDRESULT._serialized_end = 8385
-    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8289
-    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8385
+    _COMMAND._serialized_end = 1180
+    _SQLCOMMAND._serialized_start = 1183
+    _SQLCOMMAND._serialized_end = 1424
+    _SQLCOMMAND_ARGSENTRY._serialized_start = 1334
+    _SQLCOMMAND_ARGSENTRY._serialized_end = 1424
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1427
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1577
+    _WRITEOPERATION._serialized_start = 1580
+    _WRITEOPERATION._serialized_end = 2631
+    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 2055
+    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 2113
+    _WRITEOPERATION_SAVETABLE._serialized_start = 2116
+    _WRITEOPERATION_SAVETABLE._serialized_end = 2374
+    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 2250
+    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 2374
+    _WRITEOPERATION_BUCKETBY._serialized_start = 2376
+    _WRITEOPERATION_BUCKETBY._serialized_end = 2467
+    _WRITEOPERATION_SAVEMODE._serialized_start = 2470
+    _WRITEOPERATION_SAVEMODE._serialized_end = 2607
+    _WRITEOPERATIONV2._serialized_start = 2634
+    _WRITEOPERATIONV2._serialized_end = 3447
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 2055
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 2113
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 3206
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 3272
+    _WRITEOPERATIONV2_MODE._serialized_start = 3275
+    _WRITEOPERATIONV2_MODE._serialized_end = 3434
+    _WRITESTREAMOPERATIONSTART._serialized_start = 3450
+    _WRITESTREAMOPERATIONSTART._serialized_end = 4250
+    _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_start = 2055
+    _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_end = 2113
+    _STREAMINGFOREACHFUNCTION._serialized_start = 4253
+    _STREAMINGFOREACHFUNCTION._serialized_end = 4432
+    _WRITESTREAMOPERATIONSTARTRESULT._serialized_start = 4434
+    _WRITESTREAMOPERATIONSTARTRESULT._serialized_end = 4555
+    _STREAMINGQUERYINSTANCEID._serialized_start = 4557
+    _STREAMINGQUERYINSTANCEID._serialized_end = 4622
+    _STREAMINGQUERYCOMMAND._serialized_start = 4625
+    _STREAMINGQUERYCOMMAND._serialized_end = 5257
+    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 5124
+    _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 5168
+    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 5170
+    _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 5246
+    _STREAMINGQUERYCOMMANDRESULT._serialized_start = 5260
+    _STREAMINGQUERYCOMMANDRESULT._serialized_end = 6401
+    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5843
+    _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 6013
+    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 6015
+    _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 6087
+    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 6089
+    _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 6128
+    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 6131
+    _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 6328
+    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 
6330
+    _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6386
+    _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6404
+    _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7101
+    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start 
= 6935
+    _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 
7014
+    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 
7016
+    
_STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 
7090
+    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7104
+    _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8258
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7712
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7839
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 
7841
+    _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end 
= 7956
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start 
= 7958
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 
8017
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start
 = 8019
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end
 = 8094
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start
 = 8097
+    
_STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end
 = 8243
+    _GETRESOURCESCOMMAND._serialized_start = 8260
+    _GETRESOURCESCOMMAND._serialized_end = 8281
+    _GETRESOURCESCOMMANDRESULT._serialized_start = 8284
+    _GETRESOURCESCOMMANDRESULT._serialized_end = 8496
+    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8400
+    _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8496
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi 
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index fe472b3140d..5b44e3da7a0 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -69,6 +69,7 @@ class Command(google.protobuf.message.Message):
     STREAMING_QUERY_COMMAND_FIELD_NUMBER: builtins.int
     GET_RESOURCES_COMMAND_FIELD_NUMBER: builtins.int
     STREAMING_QUERY_MANAGER_COMMAND_FIELD_NUMBER: builtins.int
+    REGISTER_TABLE_FUNCTION_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def register_function(
@@ -91,6 +92,10 @@ class Command(google.protobuf.message.Message):
     @property
     def streaming_query_manager_command(self) -> 
global___StreamingQueryManagerCommand: ...
     @property
+    def register_table_function(
+        self,
+    ) -> 
pyspark.sql.connect.proto.relations_pb2.CommonInlineUserDefinedTableFunction: 
...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins 
generate arbitrary
         Commands they can add them here. During the planning the correct 
resolution is done.
@@ -108,6 +113,8 @@ class Command(google.protobuf.message.Message):
         streaming_query_command: global___StreamingQueryCommand | None = ...,
         get_resources_command: global___GetResourcesCommand | None = ...,
         streaming_query_manager_command: global___StreamingQueryManagerCommand 
| None = ...,
+        register_table_function: 
pyspark.sql.connect.proto.relations_pb2.CommonInlineUserDefinedTableFunction
+        | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -123,6 +130,8 @@ class Command(google.protobuf.message.Message):
             b"get_resources_command",
             "register_function",
             b"register_function",
+            "register_table_function",
+            b"register_table_function",
             "sql_command",
             b"sql_command",
             "streaming_query_command",
@@ -150,6 +159,8 @@ class Command(google.protobuf.message.Message):
             b"get_resources_command",
             "register_function",
             b"register_function",
+            "register_table_function",
+            b"register_table_function",
             "sql_command",
             b"sql_command",
             "streaming_query_command",
@@ -176,6 +187,7 @@ class Command(google.protobuf.message.Message):
         "streaming_query_command",
         "get_resources_command",
         "streaming_query_manager_command",
+        "register_table_function",
         "extension",
     ] | None: ...
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py 
b/python/pyspark/sql/connect/proto/relations_pb2.py
index d370bc6fd0c..3a0a7ff71fd 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -35,7 +35,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as 
spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xd0\x17\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xe1\x18\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
 
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
 
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -57,135 +57,139 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _PARSE_OPTIONSENTRY._options = None
     _PARSE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 165
-    _RELATION._serialized_end = 3189
-    _UNKNOWN._serialized_start = 3191
-    _UNKNOWN._serialized_end = 3200
-    _RELATIONCOMMON._serialized_start = 3202
-    _RELATIONCOMMON._serialized_end = 3293
-    _SQL._serialized_start = 3296
-    _SQL._serialized_end = 3527
-    _SQL_ARGSENTRY._serialized_start = 3437
-    _SQL_ARGSENTRY._serialized_end = 3527
-    _READ._serialized_start = 3530
-    _READ._serialized_end = 4193
-    _READ_NAMEDTABLE._serialized_start = 3708
-    _READ_NAMEDTABLE._serialized_end = 3900
-    _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3842
-    _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 3900
-    _READ_DATASOURCE._serialized_start = 3903
-    _READ_DATASOURCE._serialized_end = 4180
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3842
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3900
-    _PROJECT._serialized_start = 4195
-    _PROJECT._serialized_end = 4312
-    _FILTER._serialized_start = 4314
-    _FILTER._serialized_end = 4426
-    _JOIN._serialized_start = 4429
-    _JOIN._serialized_end = 5090
-    _JOIN_JOINDATATYPE._serialized_start = 4768
-    _JOIN_JOINDATATYPE._serialized_end = 4860
-    _JOIN_JOINTYPE._serialized_start = 4863
-    _JOIN_JOINTYPE._serialized_end = 5071
-    _SETOPERATION._serialized_start = 5093
-    _SETOPERATION._serialized_end = 5572
-    _SETOPERATION_SETOPTYPE._serialized_start = 5409
-    _SETOPERATION_SETOPTYPE._serialized_end = 5523
-    _LIMIT._serialized_start = 5574
-    _LIMIT._serialized_end = 5650
-    _OFFSET._serialized_start = 5652
-    _OFFSET._serialized_end = 5731
-    _TAIL._serialized_start = 5733
-    _TAIL._serialized_end = 5808
-    _AGGREGATE._serialized_start = 5811
-    _AGGREGATE._serialized_end = 6393
-    _AGGREGATE_PIVOT._serialized_start = 6150
-    _AGGREGATE_PIVOT._serialized_end = 6261
-    _AGGREGATE_GROUPTYPE._serialized_start = 6264
-    _AGGREGATE_GROUPTYPE._serialized_end = 6393
-    _SORT._serialized_start = 6396
-    _SORT._serialized_end = 6556
-    _DROP._serialized_start = 6559
-    _DROP._serialized_end = 6700
-    _DEDUPLICATE._serialized_start = 6703
-    _DEDUPLICATE._serialized_end = 6943
-    _LOCALRELATION._serialized_start = 6945
-    _LOCALRELATION._serialized_end = 7034
-    _CACHEDLOCALRELATION._serialized_start = 7036
-    _CACHEDLOCALRELATION._serialized_end = 7131
-    _CACHEDREMOTERELATION._serialized_start = 7133
-    _CACHEDREMOTERELATION._serialized_end = 7188
-    _SAMPLE._serialized_start = 7191
-    _SAMPLE._serialized_end = 7464
-    _RANGE._serialized_start = 7467
-    _RANGE._serialized_end = 7612
-    _SUBQUERYALIAS._serialized_start = 7614
-    _SUBQUERYALIAS._serialized_end = 7728
-    _REPARTITION._serialized_start = 7731
-    _REPARTITION._serialized_end = 7873
-    _SHOWSTRING._serialized_start = 7876
-    _SHOWSTRING._serialized_end = 8018
-    _HTMLSTRING._serialized_start = 8020
-    _HTMLSTRING._serialized_end = 8134
-    _STATSUMMARY._serialized_start = 8136
-    _STATSUMMARY._serialized_end = 8228
-    _STATDESCRIBE._serialized_start = 8230
-    _STATDESCRIBE._serialized_end = 8311
-    _STATCROSSTAB._serialized_start = 8313
-    _STATCROSSTAB._serialized_end = 8414
-    _STATCOV._serialized_start = 8416
-    _STATCOV._serialized_end = 8512
-    _STATCORR._serialized_start = 8515
-    _STATCORR._serialized_end = 8652
-    _STATAPPROXQUANTILE._serialized_start = 8655
-    _STATAPPROXQUANTILE._serialized_end = 8819
-    _STATFREQITEMS._serialized_start = 8821
-    _STATFREQITEMS._serialized_end = 8946
-    _STATSAMPLEBY._serialized_start = 8949
-    _STATSAMPLEBY._serialized_end = 9258
-    _STATSAMPLEBY_FRACTION._serialized_start = 9150
-    _STATSAMPLEBY_FRACTION._serialized_end = 9249
-    _NAFILL._serialized_start = 9261
-    _NAFILL._serialized_end = 9395
-    _NADROP._serialized_start = 9398
-    _NADROP._serialized_end = 9532
-    _NAREPLACE._serialized_start = 9535
-    _NAREPLACE._serialized_end = 9831
-    _NAREPLACE_REPLACEMENT._serialized_start = 9690
-    _NAREPLACE_REPLACEMENT._serialized_end = 9831
-    _TODF._serialized_start = 9833
-    _TODF._serialized_end = 9921
-    _WITHCOLUMNSRENAMED._serialized_start = 9924
-    _WITHCOLUMNSRENAMED._serialized_end = 10163
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10096
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10163
-    _WITHCOLUMNS._serialized_start = 10165
-    _WITHCOLUMNS._serialized_end = 10284
-    _WITHWATERMARK._serialized_start = 10287
-    _WITHWATERMARK._serialized_end = 10421
-    _HINT._serialized_start = 10424
-    _HINT._serialized_end = 10556
-    _UNPIVOT._serialized_start = 10559
-    _UNPIVOT._serialized_end = 10886
-    _UNPIVOT_VALUES._serialized_start = 10816
-    _UNPIVOT_VALUES._serialized_end = 10875
-    _TOSCHEMA._serialized_start = 10888
-    _TOSCHEMA._serialized_end = 10994
-    _REPARTITIONBYEXPRESSION._serialized_start = 10997
-    _REPARTITIONBYEXPRESSION._serialized_end = 11200
-    _MAPPARTITIONS._serialized_start = 11203
-    _MAPPARTITIONS._serialized_end = 11384
-    _GROUPMAP._serialized_start = 11387
-    _GROUPMAP._serialized_end = 12022
-    _COGROUPMAP._serialized_start = 12025
-    _COGROUPMAP._serialized_end = 12551
-    _APPLYINPANDASWITHSTATE._serialized_start = 12554
-    _APPLYINPANDASWITHSTATE._serialized_end = 12911
-    _COLLECTMETRICS._serialized_start = 12914
-    _COLLECTMETRICS._serialized_end = 13050
-    _PARSE._serialized_start = 13053
-    _PARSE._serialized_end = 13441
-    _PARSE_OPTIONSENTRY._serialized_start = 3842
-    _PARSE_OPTIONSENTRY._serialized_end = 3900
-    _PARSE_PARSEFORMAT._serialized_start = 13342
-    _PARSE_PARSEFORMAT._serialized_end = 13430
+    _RELATION._serialized_end = 3334
+    _UNKNOWN._serialized_start = 3336
+    _UNKNOWN._serialized_end = 3345
+    _RELATIONCOMMON._serialized_start = 3347
+    _RELATIONCOMMON._serialized_end = 3438
+    _SQL._serialized_start = 3441
+    _SQL._serialized_end = 3672
+    _SQL_ARGSENTRY._serialized_start = 3582
+    _SQL_ARGSENTRY._serialized_end = 3672
+    _READ._serialized_start = 3675
+    _READ._serialized_end = 4338
+    _READ_NAMEDTABLE._serialized_start = 3853
+    _READ_NAMEDTABLE._serialized_end = 4045
+    _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3987
+    _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 4045
+    _READ_DATASOURCE._serialized_start = 4048
+    _READ_DATASOURCE._serialized_end = 4325
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3987
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 4045
+    _PROJECT._serialized_start = 4340
+    _PROJECT._serialized_end = 4457
+    _FILTER._serialized_start = 4459
+    _FILTER._serialized_end = 4571
+    _JOIN._serialized_start = 4574
+    _JOIN._serialized_end = 5235
+    _JOIN_JOINDATATYPE._serialized_start = 4913
+    _JOIN_JOINDATATYPE._serialized_end = 5005
+    _JOIN_JOINTYPE._serialized_start = 5008
+    _JOIN_JOINTYPE._serialized_end = 5216
+    _SETOPERATION._serialized_start = 5238
+    _SETOPERATION._serialized_end = 5717
+    _SETOPERATION_SETOPTYPE._serialized_start = 5554
+    _SETOPERATION_SETOPTYPE._serialized_end = 5668
+    _LIMIT._serialized_start = 5719
+    _LIMIT._serialized_end = 5795
+    _OFFSET._serialized_start = 5797
+    _OFFSET._serialized_end = 5876
+    _TAIL._serialized_start = 5878
+    _TAIL._serialized_end = 5953
+    _AGGREGATE._serialized_start = 5956
+    _AGGREGATE._serialized_end = 6538
+    _AGGREGATE_PIVOT._serialized_start = 6295
+    _AGGREGATE_PIVOT._serialized_end = 6406
+    _AGGREGATE_GROUPTYPE._serialized_start = 6409
+    _AGGREGATE_GROUPTYPE._serialized_end = 6538
+    _SORT._serialized_start = 6541
+    _SORT._serialized_end = 6701
+    _DROP._serialized_start = 6704
+    _DROP._serialized_end = 6845
+    _DEDUPLICATE._serialized_start = 6848
+    _DEDUPLICATE._serialized_end = 7088
+    _LOCALRELATION._serialized_start = 7090
+    _LOCALRELATION._serialized_end = 7179
+    _CACHEDLOCALRELATION._serialized_start = 7181
+    _CACHEDLOCALRELATION._serialized_end = 7276
+    _CACHEDREMOTERELATION._serialized_start = 7278
+    _CACHEDREMOTERELATION._serialized_end = 7333
+    _SAMPLE._serialized_start = 7336
+    _SAMPLE._serialized_end = 7609
+    _RANGE._serialized_start = 7612
+    _RANGE._serialized_end = 7757
+    _SUBQUERYALIAS._serialized_start = 7759
+    _SUBQUERYALIAS._serialized_end = 7873
+    _REPARTITION._serialized_start = 7876
+    _REPARTITION._serialized_end = 8018
+    _SHOWSTRING._serialized_start = 8021
+    _SHOWSTRING._serialized_end = 8163
+    _HTMLSTRING._serialized_start = 8165
+    _HTMLSTRING._serialized_end = 8279
+    _STATSUMMARY._serialized_start = 8281
+    _STATSUMMARY._serialized_end = 8373
+    _STATDESCRIBE._serialized_start = 8375
+    _STATDESCRIBE._serialized_end = 8456
+    _STATCROSSTAB._serialized_start = 8458
+    _STATCROSSTAB._serialized_end = 8559
+    _STATCOV._serialized_start = 8561
+    _STATCOV._serialized_end = 8657
+    _STATCORR._serialized_start = 8660
+    _STATCORR._serialized_end = 8797
+    _STATAPPROXQUANTILE._serialized_start = 8800
+    _STATAPPROXQUANTILE._serialized_end = 8964
+    _STATFREQITEMS._serialized_start = 8966
+    _STATFREQITEMS._serialized_end = 9091
+    _STATSAMPLEBY._serialized_start = 9094
+    _STATSAMPLEBY._serialized_end = 9403
+    _STATSAMPLEBY_FRACTION._serialized_start = 9295
+    _STATSAMPLEBY_FRACTION._serialized_end = 9394
+    _NAFILL._serialized_start = 9406
+    _NAFILL._serialized_end = 9540
+    _NADROP._serialized_start = 9543
+    _NADROP._serialized_end = 9677
+    _NAREPLACE._serialized_start = 9680
+    _NAREPLACE._serialized_end = 9976
+    _NAREPLACE_REPLACEMENT._serialized_start = 9835
+    _NAREPLACE_REPLACEMENT._serialized_end = 9976
+    _TODF._serialized_start = 9978
+    _TODF._serialized_end = 10066
+    _WITHCOLUMNSRENAMED._serialized_start = 10069
+    _WITHCOLUMNSRENAMED._serialized_end = 10308
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 10241
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 10308
+    _WITHCOLUMNS._serialized_start = 10310
+    _WITHCOLUMNS._serialized_end = 10429
+    _WITHWATERMARK._serialized_start = 10432
+    _WITHWATERMARK._serialized_end = 10566
+    _HINT._serialized_start = 10569
+    _HINT._serialized_end = 10701
+    _UNPIVOT._serialized_start = 10704
+    _UNPIVOT._serialized_end = 11031
+    _UNPIVOT_VALUES._serialized_start = 10961
+    _UNPIVOT_VALUES._serialized_end = 11020
+    _TOSCHEMA._serialized_start = 11033
+    _TOSCHEMA._serialized_end = 11139
+    _REPARTITIONBYEXPRESSION._serialized_start = 11142
+    _REPARTITIONBYEXPRESSION._serialized_end = 11345
+    _MAPPARTITIONS._serialized_start = 11348
+    _MAPPARTITIONS._serialized_end = 11529
+    _GROUPMAP._serialized_start = 11532
+    _GROUPMAP._serialized_end = 12167
+    _COGROUPMAP._serialized_start = 12170
+    _COGROUPMAP._serialized_end = 12696
+    _APPLYINPANDASWITHSTATE._serialized_start = 12699
+    _APPLYINPANDASWITHSTATE._serialized_end = 13056
+    _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_start = 13059
+    _COMMONINLINEUSERDEFINEDTABLEFUNCTION._serialized_end = 13303
+    _PYTHONUDTF._serialized_start = 13306
+    _PYTHONUDTF._serialized_end = 13483
+    _COLLECTMETRICS._serialized_start = 13486
+    _COLLECTMETRICS._serialized_end = 13622
+    _PARSE._serialized_start = 13625
+    _PARSE._serialized_end = 14013
+    _PARSE_OPTIONSENTRY._serialized_start = 3987
+    _PARSE_OPTIONSENTRY._serialized_end = 4045
+    _PARSE_PARSEFORMAT._serialized_start = 13914
+    _PARSE_PARSEFORMAT._serialized_end = 14002
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi 
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 1f15ed2c6c7..9cadd4acc52 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -99,6 +99,7 @@ class Relation(google.protobuf.message.Message):
     HTML_STRING_FIELD_NUMBER: builtins.int
     CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
     CACHED_REMOTE_RELATION_FIELD_NUMBER: builtins.int
+    COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION_FIELD_NUMBER: builtins.int
     FILL_NA_FIELD_NUMBER: builtins.int
     DROP_NA_FIELD_NUMBER: builtins.int
     REPLACE_FIELD_NUMBER: builtins.int
@@ -188,6 +189,10 @@ class Relation(google.protobuf.message.Message):
     @property
     def cached_remote_relation(self) -> global___CachedRemoteRelation: ...
     @property
+    def common_inline_user_defined_table_function(
+        self,
+    ) -> global___CommonInlineUserDefinedTableFunction: ...
+    @property
     def fill_na(self) -> global___NAFill:
         """NA functions"""
     @property
@@ -261,6 +266,8 @@ class Relation(google.protobuf.message.Message):
         html_string: global___HtmlString | None = ...,
         cached_local_relation: global___CachedLocalRelation | None = ...,
         cached_remote_relation: global___CachedRemoteRelation | None = ...,
+        common_inline_user_defined_table_function: 
global___CommonInlineUserDefinedTableFunction
+        | None = ...,
         fill_na: global___NAFill | None = ...,
         drop_na: global___NADrop | None = ...,
         replace: global___NAReplace | None = ...,
@@ -297,6 +304,8 @@ class Relation(google.protobuf.message.Message):
             b"collect_metrics",
             "common",
             b"common",
+            "common_inline_user_defined_table_function",
+            b"common_inline_user_defined_table_function",
             "corr",
             b"corr",
             "cov",
@@ -406,6 +415,8 @@ class Relation(google.protobuf.message.Message):
             b"collect_metrics",
             "common",
             b"common",
+            "common_inline_user_defined_table_function",
+            b"common_inline_user_defined_table_function",
             "corr",
             b"corr",
             "cov",
@@ -533,6 +544,7 @@ class Relation(google.protobuf.message.Message):
         "html_string",
         "cached_local_relation",
         "cached_remote_relation",
+        "common_inline_user_defined_table_function",
         "fill_na",
         "drop_na",
         "replace",
@@ -3378,6 +3390,113 @@ class 
ApplyInPandasWithState(google.protobuf.message.Message):
 
 global___ApplyInPandasWithState = ApplyInPandasWithState
 
+class CommonInlineUserDefinedTableFunction(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    FUNCTION_NAME_FIELD_NUMBER: builtins.int
+    DETERMINISTIC_FIELD_NUMBER: builtins.int
+    ARGUMENTS_FIELD_NUMBER: builtins.int
+    PYTHON_UDTF_FIELD_NUMBER: builtins.int
+    function_name: builtins.str
+    """(Required) Name of the user-defined table function."""
+    deterministic: builtins.bool
+    """(Optional) Whether the user-defined table function is deterministic."""
+    @property
+    def arguments(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """(Optional) Function input arguments. Empty arguments are allowed."""
+    @property
+    def python_udtf(self) -> global___PythonUDTF: ...
+    def __init__(
+        self,
+        *,
+        function_name: builtins.str = ...,
+        deterministic: builtins.bool = ...,
+        arguments: 
collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+        | None = ...,
+        python_udtf: global___PythonUDTF | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "function", b"function", "python_udtf", b"python_udtf"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "arguments",
+            b"arguments",
+            "deterministic",
+            b"deterministic",
+            "function",
+            b"function",
+            "function_name",
+            b"function_name",
+            "python_udtf",
+            b"python_udtf",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["function", b"function"]
+    ) -> typing_extensions.Literal["python_udtf"] | None: ...
+
+global___CommonInlineUserDefinedTableFunction = 
CommonInlineUserDefinedTableFunction
+
+class PythonUDTF(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    RETURN_TYPE_FIELD_NUMBER: builtins.int
+    EVAL_TYPE_FIELD_NUMBER: builtins.int
+    COMMAND_FIELD_NUMBER: builtins.int
+    PYTHON_VER_FIELD_NUMBER: builtins.int
+    @property
+    def return_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
+        """(Optional) Return type of the Python UDTF."""
+    eval_type: builtins.int
+    """(Required) EvalType of the Python UDTF."""
+    command: builtins.bytes
+    """(Required) The encoded commands of the Python UDTF."""
+    python_ver: builtins.str
+    """(Required) Python version being used in the client."""
+    def __init__(
+        self,
+        *,
+        return_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
+        eval_type: builtins.int = ...,
+        command: builtins.bytes = ...,
+        python_ver: builtins.str = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_return_type", b"_return_type", "return_type", b"return_type"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_return_type",
+            b"_return_type",
+            "command",
+            b"command",
+            "eval_type",
+            b"eval_type",
+            "python_ver",
+            b"python_ver",
+            "return_type",
+            b"return_type",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_return_type", 
b"_return_type"]
+    ) -> typing_extensions.Literal["return_type"] | None: ...
+
+global___PythonUDTF = PythonUDTF
+
 class CollectMetrics(google.protobuf.message.Message):
     """Collect arbitrary (named) metrics from a dataset."""
 
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index ea88d60d760..3f9d46a22f4 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -88,6 +88,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect._typing import OptionalPrimitiveType
     from pyspark.sql.connect.catalog import Catalog
     from pyspark.sql.connect.udf import UDFRegistration
+    from pyspark.sql.connect.udtf import UDTFRegistration
 
 
 # `_active_spark_session` stores the active spark connect session created by
@@ -599,7 +600,7 @@ class SparkSession:
             raise PySparkAttributeError(
                 error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", 
message_parameters={"attr_name": name}
             )
-        elif name in ["newSession", "sparkContext", "udtf"]:
+        elif name in ["newSession", "sparkContext"]:
             raise PySparkNotImplementedError(
                 error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
f"{name}()"}
             )
@@ -613,6 +614,14 @@ class SparkSession:
 
     udf.__doc__ = PySparkSession.udf.__doc__
 
+    @property
+    def udtf(self) -> "UDTFRegistration":
+        from pyspark.sql.connect.udtf import UDTFRegistration
+
+        return UDTFRegistration(self)
+
+    udtf.__doc__ = PySparkSession.udtf.__doc__
+
     @property
     def version(self) -> str:
         result = self._client._analyze(method="spark_version").spark_version
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
new file mode 100644
index 00000000000..1fe8e1024ee
--- /dev/null
+++ b/python/pyspark/sql/connect/udtf.py
@@ -0,0 +1,205 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined table function related classes and functions
+"""
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+import warnings
+from typing import Type, TYPE_CHECKING, Optional, Union
+
+from pyspark.rdd import PythonEvalType
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.expressions import ColumnReference
+from pyspark.sql.connect.plan import (
+    CommonInlineUserDefinedTableFunction,
+    PythonUDTF,
+)
+from pyspark.sql.connect.types import UnparsedDataType
+from pyspark.sql.connect.utils import get_python_ver
+from pyspark.sql.udtf import UDTFRegistration as PySparkUDTFRegistration
+from pyspark.sql.udtf import _validate_udtf_handler
+from pyspark.sql.types import DataType, StructType
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError
+
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import ColumnOrName
+    from pyspark.sql.connect.dataframe import DataFrame
+    from pyspark.sql.connect.session import SparkSession
+
+
+def _create_udtf(
+    cls: Type,
+    returnType: Union[StructType, str],
+    name: Optional[str] = None,
+    evalType: int = PythonEvalType.SQL_TABLE_UDF,
+    deterministic: bool = True,
+) -> "UserDefinedTableFunction":
+    udtf_obj = UserDefinedTableFunction(
+        cls, returnType=returnType, name=name, evalType=evalType, 
deterministic=deterministic
+    )
+    return udtf_obj
+
+
+def _create_py_udtf(
+    cls: Type,
+    returnType: Union[StructType, str],
+    name: Optional[str] = None,
+    deterministic: bool = True,
+    useArrow: Optional[bool] = None,
+) -> "UserDefinedTableFunction":
+    if useArrow is not None:
+        arrow_enabled = useArrow
+    else:
+        from pyspark.sql.connect.session import _active_spark_session
+
+        arrow_enabled = (
+            
_active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") 
== "true"
+            if _active_spark_session is not None
+            else True
+        )
+
+    # Create a regular Python UDTF and check for invalid handler class.
+    regular_udtf = _create_udtf(cls, returnType, name, 
PythonEvalType.SQL_TABLE_UDF, deterministic)
+
+    if not arrow_enabled:
+        return regular_udtf
+
+    from pyspark.sql.pandas.utils import (
+        require_minimum_pandas_version,
+        require_minimum_pyarrow_version,
+    )
+
+    try:
+        require_minimum_pandas_version()
+        require_minimum_pyarrow_version()
+    except ImportError as e:
+        warnings.warn(
+            f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. 
"
+            f"Falling back to using regular Python UDTFs.",
+            UserWarning,
+        )
+        return regular_udtf
+
+    from pyspark.sql.udtf import _vectorize_udtf
+
+    vectorized_udtf = _vectorize_udtf(cls)
+    return _create_udtf(
+        vectorized_udtf, returnType, name, PythonEvalType.SQL_ARROW_TABLE_UDF, 
deterministic
+    )
+
+
+class UserDefinedTableFunction:
+    """
+    User defined function in Python
+
+    Notes
+    -----
+    The constructor of this class is not supposed to be directly called.
+    Use :meth:`pyspark.sql.functions.udtf` to create this instance.
+    """
+
+    def __init__(
+        self,
+        func: Type,
+        returnType: Union[StructType, str],
+        name: Optional[str] = None,
+        evalType: int = PythonEvalType.SQL_TABLE_UDF,
+        deterministic: bool = True,
+    ) -> None:
+        self.func = func
+        self.returnType: DataType = (
+            UnparsedDataType(returnType) if isinstance(returnType, str) else 
returnType
+        )
+        self._name = name or func.__name__
+        self.evalType = evalType
+        self.deterministic = deterministic
+
+        _validate_udtf_handler(func)
+
+    def _build_common_inline_user_defined_table_function(
+        self, *cols: "ColumnOrName"
+    ) -> CommonInlineUserDefinedTableFunction:
+        arg_cols = [
+            col if isinstance(col, Column) else Column(ColumnReference(col)) 
for col in cols
+        ]
+        arg_exprs = [col._expr for col in arg_cols]
+
+        udtf = PythonUDTF(
+            func=self.func,
+            return_type=self.returnType,
+            eval_type=self.evalType,
+            python_ver=get_python_ver(),
+        )
+        return CommonInlineUserDefinedTableFunction(
+            function_name=self._name,
+            function=udtf,
+            deterministic=self.deterministic,
+            arguments=arg_exprs,
+        )
+
+    def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
+        from pyspark.sql.connect.dataframe import DataFrame
+        from pyspark.sql.connect.session import _active_spark_session
+
+        if _active_spark_session is None:
+            raise PySparkRuntimeError(
+                "An active SparkSession is required for "
+                "executing a Python user-defined table function."
+            )
+
+        plan = self._build_common_inline_user_defined_table_function(*cols)
+        return DataFrame.withPlan(plan, _active_spark_session)
+
+    def asNondeterministic(self) -> "UserDefinedTableFunction":
+        self.deterministic = False
+        return self
+
+
+class UDTFRegistration:
+    """
+    Wrapper for user-defined table function registration.
+
+    .. versionadded:: 3.5.0
+    """
+
+    def __init__(self, sparkSession: "SparkSession"):
+        self.sparkSession = sparkSession
+
+    def register(
+        self,
+        name: str,
+        f: "UserDefinedTableFunction",
+    ) -> "UserDefinedTableFunction":
+        if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, 
PythonEvalType.SQL_ARROW_TABLE_UDF]:
+            raise PySparkTypeError(
+                error_class="INVALID_UDTF_EVAL_TYPE",
+                message_parameters={
+                    "name": name,
+                    "eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",
+                },
+            )
+
+        self.sparkSession._client.register_udtf(
+            f.func, f.returnType, name, f.evalType, f.deterministic
+        )
+        return f
+
+    register.__doc__ = PySparkUDTFRegistration.register.__doc__
diff --git a/python/pyspark/sql/connect/utils.py 
b/python/pyspark/sql/connect/utils.py
index 25a94676551..8872ba50633 100644
--- a/python/pyspark/sql/connect/utils.py
+++ b/python/pyspark/sql/connect/utils.py
@@ -52,3 +52,7 @@ def require_minimum_grpc_version() -> None:
             "grpcio >= %s must be installed; however, "
             "your version was %s." % (minimum_grpc_version, grpc.__version__)
         )
+
+
+def get_python_ver() -> str:
+    return "%d.%d" % sys.version_info[:2]
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b2017627598..f566fcee0e3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -15510,12 +15510,13 @@ def udf(
         return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
 
 
+@try_remote_functions
 def udtf(
     cls: Optional[Type] = None,
     *,
     returnType: Union[StructType, str],
     useArrow: Optional[bool] = None,
-) -> Union[UserDefinedTableFunction, functools.partial]:
+) -> Union["UserDefinedTableFunction", Callable[[Type], 
"UserDefinedTableFunction"]]:
     """Creates a user defined table function (UDTF).
 
     .. versionadded:: 3.5.0
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 268011ef1e4..c1235620990 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -27,7 +27,6 @@ from collections import defaultdict
 
 from pyspark.errors import (
     PySparkAttributeError,
-    PySparkNotImplementedError,
     PySparkTypeError,
     PySparkException,
     PySparkValueError,
@@ -3186,16 +3185,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         rows = [cols] * row_count
         self.assertEqual(row_count, 
self.connect.createDataFrame(data=rows).count())
 
-    def test_unsupported_udtf(self):
-        with self.assertRaises(PySparkNotImplementedError) as e:
-            self.connect.udtf.register()
-
-        self.check_error(
-            exception=e.exception,
-            error_class="NOT_IMPLEMENTED",
-            message_parameters={"feature": "udtf()"},
-        )
-
     def test_unsupported_jvm_attribute(self):
         # Unsupported jvm attributes for Spark session.
         unsupported_attrs = ["_jsc", "_jconf", "_jvm", "_jsparkSession"]
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index 3e3b4dd5b16..a5d330fe1a7 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -20,6 +20,7 @@ import unittest
 from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.sql import SparkSession as PySparkSession
 from pyspark.sql.types import StringType, StructType, StructField, ArrayType, 
IntegerType
+from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 from pyspark.testing.connectutils import ReusedConnectTestCase, 
should_test_connect
 from pyspark.testing.sqlutils import SQLTestUtils
@@ -2342,6 +2343,23 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, 
PandasOnSparkTestUtils, S
             sdf.withColumn("A", sfun(sdf.c)).toPandas(),
         )
 
+    def test_udtf(self):
+        class TestUDTF:
+            def eval(self, x: int, y: int):
+                yield x, x + 1
+                yield y, y + 1
+
+        sfunc = SF.udtf(TestUDTF, returnType="a: int, b: int")
+        cfunc = CF.udtf(TestUDTF, returnType="a: int, b: int")
+
+        assertDataFrameEqual(sfunc(SF.lit(1), SF.lit(1)), cfunc(CF.lit(1), 
CF.lit(1)))
+
+        self.spark.udtf.register("test_udtf", sfunc)
+        self.connect.udtf.register("test_udtf", cfunc)
+
+        query = "select * from test_udtf(1, 2)"
+        assertDataFrameEqual(self.spark.sql(query), self.connect.sql(query))
+
     def test_pandas_udf_import(self):
         self.assert_eq(getattr(CF, "pandas_udf"), getattr(SF, "pandas_udf"))
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
new file mode 100644
index 00000000000..f5f37f1f5c0
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -0,0 +1,141 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from pyspark.testing.connectutils import should_test_connect
+
+if should_test_connect:
+    from pyspark import sql
+    from pyspark.sql.connect.udtf import UserDefinedTableFunction
+
+    sql.udtf.UserDefinedTableFunction = UserDefinedTableFunction
+
+from pyspark.sql.connect.functions import lit, udtf
+from pyspark.sql.tests.test_udtf import BaseUDTFTestsMixin, UDTFArrowTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.errors.exceptions.connect import SparkConnectGrpcException
+
+
+class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase):
+    @classmethod
+    def setUpClass(cls):
+        super(UDTFParityTests, cls).setUpClass()
+        cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", 
"false")
+
+    @classmethod
+    def tearDownClass(cls):
+        try:
+            
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
+        finally:
+            super(UDTFParityTests, cls).tearDownClass()
+
+    # TODO: use PySpark error classes instead of SparkConnectGrpcException
+
+    def test_udtf_with_invalid_return_type(self):
+        @udtf(returnType="int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a + 1,
+
+        with self.assertRaisesRegex(
+            SparkConnectGrpcException, "Invalid Python user-defined table 
function return type."
+        ):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_with_wrong_num_output(self):
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
+            + "values required by the schema."
+        )
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
+            TestUDTF(lit(1)).collect()
+
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
+            TestUDTF(lit(1)).collect()
+
+    def test_udtf_terminate_with_wrong_num_output(self):
+        err_msg = (
+            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
+            "values required by the schema."
+        )
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                yield 1, 2, 3
+
+        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
+            TestUDTF(lit(1)).show()
+
+        @udtf(returnType="a: int, b: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a, a + 1
+
+            def terminate(self):
+                yield 1,
+
+        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
+            TestUDTF(lit(1)).show()
+
+    def test_udtf_with_empty_yield(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield
+
+        with self.assertRaisesRegex(SparkConnectGrpcException, 
"java.lang.NullPointerException"):
+            TestUDTF(lit(1)).collect()
+
+
+class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
+    @classmethod
+    def setUpClass(cls):
+        super(ArrowUDTFParityTests, cls).setUpClass()
+        cls.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", 
"true")
+
+    @classmethod
+    def tearDownClass(cls):
+        try:
+            
cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
+        finally:
+            super(ArrowUDTFParityTests, cls).tearDownClass()
+
+
+if __name__ == "__main__":
+    import unittest
+    from pyspark.sql.tests.connect.test_parity_udtf import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 8f42b4123e4..f109302dec5 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -30,6 +30,8 @@ from pyspark.errors import (
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.functions import lit, udf, udtf
 from pyspark.sql.types import Row
+from pyspark.testing import assertDataFrameEqual
+from pyspark.sql.types import MapType, StringType, IntegerType
 from pyspark.testing.sqlutils import (
     have_pandas,
     have_pyarrow,
@@ -207,8 +209,8 @@ class BaseUDTFTestsMixin:
         self.assertEqual(TestUDTF(lit(1)).collect(), [Row(a=1), Row(a=None)])
         df = self.spark.createDataFrame([(0, 1), (1, 2)], schema=["a", "b"])
         self.assertEqual(TestUDTF(lit(1)).join(df, "a", "inner").collect(), 
[Row(a=1, b=2)])
-        self.assertEqual(
-            TestUDTF(lit(1)).join(df, "a", "left").collect(), [Row(a=None, 
b=None), Row(a=1, b=2)]
+        assertDataFrameEqual(
+            TestUDTF(lit(1)).join(df, "a", "left"), [Row(a=None, b=None), 
Row(a=1, b=2)]
         )
 
     def test_udtf_with_none_input(self):
@@ -381,6 +383,38 @@ class BaseUDTFTestsMixin:
         ):
             TestUDTF(rand(0) * 100).collect()
 
+    def test_udtf_with_invalid_return_type(self):
+        @udtf(returnType="int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a + 1,
+
+        with self.assertRaises(PySparkTypeError) as e:
+            TestUDTF(lit(1)).collect()
+
+        self.check_error(
+            exception=e.exception,
+            error_class="UDTF_RETURN_TYPE_MISMATCH",
+            message_parameters={"name": "TestUDTF", "return_type": 
"IntegerType()"},
+        )
+
+        @udtf(returnType=MapType(StringType(), IntegerType()))
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a + 1,
+
+        with self.assertRaises(PySparkTypeError) as e:
+            TestUDTF(lit(1)).collect()
+
+        self.check_error(
+            exception=e.exception,
+            error_class="UDTF_RETURN_TYPE_MISMATCH",
+            message_parameters={
+                "name": "TestUDTF",
+                "return_type": "MapType(StringType(), IntegerType(), True)",
+            },
+        )
+
     def test_udtf_with_struct_input_type(self):
         @udtf(returnType="x: string")
         class TestUDTF:
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index b0f373430be..3ab74193093 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -19,7 +19,7 @@ User-defined table function related classes and functions
 """
 import sys
 import warnings
-from typing import Iterator, Type, TYPE_CHECKING, Optional, Union
+from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union
 
 from py4j.java_gateway import JavaObject
 
@@ -77,27 +77,35 @@ def _create_py_udtf(
     # Create a regular Python UDTF and check for invalid handler class.
     regular_udtf = _create_udtf(cls, returnType, name, 
PythonEvalType.SQL_TABLE_UDF, deterministic)
 
-    if arrow_enabled:
-        try:
-            require_minimum_pandas_version()
-            require_minimum_pyarrow_version()
-        except ImportError as e:
-            warnings.warn(
-                f"Arrow optimization for Python UDTFs cannot be enabled: 
{str(e)}. "
-                f"Falling back to using regular Python UDTFs.",
-                UserWarning,
-            )
-            return regular_udtf
-        return _create_arrow_udtf(regular_udtf)
-    else:
+    if not arrow_enabled:
         return regular_udtf
 
+    # Return the regular UDTF if the required dependencies are not satisfied.
+    try:
+        require_minimum_pandas_version()
+        require_minimum_pyarrow_version()
+    except ImportError as e:
+        warnings.warn(
+            f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. 
"
+            f"Falling back to using regular Python UDTFs.",
+            UserWarning,
+        )
+        return regular_udtf
 
-def _create_arrow_udtf(regular_udtf: "UserDefinedTableFunction") -> 
"UserDefinedTableFunction":
-    """Create an Arrow-optimized Python UDTF."""
-    import pandas as pd
+    # Return the vectorized UDTF.
+    vectorized_udtf = _vectorize_udtf(cls)
+    return _create_udtf(
+        cls=vectorized_udtf,
+        returnType=returnType,
+        name=name,
+        evalType=PythonEvalType.SQL_ARROW_TABLE_UDF,
+        deterministic=regular_udtf.deterministic,
+    )
 
-    cls = regular_udtf.func
+
+def _vectorize_udtf(cls: Type) -> Type:
+    """Vectorize a Python UDTF handler class."""
+    import pandas as pd
 
     class VectorizedUDTF:
         def __init__(self) -> None:
@@ -126,13 +134,24 @@ def _create_arrow_udtf(regular_udtf: 
"UserDefinedTableFunction") -> "UserDefined
     if hasattr(cls, "terminate"):
         getattr(vectorized_udtf, "terminate").__doc__ = getattr(cls, 
"terminate").__doc__
 
-    return _create_udtf(
-        cls=vectorized_udtf,
-        returnType=regular_udtf.returnType,
-        name=regular_udtf._name,
-        evalType=PythonEvalType.SQL_ARROW_TABLE_UDF,
-        deterministic=regular_udtf.deterministic,
-    )
+    return vectorized_udtf
+
+
+def _validate_udtf_handler(cls: Any) -> None:
+    """Validate the handler class of a UDTF."""
+    # TODO(SPARK-43968): add more compile time checks for UDTFs
+
+    if not isinstance(cls, type):
+        raise PySparkTypeError(
+            f"Invalid user defined table function: the function handler "
+            f"must be a class, but got {type(cls).__name__}. Please provide "
+            "a class as the handler."
+        )
+
+    if not hasattr(cls, "eval"):
+        raise PySparkAttributeError(
+            error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": 
cls.__name__}
+        )
 
 
 class UserDefinedTableFunction:
@@ -157,15 +176,6 @@ class UserDefinedTableFunction:
         evalType: int = PythonEvalType.SQL_TABLE_UDF,
         deterministic: bool = True,
     ):
-
-        if not isinstance(func, type):
-            raise PySparkTypeError(
-                f"Invalid user defined table function: the function handler "
-                f"must be a class, but got {type(func).__name__}. Please 
provide "
-                "a class as the handler."
-            )
-
-        # TODO(SPARK-43968): add more compile time checks for UDTFs
         self.func = func
         self._returnType = returnType
         self._returnType_placeholder: Optional[StructType] = None
@@ -175,29 +185,26 @@ class UserDefinedTableFunction:
         self.evalType = evalType
         self.deterministic = deterministic
 
-        if not hasattr(func, "eval"):
-            raise PySparkAttributeError(
-                error_class="INVALID_UDTF_NO_EVAL", 
message_parameters={"name": self._name}
-            )
+        _validate_udtf_handler(func)
 
     @property
     def returnType(self) -> StructType:
         # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted 
string.
         # This makes sure this is called after SparkContext is initialized.
         if self._returnType_placeholder is None:
-            if isinstance(self._returnType, StructType):
-                self._returnType_placeholder = self._returnType
-            else:
-                assert isinstance(self._returnType, str)
+            if isinstance(self._returnType, str):
                 parsed = _parse_datatype_string(self._returnType)
-                if not isinstance(parsed, StructType):
-                    raise PySparkTypeError(
-                        f"Invalid return type for the user defined table 
function "
-                        f"'{self._name}': {self._returnType}. The return type 
of a "
-                        f"UDTF must be a 'StructType'. Please ensure the 
return "
-                        "type is a correctly formatted 'StructType' string."
-                    )
-                self._returnType_placeholder = parsed
+            else:
+                parsed = self._returnType
+            if not isinstance(parsed, StructType):
+                raise PySparkTypeError(
+                    error_class="UDTF_RETURN_TYPE_MISMATCH",
+                    message_parameters={
+                        "name": self._name,
+                        "return_type": f"{parsed}",
+                    },
+                )
+            self._returnType_placeholder = parsed
         return self._returnType_placeholder
 
     @property
@@ -254,8 +261,8 @@ class UDTFRegistration:
     def register(
         self,
         name: str,
-        f: UserDefinedTableFunction,
-    ) -> UserDefinedTableFunction:
+        f: "UserDefinedTableFunction",
+    ) -> "UserDefinedTableFunction":
         """Register a Python user-defined table function as a SQL table 
function.
 
         .. versionadded:: 3.5.0


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to