This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8c444f49713 [SPARK-44715][CONNECT] Bring back callUdf and udf function
8c444f49713 is described below

commit 8c444f497137d5abb3a94b576ec0fea55dc18bbc
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Aug 8 15:41:36 2023 +0200

    [SPARK-44715][CONNECT] Bring back callUdf and udf function
    
    ### What changes were proposed in this pull request?
    This PR adds the `udf` (with a return type), and `callUDF` functions to 
`functions.scala` for the Spark Connect Scala Client.
    
    ### Why are the changes needed?
    We want the Spark Connect Scala Client to be as compatible as possible with 
the existing sql/core APIs.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. It adds more exposed functions.
    
    ### How was this patch tested?
    Added tests to `UserDefinedFunctionE2ETestSuite` and  `FunctionTestSuite`. 
I have also updated the compatibility checks.
    
    Closes #42387 from hvanhovell/SPARK-44715.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../scala/org/apache/spark/sql/functions.scala     | 40 ++++++++++++++++++++++
 .../org/apache/spark/sql/FunctionTestSuite.scala   |  2 ++
 .../sql/UserDefinedFunctionE2ETestSuite.scala      | 20 +++++++++++
 .../CheckConnectJvmClientCompatibility.scala       |  7 ----
 4 files changed, 62 insertions(+), 7 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 89bfc998179..fa8c5782e06 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -8056,6 +8056,46 @@ object functions {
   }
   // scalastyle:off line.size.limit
 
+  /**
+   * Defines a deterministic user-defined function (UDF) using a Scala 
closure. For this variant,
+   * the caller must specify the output data type, and there is no automatic 
input type coercion.
+   * By default the returned UDF is deterministic. To change it to 
nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`.
+   *
+   * Note that, although the Scala closure can have primitive-type function 
argument, it doesn't
+   * work well with null values. Because the Scala closure is passed in as Any 
type, there is no
+   * type information for the function arguments. Without the type 
information, Spark may blindly
+   * pass null to the Scala closure with primitive-type argument, and the 
closure will see the
+   * default value of the Java type for the null argument, e.g. `udf((x: Int) 
=> x, IntegerType)`,
+   * the result is 0 for null input.
+   *
+   * @param f
+   *   A closure in Scala
+   * @param dataType
+   *   The output data type of the UDF
+   *
+   * @group udf_funcs
+   * @since 3.5.0
+   */
+  @deprecated(
+    "Scala `udf` method with return type parameter is deprecated. " +
+      "Please use Scala `udf` method without return type parameter.",
+    "3.0.0")
+  def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
+    ScalarUserDefinedFunction(f, dataType)
+  }
+
+  /**
+   * Call an user-defined function.
+   *
+   * @group udf_funcs
+   * @since 3.5.0
+   */
+  @scala.annotation.varargs
+  @deprecated("Use call_udf")
+  def callUDF(udfName: String, cols: Column*): Column =
+    call_function(udfName, cols: _*)
+
   /**
    * Call an user-defined function. Example:
    * {{{
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
index 32004b6bcc1..4a8e108357f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
@@ -249,6 +249,8 @@ class FunctionTestSuite extends ConnectFunSuite {
     pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes(), 
Map.empty[String, String].asJava),
     pbFn.to_protobuf(a, "FakeMessage", "fakeBytes".getBytes()))
 
+  testEquals("call_udf", callUDF("bob", lit(1)), call_udf("bob", lit(1)))
+
   test("assert_true no message") {
     val e = assert_true(a).expr
     assert(e.hasUnresolvedFunction)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index 258fa1e7c74..3a931c9a6ba 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -24,9 +24,11 @@ import java.util.concurrent.atomic.AtomicLong
 import scala.collection.JavaConverters._
 
 import org.apache.spark.api.java.function._
+import org.apache.spark.sql.api.java.UDF2
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, 
PrimitiveLongEncoder}
 import org.apache.spark.sql.connect.client.util.QueryTest
 import org.apache.spark.sql.functions.{col, struct, udf}
+import org.apache.spark.sql.types.IntegerType
 
 /**
  * All tests in this class requires client UDF defined in this test class 
synced with the server.
@@ -250,4 +252,22 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
       "b",
       "c")
   }
+
+  test("(deprecated) scala UDF with dataType") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val fn = udf(((i: Long) => (i + 1).toInt), IntegerType)
+    checkDataset(session.range(2).select(fn($"id")).as[Int], 1, 2)
+  }
+
+  test("java UDF") {
+    val session: SparkSession = spark
+    import session.implicits._
+    val fn = udf(
+      new UDF2[Long, Long, Int] {
+        override def call(t1: Long, t2: Long): Int = (t1 + t2 + 1).toInt
+      },
+      IntegerType)
+    checkDataset(session.range(2).select(fn($"id", $"id" + 2)).as[Int], 3, 5)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 2bf9c41fb2c..d380a1bbb65 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -191,8 +191,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.Dataset.javaRDD"),
 
       // functions
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udf"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
 
@@ -214,14 +212,11 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
-      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.streams"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
       ProblemFilters.exclude[Problem](
         "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"),
-      // TODO(SPARK-44068): Support positional parameters in Scala connect 
client
-      ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sql"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.this"),
 
       // SparkSession#implicits
@@ -266,8 +261,6 @@ object CheckConnectJvmClientCompatibility {
         "org.apache.spark.sql.streaming.StreamingQueryException.time"),
 
       // Classes missing from streaming API
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.ForeachWriter"),
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.GroupState"),
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.streaming.TestGroupState"),
       ProblemFilters.exclude[MissingClassProblem](


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to