This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ce41ca0848e [SPARK-41343][CONNECT] Move FunctionName parsing to server 
side
ce41ca0848e is described below

commit ce41ca0848e740026048aa08cb1062cc4d5082d1
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Thu Dec 1 13:27:03 2022 +0800

    [SPARK-41343][CONNECT] Move FunctionName parsing to server side
    
    ### What changes were proposed in this pull request?
    
    This PR propose to change the name of `UnresolvedFunction` from a sequence 
of name parts to a single name string, which help to move the function name 
parsing to server side.
    
    For built-in functions, there is no need to even call SQL parser to parse 
the name (built-in functions should not belong to any catalog or database).
    
    ### Why are the changes needed?
    
    This will help reduce redundant implementation on client sides to parse 
function names.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #38854 from amaliujia/function_name_parse.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/protobuf/spark/connect/expressions.proto  | 10 ++++++--
 .../org/apache/spark/sql/connect/dsl/package.scala | 11 +++++----
 .../sql/connect/planner/SparkConnectPlanner.scala  | 22 +++++++++--------
 .../connect/planner/SparkConnectPlannerSuite.scala |  4 ++--
 .../connect/planner/SparkConnectProtoSuite.scala   |  4 ----
 .../connect/planner/SparkConnectServiceSuite.scala |  2 +-
 python/pyspark/sql/connect/column.py               |  2 +-
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 20 ++++++++--------
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 28 +++++++++++++++-------
 .../connect/test_connect_column_expressions.py     |  4 ++--
 .../sql/tests/connect/test_connect_plan_only.py    |  2 +-
 11 files changed, 63 insertions(+), 46 deletions(-)

diff --git 
a/connector/connect/src/main/protobuf/spark/connect/expressions.proto 
b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index b90f7619b8f..1b93c342381 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -126,11 +126,17 @@ message Expression {
   // An unresolved function is not explicitly bound to one explicit function, 
but the function
   // is resolved during analysis following Sparks name resolution rules.
   message UnresolvedFunction {
-    // (Required) Names parts for the unresolved function.
-    repeated string parts = 1;
+    // (Required) name (or unparsed name for user defined function) for the 
unresolved function.
+    string function_name = 1;
 
     // (Optional) Function arguments. Empty arguments are allowed.
     repeated Expression arguments = 2;
+
+    // (Required) Indicate if this is a user defined function.
+    //
+    // When it is not a user defined function, Connect will use the function 
name directly.
+    // When it is a user defined function, Connect will parse the function 
name first.
+    bool is_user_defined_function = 3;
   }
 
   // Expression as string.
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 654a4d5ce20..1342842cbc9 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -83,7 +83,7 @@ package object dsl {
           .setUnresolvedFunction(
             Expression.UnresolvedFunction
               .newBuilder()
-              .addParts("<")
+              .setFunctionName("<")
               .addArguments(expr)
               .addArguments(other))
           .build()
@@ -93,14 +93,14 @@ package object dsl {
       Expression
         .newBuilder()
         .setUnresolvedFunction(
-          
Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e))
+          
Expression.UnresolvedFunction.newBuilder().setFunctionName("min").addArguments(e))
         .build()
 
     def proto_explode(e: Expression): Expression =
       Expression
         .newBuilder()
         .setUnresolvedFunction(
-          
Expression.UnresolvedFunction.newBuilder().addParts("explode").addArguments(e))
+          
Expression.UnresolvedFunction.newBuilder().setFunctionName("explode").addArguments(e))
         .build()
 
     /**
@@ -117,7 +117,8 @@ package object dsl {
         .setUnresolvedFunction(
           Expression.UnresolvedFunction
             .newBuilder()
-            .addAllParts(nameParts.asJava)
+            .setFunctionName(nameParts.mkString("."))
+            .setIsUserDefinedFunction(true)
             .addAllArguments(args.asJava))
         .build()
     }
@@ -136,7 +137,7 @@ package object dsl {
         .setUnresolvedFunction(
           Expression.UnresolvedFunction
             .newBuilder()
-            .addParts(name)
+            .setFunctionName(name)
             .addAllArguments(args.asJava))
         .build()
     }
diff --git 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6b11cbea7a5..b50fe1b1b60 100644
--- 
a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -27,9 +27,8 @@ import org.apache.spark.api.python.{PythonEvalType, 
SimplePythonFunction}
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.WriteOperation
 import org.apache.spark.sql.{Column, Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.AliasIdentifier
+import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, 
MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, 
UnresolvedRelation, UnresolvedStar}
-import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
@@ -539,14 +538,17 @@ class SparkConnectPlanner(session: SparkSession) {
    * @return
    */
   private def transformScalarFunction(fun: 
proto.Expression.UnresolvedFunction): Expression = {
-    if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) {
-      throw new IllegalArgumentException(
-        "Function identifier must be passed as sequence of name parts.")
-    }
-    UnresolvedFunction(
-      fun.getPartsList.asScala.toSeq,
-      fun.getArgumentsList.asScala.map(transformExpression).toSeq,
-      isDistinct = false)
+    if (fun.getIsUserDefinedFunction) {
+      UnresolvedFunction(
+        
session.sessionState.sqlParser.parseFunctionIdentifier(fun.getFunctionName),
+        fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+        isDistinct = false)
+    } else {
+      UnresolvedFunction(
+        FunctionIdentifier(fun.getFunctionName),
+        fun.getArgumentsList.asScala.map(transformExpression).toSeq,
+        isDistinct = false)
+    }
   }
 
   private def transformAlias(alias: proto.Expression.Alias): NamedExpression = 
{
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 81e5ee3d0ce..362973a90ef 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -235,7 +235,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
 
     val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
       proto.Expression.UnresolvedFunction.newBuilder
-        .addAllParts(Seq("==").asJava)
+        .setFunctionName("==")
         .addArguments(unresolvedAttribute)
         .addArguments(unresolvedAttribute)
         .build())
@@ -296,7 +296,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
         .setUnresolvedFunction(
           proto.Expression.UnresolvedFunction
             .newBuilder()
-            .addParts("sum")
+            .setFunctionName("sum")
             .addArguments(unresolvedAttribute))
         .build()
 
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 86253b0016c..5d2bf1d57b2 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -91,10 +91,6 @@ class SparkConnectProtoSuite extends PlanTest with 
SparkConnectPlanTest {
   }
 
   test("UnresolvedFunction resolution.") {
-    assertThrows[IllegalArgumentException] {
-      transform(connectTestRelation.select(callFunction("default.hex", 
Seq("id".protoAttr))))
-    }
-
     val connectPlan =
       connectTestRelation.select(callFunction(Seq("default", "hex"), 
Seq("id".protoAttr)))
 
diff --git 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 9724c876b1b..8f4268b904b 100644
--- 
a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ 
b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -211,7 +211,7 @@ class SparkConnectServiceSuite extends SharedSparkSession {
                 .setUnresolvedFunction(
                   proto.Expression.UnresolvedFunction
                     .newBuilder()
-                    .addParts("abs")
+                    .setFunctionName("abs")
                     .addArguments(proto.Expression
                       .newBuilder()
                       
.setLiteral(proto.Expression.Literal.newBuilder().setInteger(-1)))))
diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 18ea0961979..f678112a7f4 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -347,7 +347,7 @@ class ScalarFunctionExpression(Expression):
 
     def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
         fun = proto.Expression()
-        fun.unresolved_function.parts.append(self._op)
+        fun.unresolved_function.function_name = self._op
         fun.unresolved_function.arguments.extend([x.to_plan(session) for x in 
self._args])
         return fun
 
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py 
b/python/pyspark/sql/connect/proto/expressions_pb2.py
index a1d9dcb91b0..d563d829769 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xc0\x12\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
+    
b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\x89\x13\n\nExpression\x12=\n\x07literal\x18\x01
 
\x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02
 
\x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03
 
\x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st
 [...]
 )
 
 
@@ -186,7 +186,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = 
b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 78
-    _EXPRESSION._serialized_end = 2446
+    _EXPRESSION._serialized_end = 2519
     _EXPRESSION_LITERAL._serialized_start = 586
     _EXPRESSION_LITERAL._serialized_end = 2044
     _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1482
@@ -203,12 +203,12 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2028
     _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2046
     _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2116
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2118
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2217
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2219
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2269
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2271
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2311
-    _EXPRESSION_ALIAS._serialized_start = 2313
-    _EXPRESSION_ALIAS._serialized_end = 2433
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2119
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2290
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2292
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2342
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2344
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2384
+    _EXPRESSION_ALIAS._serialized_start = 2386
+    _EXPRESSION_ALIAS._serialized_end = 2506
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi 
b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index ddd9338d85d..b8a36ca6827 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -458,13 +458,11 @@ class Expression(google.protobuf.message.Message):
 
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-        PARTS_FIELD_NUMBER: builtins.int
+        FUNCTION_NAME_FIELD_NUMBER: builtins.int
         ARGUMENTS_FIELD_NUMBER: builtins.int
-        @property
-        def parts(
-            self,
-        ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
-            """(Required) Names parts for the unresolved function."""
+        IS_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
+        function_name: builtins.str
+        """(Required) name (or unparsed name for user defined function) for 
the unresolved function."""
         @property
         def arguments(
             self,
@@ -472,15 +470,29 @@ class Expression(google.protobuf.message.Message):
             global___Expression
         ]:
             """(Optional) Function arguments. Empty arguments are allowed."""
+        is_user_defined_function: builtins.bool
+        """(Required) Indicate if this is a user defined function.
+
+        When it is not a user defined function, Connect will use the function 
name directly.
+        When it is a user defined function, Connect will parse the function 
name first.
+        """
         def __init__(
             self,
             *,
-            parts: collections.abc.Iterable[builtins.str] | None = ...,
+            function_name: builtins.str = ...,
             arguments: collections.abc.Iterable[global___Expression] | None = 
...,
+            is_user_defined_function: builtins.bool = ...,
         ) -> None: ...
         def ClearField(
             self,
-            field_name: typing_extensions.Literal["arguments", b"arguments", 
"parts", b"parts"],
+            field_name: typing_extensions.Literal[
+                "arguments",
+                b"arguments",
+                "function_name",
+                b"function_name",
+                "is_user_defined_function",
+                b"is_user_defined_function",
+            ],
         ) -> None: ...
 
     class ExpressionString(google.protobuf.message.Message):
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py 
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 6a6e0f84e21..ba7cc026562 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -197,12 +197,12 @@ class 
SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         expr = fun.lit(10) < fun.lit(10)
         expr_plan = expr.to_plan(None)
         self.assertIsNotNone(expr_plan.unresolved_function)
-        self.assertEqual(expr_plan.unresolved_function.parts[0], "<")
+        self.assertEqual(expr_plan.unresolved_function.function_name, "<")
 
         expr = df.id % fun.lit(10) == fun.lit(10)
         expr_plan = expr.to_plan(None)
         self.assertIsNotNone(expr_plan.unresolved_function)
-        self.assertEqual(expr_plan.unresolved_function.parts[0], "==")
+        self.assertEqual(expr_plan.unresolved_function.function_name, "==")
 
         lit_fun = expr_plan.unresolved_function.arguments[1]
         self.assertIsInstance(lit_fun, ProtoExpression)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py 
b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index 60d52a9c05d..bdebd54492e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -77,7 +77,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
                 plan.root.filter.condition.unresolved_function, 
proto.Expression.UnresolvedFunction
             )
         )
-        self.assertEqual(plan.root.filter.condition.unresolved_function.parts, 
[">"])
+        
self.assertEqual(plan.root.filter.condition.unresolved_function.function_name, 
">")
         
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 
2)
 
     def test_filter_with_string_expr(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to