This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 46dd3aa9425 [SPARK-44131][SQL] Add call_function and deprecate call_udf for Scala API 46dd3aa9425 is described below commit 46dd3aa94250343b38d963d74ae10aba255a6a24 Author: Jiaan Geng <belie...@163.com> AuthorDate: Mon Jul 10 18:11:14 2023 +0800 [SPARK-44131][SQL] Add call_function and deprecate call_udf for Scala API ### What changes were proposed in this pull request? The Scala API exists a method `call_udf` used to call the user-defined functions. In fact, `call_udf` also could call the builtin functions. The behavior is confused for users. This PR adds `call_function` to replace `call_udf` and deprecate `call_udf` for Scala API. ### Why are the changes needed? Fix the confusion of `call_udf`. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? Exists test cases. Closes #41687 from beliefer/SPARK-44131. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../scala/org/apache/spark/sql/functions.scala | 12 +++ .../apache/spark/sql/PlanGenerationTestSuite.scala | 4 + .../explain-results/function_call_function.explain | 2 + .../queries/function_call_function.json | 25 ++++++ .../queries/function_call_function.proto.bin | Bin 0 -> 174 bytes .../source/reference/pyspark.sql/functions.rst | 5 +- python/pyspark/sql/connect/functions.py | 9 ++- python/pyspark/sql/functions.py | 53 ++++++++++++ .../scala/org/apache/spark/sql/functions.scala | 90 +++++++++------------ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 7 +- 10 files changed, 150 insertions(+), 57 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 5240cdecb01..b0ae4c9752a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -7905,4 +7905,16 @@ object functions { } // scalastyle:off line.size.limit + /** + * Call a builtin or temp function. + * + * @param funcName + * function name + * @param cols + * the expression parameters of function + * @since 3.5.0 + */ + @scala.annotation.varargs + def call_function(funcName: String, cols: Column*): Column = Column.fn(funcName, cols: _*) + } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 1d679653166..7e4e0f24f4f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -2873,6 +2873,10 @@ class PlanGenerationTestSuite fn.random(lit(1)) } + functionTest("call_function") { + fn.call_function("lower", fn.col("g")) + } + test("hll_sketch_agg with column lgConfigK") { binary.select(fn.hll_sketch_agg(fn.col("bytes"), lit(0))) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_call_function.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_call_function.explain new file mode 100644 index 00000000000..d905689c35d --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_call_function.explain @@ -0,0 +1,2 @@ +Project [lower(g#0) AS lower(g)#0] ++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json new file mode 100644 index 00000000000..f7fe5beba2c --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json @@ -0,0 +1,25 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "lower", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin new file mode 100644 index 00000000000..7c736d93f77 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin differ diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 4ca1ef76049..c5eb92c92a7 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -460,11 +460,12 @@ Bitwise Functions getbit -UDF ---- +Call Functions +-------------- .. autosummary:: :toctree: api/ + call_function call_udf pandas_udf udf diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 813866edb9b..c6445f110c0 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -3853,7 +3853,7 @@ def bitmap_or_agg(col: "ColumnOrName") -> Column: bitmap_or_agg.__doc__ = pysparkfuncs.bitmap_or_agg.__doc__ -# User Defined Function +# Call Functions def call_udf(udfName: str, *cols: "ColumnOrName") -> Column: @@ -3891,6 +3891,13 @@ def udf( udf.__doc__ = pysparkfuncs.udf.__doc__ +def call_function(udfName: str, *cols: "ColumnOrName") -> Column: + return _invoke_function(udfName, *[_to_col(c) for c in cols]) + + +call_function.__doc__ = pysparkfuncs.call_function.__doc__ + + def _test() -> None: import sys import doctest diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b77a41a0f6f..b7d1204deef 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -14394,6 +14394,59 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> Column: return _invoke_function("call_udf", udfName, _to_seq(sc, cols, _to_java_column)) +@try_remote_functions +def call_function(udfName: str, *cols: "ColumnOrName") -> Column: + """ + Call a builtin or temp function. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + udfName : str + name of the function + cols : :class:`~pyspark.sql.Column` or str + column names or :class:`~pyspark.sql.Column`\\s to be used in the function + + Returns + ------- + :class:`~pyspark.sql.Column` + result of executed function. + + Examples + -------- + >>> from pyspark.sql.functions import call_udf, col + >>> from pyspark.sql.types import IntegerType, StringType + >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "c")],["id", "name"]) + >>> _ = spark.udf.register("intX2", lambda i: i * 2, IntegerType()) + >>> df.select(call_function("intX2", "id")).show() + +---------+ + |intX2(id)| + +---------+ + | 2| + | 4| + | 6| + +---------+ + >>> _ = spark.udf.register("strX2", lambda s: s * 2, StringType()) + >>> df.select(call_function("strX2", col("name"))).show() + +-----------+ + |strX2(name)| + +-----------+ + | aa| + | bb| + | cc| + +-----------+ + >>> df.select(call_function("avg", col("id"))).show() + +-------+ + |avg(id)| + +-------+ + | 2.0| + +-------+ + """ + sc = get_active_spark_context() + return _invoke_function("call_function", udfName, _to_seq(sc, cols, _to_java_column)) + + @try_remote_functions def unwrap_udt(col: "ColumnOrName") -> Column: """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7e584db6636..6931cd286ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1936,9 +1936,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_add(left: Column, right: Column): Column = withExpr { - UnresolvedFunction("try_add", Seq(left.expr, right.expr), isDistinct = false) - } + def try_add(left: Column, right: Column): Column = call_function("try_add", left, right) /** * Returns the mean calculated from values of a group and the result is null on overflow. @@ -1957,9 +1955,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_divide(dividend: Column, divisor: Column): Column = withExpr { - UnresolvedFunction("try_divide", Seq(dividend.expr, divisor.expr), isDistinct = false) - } + def try_divide(dividend: Column, divisor: Column): Column = + call_function("try_divide", dividend, divisor) /** * Returns `left``*``right` and the result is null on overflow. The acceptable input types are @@ -1968,9 +1965,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_multiply(left: Column, right: Column): Column = withExpr { - UnresolvedFunction("try_multiply", Seq(left.expr, right.expr), isDistinct = false) - } + def try_multiply(left: Column, right: Column): Column = + call_function("try_multiply", left, right) /** * Returns `left``-``right` and the result is null on overflow. The acceptable input types are @@ -1979,9 +1975,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_subtract(left: Column, right: Column): Column = withExpr { - UnresolvedFunction("try_subtract", Seq(left.expr, right.expr), isDistinct = false) - } + def try_subtract(left: Column, right: Column): Column = + call_function("try_subtract", left, right) /** * Returns the sum calculated from values of a group and the result is null on overflow. @@ -2366,9 +2361,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def ceil(e: Column, scale: Column): Column = withExpr { - UnresolvedFunction(Seq("ceil"), Seq(e.expr, scale.expr), isDistinct = false) - } + def ceil(e: Column, scale: Column): Column = call_function("ceil", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2376,9 +2369,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = withExpr { - UnresolvedFunction(Seq("ceil"), Seq(e.expr), isDistinct = false) - } + def ceil(e: Column): Column = call_function("ceil", e) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2522,9 +2513,7 @@ object functions { * @group math_funcs * @since 3.3.0 */ - def floor(e: Column, scale: Column): Column = withExpr { - UnresolvedFunction(Seq("floor"), Seq(e.expr, scale.expr), isDistinct = false) - } + def floor(e: Column, scale: Column): Column = call_function("floor", e, scale) /** * Computes the floor of the given value of `e` to 0 decimal places. @@ -2532,9 +2521,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = withExpr { - UnresolvedFunction(Seq("floor"), Seq(e.expr), isDistinct = false) - } + def floor(e: Column): Column = call_function("floor", e) /** * Computes the floor of the given column value to 0 decimal places. @@ -4007,9 +3994,8 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { - UnresolvedFunction("lpad", Seq(str.expr, lit(len).expr, lit(pad).expr), isDistinct = false) - } + def lpad(str: Column, len: Int, pad: Array[Byte]): Column = + call_function("lpad", str, lit(len), lit(pad)) /** * Trim the spaces from left end for the specified string value. @@ -4190,9 +4176,8 @@ object functions { * @group string_funcs * @since 3.3.0 */ - def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr { - UnresolvedFunction("rpad", Seq(str.expr, lit(len).expr, lit(pad).expr), isDistinct = false) - } + def rpad(str: Column, len: Int, pad: Array[Byte]): Column = + call_function("rpad", str, lit(len), lit(pad)) /** * Repeats a string column n times, and returns it as a new string column. @@ -4628,9 +4613,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def endswith(str: Column, suffix: Column): Column = withExpr { - UnresolvedFunction(Seq("endswith"), Seq(str.expr, suffix.expr), isDistinct = false) - } + def endswith(str: Column, suffix: Column): Column = call_function("endswith", str, suffix) /** * Returns a boolean. The value is True if str starts with prefix. @@ -4640,9 +4623,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def startswith(str: Column, prefix: Column): Column = withExpr { - UnresolvedFunction(Seq("startswith"), Seq(str.expr, prefix.expr), isDistinct = false) - } + def startswith(str: Column, prefix: Column): Column = call_function("startswith", str, prefix) /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4752,9 +4733,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def contains(left: Column, right: Column): Column = withExpr { - UnresolvedFunction(Seq("contains"), Seq(left.expr, right.expr), isDistinct = false) - } + def contains(left: Column, right: Column): Column = call_function("contains", left, right) /** * Returns the `n`-th input, e.g., returns `input2` when `n` is 2. @@ -5167,9 +5146,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def extract(field: Column, source: Column): Column = withExpr { - UnresolvedFunction("extract", Seq(field.expr, source.expr), isDistinct = false) - } + def extract(field: Column, source: Column): Column = call_function("extract", field, source) /** * Extracts a part of the date/timestamp or interval source. @@ -5181,9 +5158,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_part(field: Column, source: Column): Column = withExpr { - UnresolvedFunction("date_part", Seq(field.expr, source.expr), isDistinct = false) - } + def date_part(field: Column, source: Column): Column = call_function("date_part", field, source) /** * Extracts a part of the date/timestamp or interval source. @@ -5195,9 +5170,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def datepart(field: Column, source: Column): Column = withExpr { - UnresolvedFunction("datepart", Seq(field.expr, source.expr), isDistinct = false) - } + def datepart(field: Column, source: Column): Column = call_function("datepart", field, source) /** * Returns the last day of the month which the given date belongs to. @@ -8363,9 +8336,9 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - @deprecated("Use call_udf") + @deprecated("Use call_function") def callUDF(udfName: String, cols: Column*): Column = - call_udf(udfName, cols: _*) + call_function(udfName, cols: _*) /** * Call an user-defined function. @@ -8383,9 +8356,20 @@ object functions { * @since 3.2.0 */ @scala.annotation.varargs - def call_udf(udfName: String, cols: Column*): Column = withExpr { - UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) - } + @deprecated("Use call_function") + def call_udf(udfName: String, cols: Column*): Column = + call_function(udfName, cols: _*) + + /** + * Call a builtin or temp function. + * + * @param funcName function name + * @param cols the expression parameters of function + * @since 3.5.0 + */ + @scala.annotation.varargs + def call_function(funcName: String, cols: Column*): Column = + withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) } /** * Unwrap UDT data type column into its underlying type. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c28ee3d8483..9781a8e3ff4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -72,7 +72,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "countDistinct", "count_distinct", // equivalent to count(distinct foo) "sum_distinct", // equivalent to sum(distinct foo) "typedLit", "typedlit", // Scala only - "udaf", "udf" // create function statement in sql + "udaf", "udf", // create function statement in sql + "call_function" // moot in SQL as you just call the function directly ) val excludedSqlFunctions = Set.empty[String] @@ -5914,6 +5915,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map.empty ) } + + test("call_function") { + checkAnswer(testData2.select(call_function("avg", $"a")), testData2.selectExpr("avg(a)")) + } } object DataFrameFunctionsSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org