This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ecdad5c59c [SPARK-43995][SPARK-43996][CONNECT] Add support for 
UDFRegistration to the Connect Scala Client
7ecdad5c59c is described below

commit 7ecdad5c59ce2eecd4686effeb10819a6d784844
Author: vicennial <venkata.gud...@databricks.com>
AuthorDate: Fri Jul 14 10:52:12 2023 +0900

    [SPARK-43995][SPARK-43996][CONNECT] Add support for UDFRegistration to the 
Connect Scala Client
    
    ### What changes were proposed in this pull request?
    
    This PR adds support to register a scala UDF from the scala/jvm client.
    
    The following APIs are implemented in `UDFRegistration`:
    
    - `def register(name: String, udf: UserDefinedFunction): 
UserDefinedFunction`
    - `def register[RT: TypeTag, A1: TypeTag ...](name: String, func: (A1, ...) 
=> RT): UserDefinedFunction` for 0 to 22 arguments.
    
    The following API is implemented in `functions`:
    
    - `def call_udf(udfName: String, cols: Column*): Column`
    
    Note: This PR is stacked on https://github.com/apache/spark/pull/41959.
    ### Why are the changes needed?
    
    To reach parity with classic Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. spark.udf.register() is added as shown below:
    ```scala
    class A(x: Int) { def get = x * 100 }
    val myUdf = udf((x: Int) => new A(x).get)
    spark.udf.register("dummyUdf", myUdf)
    spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
    ```
    The output:
    ```scala
    Array[Long] = Array(0L, 100L, 200L, 300L, 400L)
    ````
    
    ### How was this patch tested?
    
    New tests in `ReplE2ESuite`.
    
    Closes #41953 from vicennial/SPARK-43995.
    
    Authored-by: vicennial <venkata.gud...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |   31 +
 .../org/apache/spark/sql/UDFRegistration.scala     | 1028 ++++++++++++++++++++
 .../sql/expressions/UserDefinedFunction.scala      |   10 +
 .../scala/org/apache/spark/sql/functions.scala     |   17 +
 .../spark/sql/application/ReplE2ESuite.scala       |   31 +
 .../CheckConnectJvmClientCompatibility.scala       |    1 -
 .../sql/connect/planner/SparkConnectPlanner.scala  |   23 +-
 7 files changed, 1139 insertions(+), 2 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index c27f0f32e0d..fb9959c9942 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -417,6 +417,30 @@ class SparkSession private[sql] (
     range(start, end, step, Option(numPartitions))
   }
 
+  /**
+   * A collection of methods for registering user-defined functions (UDF).
+   *
+   * The following example registers a Scala closure as UDF:
+   * {{{
+   *   sparkSession.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + 
arg1)
+   * }}}
+   *
+   * The following example registers a UDF in Java:
+   * {{{
+   *   sparkSession.udf().register("myUDF",
+   *       (Integer arg1, String arg2) -> arg2 + arg1,
+   *       DataTypes.StringType);
+   * }}}
+   *
+   * @note
+   *   The user-defined functions must be deterministic. Due to optimization, 
duplicate
+   *   invocations may be eliminated or the function may even be invoked more 
times than it is
+   *   present in the query.
+   *
+   * @since 3.5.0
+   */
+  lazy val udf: UDFRegistration = new UDFRegistration(this)
+
   // scalastyle:off
   // Disable style checker so "implicits" object can start with lowercase i
   /**
@@ -525,6 +549,13 @@ class SparkSession private[sql] (
     client.execute(plan).asScala.toSeq
   }
 
+  private[sql] def registerUdf(udf: proto.CommonInlineUserDefinedFunction): 
Unit = {
+    val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
+    val plan = proto.Plan.newBuilder().setCommand(command).build()
+
+    client.execute(plan)
+  }
+
   @DeveloperApi
   def execute(extension: com.google.protobuf.Any): Unit = {
     val command = proto.Command.newBuilder().setExtension(extension).build()
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
new file mode 100644
index 00000000000..426709b8f18
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -0,0 +1,1028 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.expressions.{ScalarUserDefinedFunction, 
UserDefinedFunction}
+
+/**
+ * Functions for registering user-defined functions. Use `SparkSession.udf` to 
access this:
+ *
+ * {{{
+ *   spark.udf
+ * }}}
+ *
+ * @since 3.5.0
+ */
+class UDFRegistration(session: SparkSession) extends Logging {
+
+  /**
+   * Registers a user-defined function (UDF), for a UDF that's already defined 
using the Dataset
+   * API (i.e. of type UserDefinedFunction). To change a UDF to 
nondeterministic, call the API
+   * `UserDefinedFunction.asNondeterministic()`. To change a UDF to 
nonNullable, call the API
+   * `UserDefinedFunction.asNonNullable()`.
+   *
+   * Example:
+   * {{{
+   *   val foo = udf(() => Math.random())
+   *   spark.udf.register("random", foo.asNondeterministic())
+   *
+   *   val bar = udf(() => "bar")
+   *   spark.udf.register("stringLit", bar.asNonNullable())
+   * }}}
+   *
+   * @param name
+   *   the name of the UDF.
+   * @param udf
+   *   the UDF needs to be registered.
+   * @return
+   *   the registered UDF.
+   *
+   * @since 3.5.0
+   */
+  def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
+    udf.withName(name) match {
+      case scalarUdf: ScalarUserDefinedFunction =>
+        session.registerUdf(scalarUdf.toProto)
+        scalarUdf
+      case other =>
+        throw new UnsupportedOperationException(
+          s"Registering a UDF of type " +
+            s"${other.getClass.getSimpleName} is currently unsupported.")
+    }
+  }
+
+  // scalastyle:off line.size.limit
+
+  /* register 0-22 were generated by this script:
+    (0 to 22).foreach { x =>
+      val params = (1 to x).map(num => s"A$num").mkString(", ")
+      val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: 
TypeTag")(_ + ", " + _)
+      println(s"""
+        |/**
+        | * Registers a deterministic Scala closure of $x arguments as 
user-defined function (UDF).
+        | * @tparam RT return type of UDF.
+        | * @since 3.5.0
+        | */
+        |def register[$typeTags](name: String, func: ($params) => RT): 
UserDefinedFunction = {
+        |  register(name, functions.udf(func))
+        |}""".stripMargin)
+    }
+   */
+
+  /**
+   * Registers a deterministic Scala closure of 0 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[RT: TypeTag](name: String, func: () => RT): UserDefinedFunction 
= {
+    val udf = ScalarUserDefinedFunction(func, typeTag[RT])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 1 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[RT: TypeTag, A1: TypeTag](name: String, func: (A1) => RT): 
UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(func, typeTag[RT], typeTag[A1])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 2 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](
+      name: String,
+      func: (A1, A2) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(func, typeTag[RT], typeTag[A1], 
typeTag[A2])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 3 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](
+      name: String,
+      func: (A1, A2, A3) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(func, typeTag[RT], typeTag[A1], 
typeTag[A2], typeTag[A3])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 4 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 5 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 6 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag](name: String, func: (A1, A2, A3, A4, A5, A6) => RT): 
UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 7 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 8 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 9 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9) => RT): UserDefinedFunction = 
{
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 10 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10) => RT): 
UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 11 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11) => RT): 
UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 12 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12) => RT): 
UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 13 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13) => RT)
+      : UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 14 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14) => 
RT)
+      : UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 15 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15) 
=> RT)
+      : UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 16 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, 
A16) => RT)
+      : UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 17 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag,
+      A17: TypeTag](
+      name: String,
+      func: (A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, 
A16, A17) => RT)
+      : UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16],
+      typeTag[A17])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 18 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag,
+      A17: TypeTag,
+      A18: TypeTag](
+      name: String,
+      func: (
+          A1,
+          A2,
+          A3,
+          A4,
+          A5,
+          A6,
+          A7,
+          A8,
+          A9,
+          A10,
+          A11,
+          A12,
+          A13,
+          A14,
+          A15,
+          A16,
+          A17,
+          A18) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16],
+      typeTag[A17],
+      typeTag[A18])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 19 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag,
+      A17: TypeTag,
+      A18: TypeTag,
+      A19: TypeTag](
+      name: String,
+      func: (
+          A1,
+          A2,
+          A3,
+          A4,
+          A5,
+          A6,
+          A7,
+          A8,
+          A9,
+          A10,
+          A11,
+          A12,
+          A13,
+          A14,
+          A15,
+          A16,
+          A17,
+          A18,
+          A19) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16],
+      typeTag[A17],
+      typeTag[A18],
+      typeTag[A19])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 20 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag,
+      A17: TypeTag,
+      A18: TypeTag,
+      A19: TypeTag,
+      A20: TypeTag](
+      name: String,
+      func: (
+          A1,
+          A2,
+          A3,
+          A4,
+          A5,
+          A6,
+          A7,
+          A8,
+          A9,
+          A10,
+          A11,
+          A12,
+          A13,
+          A14,
+          A15,
+          A16,
+          A17,
+          A18,
+          A19,
+          A20) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16],
+      typeTag[A17],
+      typeTag[A18],
+      typeTag[A19],
+      typeTag[A20])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 21 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag,
+      A17: TypeTag,
+      A18: TypeTag,
+      A19: TypeTag,
+      A20: TypeTag,
+      A21: TypeTag](
+      name: String,
+      func: (
+          A1,
+          A2,
+          A3,
+          A4,
+          A5,
+          A6,
+          A7,
+          A8,
+          A9,
+          A10,
+          A11,
+          A12,
+          A13,
+          A14,
+          A15,
+          A16,
+          A17,
+          A18,
+          A19,
+          A20,
+          A21) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16],
+      typeTag[A17],
+      typeTag[A18],
+      typeTag[A19],
+      typeTag[A20],
+      typeTag[A21])
+    register(name, udf)
+  }
+
+  /**
+   * Registers a deterministic Scala closure of 22 arguments as user-defined 
function (UDF).
+   * @tparam RT
+   *   return type of UDF.
+   * @since 3.5.0
+   */
+  def register[
+      RT: TypeTag,
+      A1: TypeTag,
+      A2: TypeTag,
+      A3: TypeTag,
+      A4: TypeTag,
+      A5: TypeTag,
+      A6: TypeTag,
+      A7: TypeTag,
+      A8: TypeTag,
+      A9: TypeTag,
+      A10: TypeTag,
+      A11: TypeTag,
+      A12: TypeTag,
+      A13: TypeTag,
+      A14: TypeTag,
+      A15: TypeTag,
+      A16: TypeTag,
+      A17: TypeTag,
+      A18: TypeTag,
+      A19: TypeTag,
+      A20: TypeTag,
+      A21: TypeTag,
+      A22: TypeTag](
+      name: String,
+      func: (
+          A1,
+          A2,
+          A3,
+          A4,
+          A5,
+          A6,
+          A7,
+          A8,
+          A9,
+          A10,
+          A11,
+          A12,
+          A13,
+          A14,
+          A15,
+          A16,
+          A17,
+          A18,
+          A19,
+          A20,
+          A21,
+          A22) => RT): UserDefinedFunction = {
+    val udf = ScalarUserDefinedFunction(
+      func,
+      typeTag[RT],
+      typeTag[A1],
+      typeTag[A2],
+      typeTag[A3],
+      typeTag[A4],
+      typeTag[A5],
+      typeTag[A6],
+      typeTag[A7],
+      typeTag[A8],
+      typeTag[A9],
+      typeTag[A10],
+      typeTag[A11],
+      typeTag[A12],
+      typeTag[A13],
+      typeTag[A14],
+      typeTag[A15],
+      typeTag[A16],
+      typeTag[A17],
+      typeTag[A18],
+      typeTag[A19],
+      typeTag[A20],
+      typeTag[A21],
+      typeTag[A22])
+    register(name, udf)
+  }
+  // scalastyle:on line.size.limit
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 7bce4b5b31a..18aef8a2e4c 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -130,6 +130,16 @@ case class ScalarUserDefinedFunction private (
   override def asNonNullable(): ScalarUserDefinedFunction = copy(nullable = 
false)
 
   override def asNondeterministic(): ScalarUserDefinedFunction = 
copy(deterministic = false)
+
+  def toProto: proto.CommonInlineUserDefinedFunction = {
+    val builder = proto.CommonInlineUserDefinedFunction.newBuilder()
+    builder
+      .setDeterministic(deterministic)
+      .setScalarScalaUdf(udf)
+
+    name.foreach(builder.setFunctionName)
+    builder.build()
+  }
 }
 
 object ScalarUserDefinedFunction {
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index b0ae4c9752a..17d1cdca350 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -7905,6 +7905,23 @@ object functions {
   }
   // scalastyle:off line.size.limit
 
+  /**
+   * Call an user-defined function. Example:
+   * {{{
+   *  import org.apache.spark.sql._
+   *
+   *  val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+   *  val spark = df.sparkSession
+   *  spark.udf.register("simpleUDF", (v: Int) => v * v)
+   *  df.select($"id", call_udf("simpleUDF", $"value"))
+   * }}}
+   *
+   * @group udf_funcs
+   * @since 3.5.0
+   */
+  @scala.annotation.varargs
+  def call_udf(udfName: String, cols: Column*): Column = 
call_function(udfName, cols: _*)
+
   /**
    * Call a builtin or temp function.
    *
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 40841aa3b39..58758a13840 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -206,4 +206,35 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     assertContains("Array[Int] = Array(2, 2, 2, 2, 2)", output)
     // scalastyle:on classforname line.size.limit
   }
+
+  test("UDF Registration") {
+    val input = """
+        |class A(x: Int) { def get = x * 100 }
+        |val myUdf = udf((x: Int) => new A(x).get)
+        |spark.udf.register("dummyUdf", myUdf)
+        |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[Long] = Array(0L, 100L, 200L, 300L, 400L)", output)
+  }
+
+  test("UDF closure registration") {
+    val input = """
+        |class A(x: Int) { def get = x * 15 }
+        |spark.udf.register("directUdf", (x: Int) => new A(x).get)
+        |spark.sql("select directUdf(id) from range(5)").as[Long].collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[Long] = Array(0L, 15L, 30L, 45L, 60L)", output)
+  }
+
+  test("call_udf") {
+    val input = """
+        |val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value")
+        |spark.udf.register("simpleUDF", (v: Int) => v * v)
+        |df.select($"id", call_udf("simpleUDF", $"value")).collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], 
[id3,25])", output)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 921381caf53..130d22842b3 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -154,7 +154,6 @@ object CheckConnectJvmClientCompatibility {
       ProblemFilters.exclude[MissingClassProblem](
         "org.apache.spark.sql.SparkSessionExtensionsProvider"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"),
-      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"),
 
       // DataFrame Reader & Writer
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 48d5e7509c3..e0bee824195 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -70,7 +70,7 @@ import 
org.apache.spark.sql.execution.python.{PythonForeachWriter, UserDefinedPy
 import org.apache.spark.sql.execution.stat.StatFunctions
 import 
org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString
 import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
-import org.apache.spark.sql.expressions.ReduceAggregator
+import org.apache.spark.sql.expressions.{ReduceAggregator, 
SparkUserDefinedFunction}
 import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils}
 import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, 
ProtobufDataToCatalyst}
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, 
StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
@@ -1487,6 +1487,20 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
       udfDeterministic = fun.getDeterministic)
   }
 
+  private def transformScalarScalaFunction(
+      fun: proto.CommonInlineUserDefinedFunction): SparkUserDefinedFunction = {
+    val udf = fun.getScalarScalaUdf
+    val udfPacket = unpackUdf(fun)
+    SparkUserDefinedFunction(
+      f = udfPacket.function,
+      dataType = transformDataType(udf.getOutputType),
+      inputEncoders = udfPacket.inputEncoders.map(e => 
Try(ExpressionEncoder(e)).toOption),
+      outputEncoder = Option(ExpressionEncoder(udfPacket.outputEncoder)),
+      name = Option(fun.getFunctionName),
+      nullable = udf.getNullable,
+      deterministic = fun.getDeterministic)
+  }
+
   /**
    * Translates a Python user-defined function from proto to the Catalyst 
expression.
    *
@@ -2415,6 +2429,8 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
         handleRegisterPythonUDF(fun)
       case proto.CommonInlineUserDefinedFunction.FunctionCase.JAVA_UDF =>
         handleRegisterJavaUDF(fun)
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.SCALAR_SCALA_UDF 
=>
+        handleRegisterScalarScalaUDF(fun)
       case _ =>
         throw InvalidPlanInput(
           s"Function with ID: ${fun.getFunctionCase.getNumber} is not 
supported")
@@ -2448,6 +2464,11 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
     }
   }
 
+  private def handleRegisterScalarScalaUDF(fun: 
proto.CommonInlineUserDefinedFunction): Unit = {
+    val udf = transformScalarScalaFunction(fun)
+    session.udf.register(fun.getFunctionName, udf)
+  }
+
   private def handleCommandPlugin(extension: ProtoAny): Unit = {
     SparkConnectPluginRegistry.commandRegistry
       // Lazily traverse the collection.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to