This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new f1377a856e8 [SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support qualified function name for call_function f1377a856e8 is described below commit f1377a856e85977aafe3bf13cce1da7b4d4ed195 Author: Jiaan Geng <belie...@163.com> AuthorDate: Tue Jul 25 08:54:00 2023 +0800 [SPARK-44131][SQL][PYTHON][CONNECT][FOLLOWUP] Support qualified function name for call_function ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/41687 added `call_function` and deprecate `call_udf` for Scala API. Some times, the function name can be qualified, we should let users use it to invoke persistent functions as well. ### Why are the changes needed? Support qualified function name for `call_function`. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #41932 from beliefer/SPARK-44131_followup. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit d97a4e214c7e11bcc9b7d6e126bf06e214a29988) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../scala/org/apache/spark/sql/functions.scala | 10 +- .../spark/sql/application/ReplE2ESuite.scala | 10 ++ .../main/protobuf/spark/connect/expressions.proto | 9 ++ .../queries/function_call_function.json | 2 +- .../queries/function_call_function.proto.bin | Bin 174 -> 175 bytes .../sql/connect/planner/SparkConnectPlanner.scala | 19 ++++ python/pyspark/sql/connect/expressions.py | 24 +++++ python/pyspark/sql/connect/functions.py | 6 +- .../pyspark/sql/connect/proto/expressions_pb2.py | 118 +++++++++++---------- .../pyspark/sql/connect/proto/expressions_pb2.pyi | 36 +++++++ python/pyspark/sql/functions.py | 23 +++- .../scala/org/apache/spark/sql/functions.scala | 22 ++-- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 20 ++++ .../spark/sql/hive/execution/HiveUDFSuite.scala | 15 ++- 14 files changed, 238 insertions(+), 76 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 17d1cdca350..eac3f652320 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7923,15 +7923,19 @@ object functions { def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** - * Call a builtin or temp function. + * Call a SQL function. * * @param funcName - * function name + * function name that follows the SQL identifier syntax (can be quoted, can be qualified) * @param cols * the expression parameters of function * @since 3.5.0 */ @scala.annotation.varargs - def call_function(funcName: String, cols: Column*): Column = Column.fn(funcName, cols: _*) + def call_function(funcName: String, cols: Column*): Column = Column { builder => + builder.getCallFunctionBuilder + .setFunctionName(funcName) + .addAllArguments(cols.map(_.expr).asJava) + } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 800ce43a60d..ad2ca383e4f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -239,4 +239,14 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { val output = runCommandsInShell(input) assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) } + + test("call_function") { + val input = """ + |val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") + |spark.udf.register("simpleUDF", (v: Int) => v * v) + |df.select($"id", call_function("simpleUDF", $"value")).collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], [id3,25])", output) + } } diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 37a8778865d..557b9db9123 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -46,6 +46,7 @@ message Expression { UpdateFields update_fields = 13; UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14; CommonInlineUserDefinedFunction common_inline_user_defined_function = 15; + CallFunction call_function = 16; // This field is used to mark extensions to the protocol. When plugins generate arbitrary // relations they can add them here. During the planning the correct resolution is done. @@ -371,3 +372,11 @@ message JavaUDF { // (Required) Indicate if the Java user-defined function is an aggregate function bool aggregate = 3; } + +message CallFunction { + // (Required) Unparsed name of the SQL function. + string function_name = 1; + + // (Optional) Function arguments. Empty arguments are allowed. + repeated Expression arguments = 2; +} diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json index f7fe5beba2c..6db0a614682 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json @@ -12,7 +12,7 @@ } }, "expressions": [{ - "unresolvedFunction": { + "callFunction": { "functionName": "lower", "arguments": [{ "unresolvedAttribute": { diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin index 7c736d93f77..ef985e42131 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin differ diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 92a9524f67a..36037cce7eb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1380,6 +1380,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { transformExpressionPlugin(exp.getExtension) case proto.Expression.ExprTypeCase.COMMON_INLINE_USER_DEFINED_FUNCTION => transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction) + case proto.Expression.ExprTypeCase.CALL_FUNCTION => + transformCallFunction(exp.getCallFunction) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -1484,6 +1486,23 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } } + /** + * Translates a SQL function from proto to the Catalyst expression. + * + * @param fun + * Proto representation of the function call. + * @return + * Expression. + */ + private def transformCallFunction(fun: proto.CallFunction): Expression = { + val funcName = fun.getFunctionName + val nameParts = session.sessionState.sqlParser.parseMultipartIdentifier(funcName) + UnresolvedFunction( + nameParts, + fun.getArgumentsList.asScala.map(transformExpression).toSeq, + false) + } + private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = { Utils.deserialize[UdfPacket]( fun.getScalarScalaUdf.getPayload.toByteArray, diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index e1b648c7bb8..44e6e174f70 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -1027,3 +1027,27 @@ class DistributedSequenceID(Expression): def __repr__(self) -> str: return "DistributedSequenceID()" + + +class CallFunction(Expression): + def __init__(self, name: str, args: Sequence["Expression"]): + super().__init__() + + assert isinstance(name, str) + self._name = name + + assert isinstance(args, list) and all(isinstance(arg, Expression) for arg in args) + self._args = args + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + expr = proto.Expression() + expr.call_function.function_name = self._name + if len(self._args) > 0: + expr.call_function.arguments.extend([arg.to_plan(session) for arg in self._args]) + return expr + + def __repr__(self) -> str: + if len(self._args) > 0: + return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})" + else: + return f"CallFunction('{self._name}')" diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index a1c0516ee0d..a92f89c0f6c 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -51,6 +51,7 @@ from pyspark.sql.connect.expressions import ( SQLExpression, LambdaFunction, UnresolvedNamedLambdaVariable, + CallFunction, ) from pyspark.sql.connect.udf import _create_py_udf from pyspark.sql.connect.udtf import _create_py_udtf @@ -3909,8 +3910,9 @@ def udtf( udtf.__doc__ = pysparkfuncs.udtf.__doc__ -def call_function(udfName: str, *cols: "ColumnOrName") -> Column: - return _invoke_function(udfName, *[_to_col(c) for c in cols]) +def call_function(funcName: str, *cols: "ColumnOrName") -> Column: + expressions = [_to_col(c)._expr for c in cols] + return Column(CallFunction(funcName, expressions)) call_function.__doc__ = pysparkfuncs.call_function.__doc__ diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 7a68d831a99..51d1a5d48a1 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x95+\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xd9+\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -45,61 +45,63 @@ if _descriptor._USE_C_DESCRIPTORS == False: b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 5630 - _EXPRESSION_WINDOW._serialized_start = 1475 - _EXPRESSION_WINDOW._serialized_end = 2258 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2258 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2032 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2177 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2179 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2258 - _EXPRESSION_SORTORDER._serialized_start = 2261 - _EXPRESSION_SORTORDER._serialized_end = 2686 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2491 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2599 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2601 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2686 - _EXPRESSION_CAST._serialized_start = 2689 - _EXPRESSION_CAST._serialized_end = 2834 - _EXPRESSION_LITERAL._serialized_start = 2837 - _EXPRESSION_LITERAL._serialized_end = 4400 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3672 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3789 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3791 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3889 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 3892 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 4022 - _EXPRESSION_LITERAL_MAP._serialized_start = 4025 - _EXPRESSION_LITERAL_MAP._serialized_end = 4252 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 4255 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 4384 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4402 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4514 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4517 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4721 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4723 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4773 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4775 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4857 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4859 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4945 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4948 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5080 - _EXPRESSION_UPDATEFIELDS._serialized_start = 5083 - _EXPRESSION_UPDATEFIELDS._serialized_end = 5270 - _EXPRESSION_ALIAS._serialized_start = 5272 - _EXPRESSION_ALIAS._serialized_end = 5392 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5395 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5553 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5555 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5617 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5633 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5997 - _PYTHONUDF._serialized_start = 6000 - _PYTHONUDF._serialized_end = 6155 - _SCALARSCALAUDF._serialized_start = 6158 - _SCALARSCALAUDF._serialized_end = 6342 - _JAVAUDF._serialized_start = 6345 - _JAVAUDF._serialized_end = 6494 + _EXPRESSION._serialized_end = 5698 + _EXPRESSION_WINDOW._serialized_start = 1543 + _EXPRESSION_WINDOW._serialized_end = 2326 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1833 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2326 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2100 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2245 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2247 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2326 + _EXPRESSION_SORTORDER._serialized_start = 2329 + _EXPRESSION_SORTORDER._serialized_end = 2754 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2559 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2667 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2669 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2754 + _EXPRESSION_CAST._serialized_start = 2757 + _EXPRESSION_CAST._serialized_end = 2902 + _EXPRESSION_LITERAL._serialized_start = 2905 + _EXPRESSION_LITERAL._serialized_end = 4468 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3740 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3857 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3859 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3957 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 3960 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 4090 + _EXPRESSION_LITERAL_MAP._serialized_start = 4093 + _EXPRESSION_LITERAL_MAP._serialized_end = 4320 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 4323 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 4452 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4470 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4582 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4585 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4789 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4791 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4841 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4843 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4925 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4927 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5013 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5016 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5148 + _EXPRESSION_UPDATEFIELDS._serialized_start = 5151 + _EXPRESSION_UPDATEFIELDS._serialized_end = 5338 + _EXPRESSION_ALIAS._serialized_start = 5340 + _EXPRESSION_ALIAS._serialized_end = 5460 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5463 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5621 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5623 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5685 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5701 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6065 + _PYTHONUDF._serialized_start = 6068 + _PYTHONUDF._serialized_end = 6223 + _SCALARSCALAUDF._serialized_start = 6226 + _SCALARSCALAUDF._serialized_end = 6410 + _JAVAUDF._serialized_start = 6413 + _JAVAUDF._serialized_end = 6562 + _CALLFUNCTION._serialized_start = 6564 + _CALLFUNCTION._serialized_end = 6672 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index bef87203b55..b9b16ce35e3 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1101,6 +1101,7 @@ class Expression(google.protobuf.message.Message): UPDATE_FIELDS_FIELD_NUMBER: builtins.int UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int + CALL_FUNCTION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @@ -1135,6 +1136,8 @@ class Expression(google.protobuf.message.Message): @property def common_inline_user_defined_function(self) -> global___CommonInlineUserDefinedFunction: ... @property + def call_function(self) -> global___CallFunction: ... + @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary relations they can add them here. During the planning the correct resolution is done. @@ -1158,6 +1161,7 @@ class Expression(google.protobuf.message.Message): unresolved_named_lambda_variable: global___Expression.UnresolvedNamedLambdaVariable | None = ..., common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ..., + call_function: global___CallFunction | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... def HasField( @@ -1165,6 +1169,8 @@ class Expression(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "alias", b"alias", + "call_function", + b"call_function", "cast", b"cast", "common_inline_user_defined_function", @@ -1204,6 +1210,8 @@ class Expression(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "alias", b"alias", + "call_function", + b"call_function", "cast", b"cast", "common_inline_user_defined_function", @@ -1256,6 +1264,7 @@ class Expression(google.protobuf.message.Message): "update_fields", "unresolved_named_lambda_variable", "common_inline_user_defined_function", + "call_function", "extension", ] | None: ... @@ -1469,3 +1478,30 @@ class JavaUDF(google.protobuf.message.Message): ) -> typing_extensions.Literal["output_type"] | None: ... global___JavaUDF = JavaUDF + +class CallFunction(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FUNCTION_NAME_FIELD_NUMBER: builtins.int + ARGUMENTS_FIELD_NUMBER: builtins.int + function_name: builtins.str + """(Required) Unparsed name of the SQL function.""" + @property + def arguments( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Expression]: + """(Optional) Function arguments. Empty arguments are allowed.""" + def __init__( + self, + *, + function_name: builtins.str = ..., + arguments: collections.abc.Iterable[global___Expression] | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "arguments", b"arguments", "function_name", b"function_name" + ], + ) -> None: ... + +global___CallFunction = CallFunction diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f566fcee0e3..b45e1daa0fd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -14395,16 +14395,16 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column: @try_remote_functions -def call_function(udfName: str, *cols: "ColumnOrName") -> Column: +def call_function(funcName: str, *cols: "ColumnOrName") -> Column: """ - Call a builtin or temp function. + Call a SQL function. .. versionadded:: 3.5.0 Parameters ---------- - udfName : str - name of the function + funcName : str + function name that follows the SQL identifier syntax (can be quoted, can be qualified) cols : :class:`~pyspark.sql.Column` or str column names or :class:`~pyspark.sql.Column`\\s to be used in the function @@ -14442,9 +14442,22 @@ def call_function(udfName: str, *cols: "ColumnOrName") -> Column: +-------+ | 2.0| +-------+ + >>> _ = spark.sql("CREATE FUNCTION custom_avg AS 'test.org.apache.spark.sql.MyDoubleAvg'") + >>> df.select(call_function("custom_avg", col("id"))).show() + +------------------------------------+ + |spark_catalog.default.custom_avg(id)| + +------------------------------------+ + | 102.0| + +------------------------------------+ + >>> df.select(call_function("spark_catalog.default.custom_avg", col("id"))).show() + +------------------------------------+ + |spark_catalog.default.custom_avg(id)| + +------------------------------------+ + | 102.0| + +------------------------------------+ """ sc = get_active_spark_context() - return _invoke_function("call_function", udfName, _to_seq(sc, cols, _to_java_column)) + return _invoke_function("call_function", funcName, _to_seq(sc, cols, _to_java_column)) @try_remote_functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a8cfd250c9..ca5e4422ca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -8338,7 +8338,7 @@ object functions { @scala.annotation.varargs @deprecated("Use call_udf") def callUDF(udfName: String, cols: Column*): Column = - call_function(udfName, cols: _*) + call_function(Seq(udfName), cols: _*) /** * Call an user-defined function. @@ -8357,18 +8357,28 @@ object functions { */ @scala.annotation.varargs def call_udf(udfName: String, cols: Column*): Column = - call_function(udfName, cols: _*) + call_function(Seq(udfName), cols: _*) /** - * Call a builtin or temp function. + * Call a SQL function. * - * @param funcName function name + * @param funcName function name that follows the SQL identifier syntax + * (can be quoted, can be qualified) * @param cols the expression parameters of function * @since 3.5.0 */ @scala.annotation.varargs - def call_function(funcName: String, cols: Column*): Column = - withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) } + def call_function(funcName: String, cols: Column*): Column = { + val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { + new SparkSqlParser() + } + val nameParts = parser.parseMultipartIdentifier(funcName) + call_function(nameParts, cols: _*) + } + + private def call_function(nameParts: Seq[String], cols: Column*): Column = withExpr { + UnresolvedFunction(nameParts, cols.map(_.expr), false) + } /** * Unwrap UDT data type column into its underlying type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9781a8e3ff4..c7dcb575ff0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -5918,6 +5918,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("call_function") { checkAnswer(testData2.select(call_function("avg", $"a")), testData2.selectExpr("avg(a)")) + + withUserDefinedFunction("custom_func" -> true, "custom_sum" -> false) { + spark.udf.register("custom_func", (i: Int) => { i + 2 }) + checkAnswer( + testData2.select(call_function("custom_func", $"a")), + Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5))) + spark.udf.register("default.custom_func", (i: Int) => { i + 2 }) + checkAnswer( + testData2.select(call_function("`default.custom_func`", $"a")), + Seq(Row(3), Row(3), Row(4), Row(4), Row(5), Row(5))) + + sql("CREATE FUNCTION custom_sum AS 'test.org.apache.spark.sql.MyDoubleSum'") + checkAnswer( + testData2.select( + call_function("custom_sum", $"a"), + call_function("default.custom_sum", $"a"), + call_function("spark_catalog.default.custom_sum", $"a")), + Row(12.0, 12.0, 12.0)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ef430f4b6a2..d12ebae0f5f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.{SparkException, SparkFiles, TestUtils} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.functions.max +import org.apache.spark.sql.functions.{call_function, max} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -552,6 +552,19 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } + test("Invoke a persist hive function with call_function") { + val testData = spark.range(5).repartition(1) + withUserDefinedFunction("custom_avg" -> false) { + sql(s"CREATE FUNCTION custom_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + testData.select( + call_function("custom_avg", $"id"), + call_function("default.custom_avg", $"id"), + call_function("spark_catalog.default.custom_avg", $"id")), + Row(2.0, 2.0, 2.0)) + } + } + test("Temp function has dots in the names") { withUserDefinedFunction("test_avg" -> false, "`default.test_avg`" -> true) { sql(s"CREATE FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org