This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 46dd3aa9425 [SPARK-44131][SQL] Add call_function and deprecate 
call_udf for Scala API
46dd3aa9425 is described below

commit 46dd3aa94250343b38d963d74ae10aba255a6a24
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Mon Jul 10 18:11:14 2023 +0800

    [SPARK-44131][SQL] Add call_function and deprecate call_udf for Scala API
    
    ### What changes were proposed in this pull request?
    The Scala API exists a method `call_udf` used to call the user-defined 
functions.
    In fact, `call_udf` also could call the builtin functions.
    The behavior is confused for users.
    
    This PR adds `call_function` to replace `call_udf` and deprecate `call_udf` 
for Scala API.
    
    ### Why are the changes needed?
    Fix the confusion of `call_udf`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    Exists test cases.
    
    Closes #41687 from beliefer/SPARK-44131.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../scala/org/apache/spark/sql/functions.scala     |  12 +++
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   4 +
 .../explain-results/function_call_function.explain |   2 +
 .../queries/function_call_function.json            |  25 ++++++
 .../queries/function_call_function.proto.bin       | Bin 0 -> 174 bytes
 .../source/reference/pyspark.sql/functions.rst     |   5 +-
 python/pyspark/sql/connect/functions.py            |   9 ++-
 python/pyspark/sql/functions.py                    |  53 ++++++++++++
 .../scala/org/apache/spark/sql/functions.scala     |  90 +++++++++------------
 .../apache/spark/sql/DataFrameFunctionsSuite.scala |   7 +-
 10 files changed, 150 insertions(+), 57 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 5240cdecb01..b0ae4c9752a 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7905,4 +7905,16 @@ object functions {
   }
   // scalastyle:off line.size.limit
 
+  /**
+   * Call a builtin or temp function.
+   *
+   * @param funcName
+   *   function name
+   * @param cols
+   *   the expression parameters of function
+   * @since 3.5.0
+   */
+  @scala.annotation.varargs
+  def call_function(funcName: String, cols: Column*): Column = 
Column.fn(funcName, cols: _*)
+
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 1d679653166..7e4e0f24f4f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -2873,6 +2873,10 @@ class PlanGenerationTestSuite
     fn.random(lit(1))
   }
 
+  functionTest("call_function") {
+    fn.call_function("lower", fn.col("g"))
+  }
+
   test("hll_sketch_agg with column lgConfigK") {
     binary.select(fn.hll_sketch_agg(fn.col("bytes"), lit(0)))
   }
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_call_function.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_call_function.explain
new file mode 100644
index 00000000000..d905689c35d
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_call_function.explain
@@ -0,0 +1,2 @@
+Project [lower(g#0) AS lower(g)#0]
++- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
new file mode 100644
index 00000000000..f7fe5beba2c
--- /dev/null
+++ 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.json
@@ -0,0 +1,25 @@
+{
+  "common": {
+    "planId": "1"
+  },
+  "project": {
+    "input": {
+      "common": {
+        "planId": "0"
+      },
+      "localRelation": {
+        "schema": 
"struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+      }
+    },
+    "expressions": [{
+      "unresolvedFunction": {
+        "functionName": "lower",
+        "arguments": [{
+          "unresolvedAttribute": {
+            "unparsedIdentifier": "g"
+          }
+        }]
+      }
+    }]
+  }
+}
\ No newline at end of file
diff --git 
a/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
new file mode 100644
index 00000000000..7c736d93f77
Binary files /dev/null and 
b/connector/connect/common/src/test/resources/query-tests/queries/function_call_function.proto.bin
 differ
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst 
b/python/docs/source/reference/pyspark.sql/functions.rst
index 4ca1ef76049..c5eb92c92a7 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -460,11 +460,12 @@ Bitwise Functions
     getbit
 
 
-UDF
----
+Call Functions
+--------------
 .. autosummary::
     :toctree: api/
 
+    call_function
     call_udf
     pandas_udf
     udf
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 813866edb9b..c6445f110c0 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -3853,7 +3853,7 @@ def bitmap_or_agg(col: "ColumnOrName") -> Column:
 bitmap_or_agg.__doc__ = pysparkfuncs.bitmap_or_agg.__doc__
 
 
-# User Defined Function
+# Call Functions
 
 
 def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:
@@ -3891,6 +3891,13 @@ def udf(
 udf.__doc__ = pysparkfuncs.udf.__doc__
 
 
+def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
+    return _invoke_function(udfName, *[_to_col(c) for c in cols])
+
+
+call_function.__doc__ = pysparkfuncs.call_function.__doc__
+
+
 def _test() -> None:
     import sys
     import doctest
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b77a41a0f6f..b7d1204deef 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -14394,6 +14394,59 @@ def call_udf(udfName: str, *cols: "ColumnOrName") -> 
Column:
     return _invoke_function("call_udf", udfName, _to_seq(sc, cols, 
_to_java_column))
 
 
+@try_remote_functions
+def call_function(udfName: str, *cols: "ColumnOrName") -> Column:
+    """
+    Call a builtin or temp function.
+
+    .. versionadded:: 3.5.0
+
+    Parameters
+    ----------
+    udfName : str
+        name of the function
+    cols : :class:`~pyspark.sql.Column` or str
+        column names or :class:`~pyspark.sql.Column`\\s to be used in the 
function
+
+    Returns
+    -------
+    :class:`~pyspark.sql.Column`
+        result of executed function.
+
+    Examples
+    --------
+    >>> from pyspark.sql.functions import call_udf, col
+    >>> from pyspark.sql.types import IntegerType, StringType
+    >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "c")],["id", 
"name"])
+    >>> _ = spark.udf.register("intX2", lambda i: i * 2, IntegerType())
+    >>> df.select(call_function("intX2", "id")).show()
+    +---------+
+    |intX2(id)|
+    +---------+
+    |        2|
+    |        4|
+    |        6|
+    +---------+
+    >>> _ = spark.udf.register("strX2", lambda s: s * 2, StringType())
+    >>> df.select(call_function("strX2", col("name"))).show()
+    +-----------+
+    |strX2(name)|
+    +-----------+
+    |         aa|
+    |         bb|
+    |         cc|
+    +-----------+
+    >>> df.select(call_function("avg", col("id"))).show()
+    +-------+
+    |avg(id)|
+    +-------+
+    |    2.0|
+    +-------+
+    """
+    sc = get_active_spark_context()
+    return _invoke_function("call_function", udfName, _to_seq(sc, cols, 
_to_java_column))
+
+
 @try_remote_functions
 def unwrap_udt(col: "ColumnOrName") -> Column:
     """
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 7e584db6636..6931cd286ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1936,9 +1936,7 @@ object functions {
    * @group math_funcs
    * @since 3.5.0
    */
-  def try_add(left: Column, right: Column): Column = withExpr {
-    UnresolvedFunction("try_add", Seq(left.expr, right.expr), isDistinct = 
false)
-  }
+  def try_add(left: Column, right: Column): Column = call_function("try_add", 
left, right)
 
   /**
    * Returns the mean calculated from values of a group and the result is null 
on overflow.
@@ -1957,9 +1955,8 @@ object functions {
    * @group math_funcs
    * @since 3.5.0
    */
-  def try_divide(dividend: Column, divisor: Column): Column = withExpr {
-    UnresolvedFunction("try_divide", Seq(dividend.expr, divisor.expr), 
isDistinct = false)
-  }
+  def try_divide(dividend: Column, divisor: Column): Column =
+    call_function("try_divide", dividend, divisor)
 
   /**
    * Returns `left``*``right` and the result is null on overflow. The 
acceptable input types are
@@ -1968,9 +1965,8 @@ object functions {
    * @group math_funcs
    * @since 3.5.0
    */
-  def try_multiply(left: Column, right: Column): Column = withExpr {
-    UnresolvedFunction("try_multiply", Seq(left.expr, right.expr), isDistinct 
= false)
-  }
+  def try_multiply(left: Column, right: Column): Column =
+    call_function("try_multiply", left, right)
 
   /**
    * Returns `left``-``right` and the result is null on overflow. The 
acceptable input types are
@@ -1979,9 +1975,8 @@ object functions {
    * @group math_funcs
    * @since 3.5.0
    */
-  def try_subtract(left: Column, right: Column): Column = withExpr {
-    UnresolvedFunction("try_subtract", Seq(left.expr, right.expr), isDistinct 
= false)
-  }
+  def try_subtract(left: Column, right: Column): Column =
+    call_function("try_subtract", left, right)
 
   /**
    * Returns the sum calculated from values of a group and the result is null 
on overflow.
@@ -2366,9 +2361,7 @@ object functions {
    * @group math_funcs
    * @since 3.3.0
    */
-  def ceil(e: Column, scale: Column): Column = withExpr {
-    UnresolvedFunction(Seq("ceil"), Seq(e.expr, scale.expr), isDistinct = 
false)
-  }
+  def ceil(e: Column, scale: Column): Column = call_function("ceil", e, scale)
 
   /**
    * Computes the ceiling of the given value of `e` to 0 decimal places.
@@ -2376,9 +2369,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def ceil(e: Column): Column = withExpr {
-    UnresolvedFunction(Seq("ceil"), Seq(e.expr), isDistinct = false)
-  }
+  def ceil(e: Column): Column = call_function("ceil", e)
 
   /**
    * Computes the ceiling of the given value of `e` to 0 decimal places.
@@ -2522,9 +2513,7 @@ object functions {
    * @group math_funcs
    * @since 3.3.0
    */
-  def floor(e: Column, scale: Column): Column = withExpr {
-    UnresolvedFunction(Seq("floor"), Seq(e.expr, scale.expr), isDistinct = 
false)
-  }
+  def floor(e: Column, scale: Column): Column = call_function("floor", e, 
scale)
 
   /**
    * Computes the floor of the given value of `e` to 0 decimal places.
@@ -2532,9 +2521,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def floor(e: Column): Column = withExpr {
-    UnresolvedFunction(Seq("floor"), Seq(e.expr), isDistinct = false)
-  }
+  def floor(e: Column): Column = call_function("floor", e)
 
   /**
    * Computes the floor of the given column value to 0 decimal places.
@@ -4007,9 +3994,8 @@ object functions {
    * @group string_funcs
    * @since 3.3.0
    */
-  def lpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
-    UnresolvedFunction("lpad", Seq(str.expr, lit(len).expr, lit(pad).expr), 
isDistinct = false)
-  }
+  def lpad(str: Column, len: Int, pad: Array[Byte]): Column =
+    call_function("lpad", str, lit(len), lit(pad))
 
   /**
    * Trim the spaces from left end for the specified string value.
@@ -4190,9 +4176,8 @@ object functions {
    * @group string_funcs
    * @since 3.3.0
    */
-  def rpad(str: Column, len: Int, pad: Array[Byte]): Column = withExpr {
-    UnresolvedFunction("rpad", Seq(str.expr, lit(len).expr, lit(pad).expr), 
isDistinct = false)
-  }
+  def rpad(str: Column, len: Int, pad: Array[Byte]): Column =
+    call_function("rpad", str, lit(len), lit(pad))
 
   /**
    * Repeats a string column n times, and returns it as a new string column.
@@ -4628,9 +4613,7 @@ object functions {
    * @group string_funcs
    * @since 3.5.0
    */
-  def endswith(str: Column, suffix: Column): Column = withExpr {
-    UnresolvedFunction(Seq("endswith"), Seq(str.expr, suffix.expr), isDistinct 
= false)
-  }
+  def endswith(str: Column, suffix: Column): Column = 
call_function("endswith", str, suffix)
 
   /**
    * Returns a boolean. The value is True if str starts with prefix.
@@ -4640,9 +4623,7 @@ object functions {
    * @group string_funcs
    * @since 3.5.0
    */
-  def startswith(str: Column, prefix: Column): Column = withExpr {
-    UnresolvedFunction(Seq("startswith"), Seq(str.expr, prefix.expr), 
isDistinct = false)
-  }
+  def startswith(str: Column, prefix: Column): Column = 
call_function("startswith", str, prefix)
 
   /**
    * Returns the ASCII character having the binary equivalent to `n`.
@@ -4752,9 +4733,7 @@ object functions {
    * @group string_funcs
    * @since 3.5.0
    */
-  def contains(left: Column, right: Column): Column = withExpr {
-    UnresolvedFunction(Seq("contains"), Seq(left.expr, right.expr), isDistinct 
= false)
-  }
+  def contains(left: Column, right: Column): Column = 
call_function("contains", left, right)
 
   /**
    * Returns the `n`-th input, e.g., returns `input2` when `n` is 2.
@@ -5167,9 +5146,7 @@ object functions {
    * @group datetime_funcs
    * @since 3.5.0
    */
-  def extract(field: Column, source: Column): Column = withExpr {
-    UnresolvedFunction("extract", Seq(field.expr, source.expr), isDistinct = 
false)
-  }
+  def extract(field: Column, source: Column): Column = 
call_function("extract", field, source)
 
   /**
    * Extracts a part of the date/timestamp or interval source.
@@ -5181,9 +5158,7 @@ object functions {
    * @group datetime_funcs
    * @since 3.5.0
    */
-  def date_part(field: Column, source: Column): Column = withExpr {
-    UnresolvedFunction("date_part", Seq(field.expr, source.expr), isDistinct = 
false)
-  }
+  def date_part(field: Column, source: Column): Column = 
call_function("date_part", field, source)
 
   /**
    * Extracts a part of the date/timestamp or interval source.
@@ -5195,9 +5170,7 @@ object functions {
    * @group datetime_funcs
    * @since 3.5.0
    */
-  def datepart(field: Column, source: Column): Column = withExpr {
-    UnresolvedFunction("datepart", Seq(field.expr, source.expr), isDistinct = 
false)
-  }
+  def datepart(field: Column, source: Column): Column = 
call_function("datepart", field, source)
 
   /**
    * Returns the last day of the month which the given date belongs to.
@@ -8363,9 +8336,9 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  @deprecated("Use call_udf")
+  @deprecated("Use call_function")
   def callUDF(udfName: String, cols: Column*): Column =
-    call_udf(udfName, cols: _*)
+    call_function(udfName, cols: _*)
 
   /**
    * Call an user-defined function.
@@ -8383,9 +8356,20 @@ object functions {
    * @since 3.2.0
    */
   @scala.annotation.varargs
-  def call_udf(udfName: String, cols: Column*): Column = withExpr {
-    UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
-  }
+  @deprecated("Use call_function")
+  def call_udf(udfName: String, cols: Column*): Column =
+    call_function(udfName, cols: _*)
+
+  /**
+   * Call a builtin or temp function.
+   *
+   * @param funcName function name
+   * @param cols the expression parameters of function
+   * @since 3.5.0
+   */
+  @scala.annotation.varargs
+  def call_function(funcName: String, cols: Column*): Column =
+    withExpr { UnresolvedFunction(funcName, cols.map(_.expr), false) }
 
   /**
    * Unwrap UDT data type column into its underlying type.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index c28ee3d8483..9781a8e3ff4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -72,7 +72,8 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       "countDistinct", "count_distinct", // equivalent to count(distinct foo)
       "sum_distinct", // equivalent to sum(distinct foo)
       "typedLit", "typedlit", // Scala only
-      "udaf", "udf" // create function statement in sql
+      "udaf", "udf", // create function statement in sql
+      "call_function" // moot in SQL as you just call the function directly
     )
 
     val excludedSqlFunctions = Set.empty[String]
@@ -5914,6 +5915,10 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSparkSession {
       parameters = Map.empty
     )
   }
+
+  test("call_function") {
+    checkAnswer(testData2.select(call_function("avg", $"a")), 
testData2.selectExpr("avg(a)"))
+  }
 }
 
 object DataFrameFunctionsSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to