This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ce41ca0848e [SPARK-41343][CONNECT] Move FunctionName parsing to server side ce41ca0848e is described below commit ce41ca0848e740026048aa08cb1062cc4d5082d1 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Thu Dec 1 13:27:03 2022 +0800 [SPARK-41343][CONNECT] Move FunctionName parsing to server side ### What changes were proposed in this pull request? This PR propose to change the name of `UnresolvedFunction` from a sequence of name parts to a single name string, which help to move the function name parsing to server side. For built-in functions, there is no need to even call SQL parser to parse the name (built-in functions should not belong to any catalog or database). ### Why are the changes needed? This will help reduce redundant implementation on client sides to parse function names. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? UT Closes #38854 from amaliujia/function_name_parse. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/protobuf/spark/connect/expressions.proto | 10 ++++++-- .../org/apache/spark/sql/connect/dsl/package.scala | 11 +++++---- .../sql/connect/planner/SparkConnectPlanner.scala | 22 +++++++++-------- .../connect/planner/SparkConnectPlannerSuite.scala | 4 ++-- .../connect/planner/SparkConnectProtoSuite.scala | 4 ---- .../connect/planner/SparkConnectServiceSuite.scala | 2 +- python/pyspark/sql/connect/column.py | 2 +- .../pyspark/sql/connect/proto/expressions_pb2.py | 20 ++++++++-------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 28 +++++++++++++++------- .../connect/test_connect_column_expressions.py | 4 ++-- .../sql/tests/connect/test_connect_plan_only.py | 2 +- 11 files changed, 63 insertions(+), 46 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto index b90f7619b8f..1b93c342381 100644 --- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto @@ -126,11 +126,17 @@ message Expression { // An unresolved function is not explicitly bound to one explicit function, but the function // is resolved during analysis following Sparks name resolution rules. message UnresolvedFunction { - // (Required) Names parts for the unresolved function. - repeated string parts = 1; + // (Required) name (or unparsed name for user defined function) for the unresolved function. + string function_name = 1; // (Optional) Function arguments. Empty arguments are allowed. repeated Expression arguments = 2; + + // (Required) Indicate if this is a user defined function. + // + // When it is not a user defined function, Connect will use the function name directly. + // When it is a user defined function, Connect will parse the function name first. + bool is_user_defined_function = 3; } // Expression as string. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 654a4d5ce20..1342842cbc9 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -83,7 +83,7 @@ package object dsl { .setUnresolvedFunction( Expression.UnresolvedFunction .newBuilder() - .addParts("<") + .setFunctionName("<") .addArguments(expr) .addArguments(other)) .build() @@ -93,14 +93,14 @@ package object dsl { Expression .newBuilder() .setUnresolvedFunction( - Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e)) + Expression.UnresolvedFunction.newBuilder().setFunctionName("min").addArguments(e)) .build() def proto_explode(e: Expression): Expression = Expression .newBuilder() .setUnresolvedFunction( - Expression.UnresolvedFunction.newBuilder().addParts("explode").addArguments(e)) + Expression.UnresolvedFunction.newBuilder().setFunctionName("explode").addArguments(e)) .build() /** @@ -117,7 +117,8 @@ package object dsl { .setUnresolvedFunction( Expression.UnresolvedFunction .newBuilder() - .addAllParts(nameParts.asJava) + .setFunctionName(nameParts.mkString(".")) + .setIsUserDefinedFunction(true) .addAllArguments(args.asJava)) .build() } @@ -136,7 +137,7 @@ package object dsl { .setUnresolvedFunction( Expression.UnresolvedFunction .newBuilder() - .addParts(name) + .setFunctionName(name) .addAllArguments(args.asJava)) .build() } diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6b11cbea7a5..b50fe1b1b60 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -27,9 +27,8 @@ import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.WriteOperation import org.apache.spark.sql.{Column, Dataset, SparkSession} -import org.apache.spark.sql.catalyst.AliasIdentifier +import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} @@ -539,14 +538,17 @@ class SparkConnectPlanner(session: SparkSession) { * @return */ private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = { - if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) { - throw new IllegalArgumentException( - "Function identifier must be passed as sequence of name parts.") - } - UnresolvedFunction( - fun.getPartsList.asScala.toSeq, - fun.getArgumentsList.asScala.map(transformExpression).toSeq, - isDistinct = false) + if (fun.getIsUserDefinedFunction) { + UnresolvedFunction( + session.sessionState.sqlParser.parseFunctionIdentifier(fun.getFunctionName), + fun.getArgumentsList.asScala.map(transformExpression).toSeq, + isDistinct = false) + } else { + UnresolvedFunction( + FunctionIdentifier(fun.getFunctionName), + fun.getArgumentsList.asScala.map(transformExpression).toSeq, + isDistinct = false) + } } private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 81e5ee3d0ce..362973a90ef 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -235,7 +235,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction( proto.Expression.UnresolvedFunction.newBuilder - .addAllParts(Seq("==").asJava) + .setFunctionName("==") .addArguments(unresolvedAttribute) .addArguments(unresolvedAttribute) .build()) @@ -296,7 +296,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { .setUnresolvedFunction( proto.Expression.UnresolvedFunction .newBuilder() - .addParts("sum") + .setFunctionName("sum") .addArguments(unresolvedAttribute)) .build() diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 86253b0016c..5d2bf1d57b2 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -91,10 +91,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } test("UnresolvedFunction resolution.") { - assertThrows[IllegalArgumentException] { - transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr)))) - } - val connectPlan = connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr))) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 9724c876b1b..8f4268b904b 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -211,7 +211,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { .setUnresolvedFunction( proto.Expression.UnresolvedFunction .newBuilder() - .addParts("abs") + .setFunctionName("abs") .addArguments(proto.Expression .newBuilder() .setLiteral(proto.Expression.Literal.newBuilder().setInteger(-1))))) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 18ea0961979..f678112a7f4 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -347,7 +347,7 @@ class ScalarFunctionExpression(Expression): def to_plan(self, session: "SparkConnectClient") -> proto.Expression: fun = proto.Expression() - fun.unresolved_function.parts.append(self._op) + fun.unresolved_function.function_name = self._op fun.unresolved_function.arguments.extend([x.to_plan(session) for x in self._args]) return fun diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index a1d9dcb91b0..d563d829769 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\xc0\x12\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto"\x89\x13\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_st [...] ) @@ -186,7 +186,7 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _EXPRESSION._serialized_start = 78 - _EXPRESSION._serialized_end = 2446 + _EXPRESSION._serialized_end = 2519 _EXPRESSION_LITERAL._serialized_start = 586 _EXPRESSION_LITERAL._serialized_end = 2044 _EXPRESSION_LITERAL_DECIMAL._serialized_start = 1482 @@ -203,12 +203,12 @@ if _descriptor._USE_C_DESCRIPTORS == False: _EXPRESSION_LITERAL_MAP_PAIR._serialized_end = 2028 _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 2046 _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 2116 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2118 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2217 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2219 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2269 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2271 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2311 - _EXPRESSION_ALIAS._serialized_start = 2313 - _EXPRESSION_ALIAS._serialized_end = 2433 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 2119 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 2290 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 2292 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 2342 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 2344 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 2384 + _EXPRESSION_ALIAS._serialized_start = 2386 + _EXPRESSION_ALIAS._serialized_end = 2506 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index ddd9338d85d..b8a36ca6827 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -458,13 +458,11 @@ class Expression(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - PARTS_FIELD_NUMBER: builtins.int + FUNCTION_NAME_FIELD_NUMBER: builtins.int ARGUMENTS_FIELD_NUMBER: builtins.int - @property - def parts( - self, - ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: - """(Required) Names parts for the unresolved function.""" + IS_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int + function_name: builtins.str + """(Required) name (or unparsed name for user defined function) for the unresolved function.""" @property def arguments( self, @@ -472,15 +470,29 @@ class Expression(google.protobuf.message.Message): global___Expression ]: """(Optional) Function arguments. Empty arguments are allowed.""" + is_user_defined_function: builtins.bool + """(Required) Indicate if this is a user defined function. + + When it is not a user defined function, Connect will use the function name directly. + When it is a user defined function, Connect will parse the function name first. + """ def __init__( self, *, - parts: collections.abc.Iterable[builtins.str] | None = ..., + function_name: builtins.str = ..., arguments: collections.abc.Iterable[global___Expression] | None = ..., + is_user_defined_function: builtins.bool = ..., ) -> None: ... def ClearField( self, - field_name: typing_extensions.Literal["arguments", b"arguments", "parts", b"parts"], + field_name: typing_extensions.Literal[ + "arguments", + b"arguments", + "function_name", + b"function_name", + "is_user_defined_function", + b"is_user_defined_function", + ], ) -> None: ... class ExpressionString(google.protobuf.message.Message): diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 6a6e0f84e21..ba7cc026562 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -197,12 +197,12 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): expr = fun.lit(10) < fun.lit(10) expr_plan = expr.to_plan(None) self.assertIsNotNone(expr_plan.unresolved_function) - self.assertEqual(expr_plan.unresolved_function.parts[0], "<") + self.assertEqual(expr_plan.unresolved_function.function_name, "<") expr = df.id % fun.lit(10) == fun.lit(10) expr_plan = expr.to_plan(None) self.assertIsNotNone(expr_plan.unresolved_function) - self.assertEqual(expr_plan.unresolved_function.parts[0], "==") + self.assertEqual(expr_plan.unresolved_function.function_name, "==") lit_fun = expr_plan.unresolved_function.arguments[1] self.assertIsInstance(lit_fun, ProtoExpression) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 60d52a9c05d..bdebd54492e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -77,7 +77,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): plan.root.filter.condition.unresolved_function, proto.Expression.UnresolvedFunction ) ) - self.assertEqual(plan.root.filter.condition.unresolved_function.parts, [">"]) + self.assertEqual(plan.root.filter.condition.unresolved_function.function_name, ">") self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2) def test_filter_with_string_expr(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org