This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 357bcac823c [SPARK-44388][CONNECT] Fix protobuf cast issue when UDF 
instance is updated
357bcac823c is described below

commit 357bcac823c0fafbdcb95458327d35e4a492046c
Author: vicennial <venkata.gud...@databricks.com>
AuthorDate: Thu Jul 13 18:55:10 2023 +0900

    [SPARK-44388][CONNECT] Fix protobuf cast issue when UDF instance is updated
    
    ### What changes were proposed in this pull request?
    
    This PR modifies the arguments of the `ScalarUserDefinedFunction` case 
class to take in `serializedUdfPacket`, `inputTypes` and `outputType` instead 
of calculating these types internally from `function`, `inputEncoders` and 
`outputEncoder`. This allows the class to copy over the serialized udf value 
from the parent instance instead of re-creating it. Through this, we avoid 
hitting a variant of the issue mentioned in 
[SPARK-43198](https://issues.apache.org/jira/browse/SPARK-43198) whic [...]
    
    ### Why are the changes needed?
    
    Bugfix. Consider the following code:
    ```
    class A(x: Int) { def get = x * 7 }
    val myUdf = udf((x: Int) => new A(x).get)
    val modifiedUdf = myUdf.withName("myUdf").asNondeterministic()
    spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect()
    ```
    
    Executing this code currently results in hitting the following error:
    ```
    java.lang.ClassCastException: org.apache.spark.connect.proto.ScalarScalaUDF 
cannot be cast to com.google.protobuf.MessageLite
        at 
com.google.protobuf.GeneratedMessageLite$SerializedForm.readResolve(GeneratedMessageLite.java:1462)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at 
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at 
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at 
java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1274)
        at 
java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2196)
    ...
    ```
    If we do not include the `myUdf.withName("myUdf").asNondeterministic()`, 
the UDF runs as expected.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, fixes the bug mentioned above.
    
    ### How was this patch tested?
    
    Added a new E2E test in `ReplE2ESuite`.
    
    Closes #41959 from vicennial/SPARK-44388.
    
    Lead-authored-by: vicennial <venkata.gud...@databricks.com>
    Co-authored-by: Venkata Sai Akhil Gudesa <venkata.gud...@databricks.com>
    Co-authored-by: Venkata Sai Akhil Gudesa <gvs.akhil1...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/expressions/UserDefinedFunction.scala      | 29 +++++++++++-----------
 .../spark/sql/application/ReplE2ESuite.scala       | 11 ++++++++
 2 files changed, 25 insertions(+), 15 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index d911b7efe29..7bce4b5b31a 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -92,26 +92,23 @@ sealed abstract class UserDefinedFunction {
 /**
  * Holder class for a scalar user-defined function and it's input/output 
encoder(s).
  */
-case class ScalarUserDefinedFunction(
-    function: AnyRef,
-    inputEncoders: Seq[AgnosticEncoder[_]],
-    outputEncoder: AgnosticEncoder[_],
+case class ScalarUserDefinedFunction private (
+    // SPARK-43198: Eagerly serialize to prevent the UDF from containing a 
reference to this class.
+    serializedUdfPacket: Array[Byte],
+    inputTypes: Seq[proto.DataType],
+    outputType: proto.DataType,
     name: Option[String],
     override val nullable: Boolean,
     override val deterministic: Boolean)
     extends UserDefinedFunction {
 
-  // SPARK-43198: Eagerly serialize to prevent the UDF from containing a 
reference to this class.
-  private[this] val udf = {
-    val udfPacketBytes =
-      SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, 
outputEncoder))
+  private[this] lazy val udf = {
     val scalaUdfBuilder = proto.ScalarScalaUDF
       .newBuilder()
-      .setPayload(ByteString.copyFrom(udfPacketBytes))
+      .setPayload(ByteString.copyFrom(serializedUdfPacket))
       // Send the real inputs and return types to obtain the types without 
deser the udf bytes.
-      .addAllInputTypes(
-        
inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType).asJava)
-      
.setOutputType(DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType))
+      .addAllInputTypes(inputTypes.asJava)
+      .setOutputType(outputType)
       .setNullable(nullable)
 
     scalaUdfBuilder.build()
@@ -154,10 +151,12 @@ object ScalarUserDefinedFunction {
       function: AnyRef,
       inputEncoders: Seq[AgnosticEncoder[_]],
       outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
+    val udfPacketBytes =
+      SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, 
outputEncoder))
     ScalarUserDefinedFunction(
-      function = function,
-      inputEncoders = inputEncoders,
-      outputEncoder = outputEncoder,
+      serializedUdfPacket = udfPacketBytes,
+      inputTypes = 
inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType),
+      outputType = 
DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType),
       name = None,
       nullable = true,
       deterministic = true)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 676ad6b090e..40841aa3b39 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -158,6 +158,17 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     assertContains("Array[Int] = Array(5, 47, 89, 131, 173)", output)
   }
 
+  test("Updating UDF properties") {
+    val input = """
+        |class A(x: Int) { def get = x * 7 }
+        |val myUdf = udf((x: Int) => new A(x).get)
+        |val modifiedUdf = myUdf.withName("myUdf").asNondeterministic()
+        |spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[Int] = Array(0, 7, 14, 21, 28)", output)
+  }
+
   test("SPARK-43198: Filter does not throw ammonite-related class 
initialization exception") {
     val input = """
         |spark.range(10).filter(n => n % 2 == 0).collect()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to