This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 8537fa634cd [SPARK-29497][CONNECT] Throw error when UDF is not 
deserializable
8537fa634cd is described below

commit 8537fa634cd02f46e7b42afd6b35f877f3a2c161
Author: Herman van Hovell <hvanhov...@databricks.com>
AuthorDate: Tue Aug 1 14:53:54 2023 -0400

    [SPARK-29497][CONNECT] Throw error when UDF is not deserializable
    
    ### What changes were proposed in this pull request?
    This PR adds a better error message when a JVM UDF cannot be deserialized.
    
    ### Why are the changes needed?
    In some cases a UDF cannot be deserialized. The happens when a lambda 
references itself (typically through the capturing class). Java cannot 
deserialize such an object graph because SerializedLambda's are serialization 
proxies which need the full graph to be deserialized before they can be 
transformed into the actual lambda. This is not possible if there is such a 
cycle. This PR adds a more readable and understandable error when this happens, 
the original java one is a `ClassCastExcep [...]
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. It will throw an error on the client when a UDF is not deserializable. 
The error is better and more actionable then what we got before.
    
    ### How was this patch tested?
    Added tests.
    
    Closes #42245 from hvanhovell/SPARK-29497.
    
    Lead-authored-by: Herman van Hovell <hvanhov...@databricks.com>
    Co-authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
    (cherry picked from commit f54b402021785e0b0ec976ec889de67d3b2fdc6e)
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../org/apache/spark/util/SparkSerDeUtils.scala    | 21 ++++++++++-
 .../sql/expressions/UserDefinedFunction.scala      | 24 +++++++++++-
 .../spark/sql/UserDefinedFunctionSuite.scala       | 44 ++++++++++++++++++++--
 .../main/scala/org/apache/spark/util/Utils.scala   | 23 +----------
 4 files changed, 85 insertions(+), 27 deletions(-)

diff --git 
a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala 
b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala
index 3069e4c36a7..9b6174c47bd 100644
--- a/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/SparkSerDeUtils.scala
@@ -16,9 +16,9 @@
  */
 package org.apache.spark.util
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
ObjectInputStream, ObjectOutputStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
ObjectInputStream, ObjectOutputStream, ObjectStreamClass}
 
-object SparkSerDeUtils {
+trait SparkSerDeUtils {
   /** Serialize an object using Java serialization */
   def serialize[T](o: T): Array[Byte] = {
     val bos = new ByteArrayOutputStream()
@@ -34,4 +34,21 @@ object SparkSerDeUtils {
     val ois = new ObjectInputStream(bis)
     ois.readObject.asInstanceOf[T]
   }
+
+  /**
+   * Deserialize an object using Java serialization and the given ClassLoader
+   */
+  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
+    val bis = new ByteArrayInputStream(bytes)
+    val ois = new ObjectInputStream(bis) {
+      override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+        // scalastyle:off classforname
+        Class.forName(desc.getName, false, loader)
+        // scalastyle:on classforname
+      }
+    }
+    ois.readObject.asInstanceOf[T]
+  }
 }
+
+object SparkSerDeUtils extends SparkSerDeUtils
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 3a38029c265..e060dba0b7e 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -18,16 +18,18 @@ package org.apache.spark.sql.expressions
 
 import scala.collection.JavaConverters._
 import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
 
 import com.google.protobuf.ByteString
 
+import org.apache.spark.SparkException
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
 import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.SparkSerDeUtils
+import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils}
 
 /**
  * A user-defined function. To create one, use the `udf` functions in 
`functions`.
@@ -144,6 +146,25 @@ case class ScalarUserDefinedFunction private[sql] (
 }
 
 object ScalarUserDefinedFunction {
+  private val LAMBDA_DESERIALIZATION_ERR_MSG: String =
+    "cannot assign instance of java.lang.invoke.SerializedLambda to field"
+
+  private def checkDeserializable(bytes: Array[Byte]): Unit = {
+    try {
+      SparkSerDeUtils.deserialize(bytes, 
SparkClassUtils.getContextOrSparkClassLoader)
+    } catch {
+      case e: ClassCastException if 
e.getMessage.contains(LAMBDA_DESERIALIZATION_ERR_MSG) =>
+        throw new SparkException(
+          "UDF cannot be executed on a Spark cluster: it cannot be 
deserialized. " +
+            "This is very likely to be caused by the lambda function (the UDF) 
having a " +
+            "self-reference. This is not supported by java serialization.")
+      case NonFatal(e) =>
+        throw new SparkException(
+          "UDF cannot be executed on a Spark cluster: it cannot be 
deserialized.",
+          e)
+    }
+  }
+
   private[sql] def apply(
       function: AnyRef,
       returnType: TypeTag[_],
@@ -164,6 +185,7 @@ object ScalarUserDefinedFunction {
       outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
     val udfPacketBytes =
       SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, 
outputEncoder))
+    checkDeserializable(udfPacketBytes)
     ScalarUserDefinedFunction(
       serializedUdfPacket = udfPacketBytes,
       inputTypes = 
inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType),
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
index 684f5671e48..76608559866 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
@@ -18,15 +18,14 @@ package org.apache.spark.sql
 
 import scala.reflect.runtime.universe.typeTag
 
-import org.scalatest.BeforeAndAfterEach
-
+import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.connect.client.util.ConnectFunSuite
 import org.apache.spark.sql.connect.common.UdfPacket
 import org.apache.spark.sql.functions.udf
 import org.apache.spark.util.SparkSerDeUtils
 
-class UserDefinedFunctionSuite extends ConnectFunSuite with BeforeAndAfterEach 
{
+class UserDefinedFunctionSuite extends ConnectFunSuite {
 
   test("udf and encoder serialization") {
     def func(x: Int): Int = x + 1
@@ -48,4 +47,43 @@ class UserDefinedFunctionSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
     assert(deSer.outputEncoder == ScalaReflection.encoderFor(typeTag[Int]))
     assert(deSer.inputEncoders == 
Seq(ScalaReflection.encoderFor(typeTag[Int])))
   }
+
+  private def testNonDeserializable(f: Int => Int): Unit = {
+    val e = intercept[SparkException](udf(f))
+    assert(
+      e.getMessage.contains(
+        "UDF cannot be executed on a Spark cluster: it cannot be 
deserialized."))
+    assert(e.getMessage.contains("This is not supported by java 
serialization."))
+  }
+
+  test("non deserializable UDFs") {
+    testNonDeserializable(Command2(Command1()).indirect)
+    testNonDeserializable(MultipleLambdas().indirect)
+    testNonDeserializable(SelfRef(22).method)
+  }
+
+  test("serializable UDFs") {
+    val direct = (i: Int) => i + 1
+    val indirect = (i: Int) => direct(i)
+    udf(indirect)
+    udf(Command1().direct)
+    udf(MultipleLambdas().direct)
+  }
+}
+
+case class Command1() extends Serializable {
+  val direct: Int => Int = (i: Int) => i + 1
+}
+
+case class Command2(prev: Command1) extends Serializable {
+  val indirect: Int => Int = (i: Int) => prev.direct(i)
+}
+
+case class SelfRef(start: Int) extends Serializable {
+  val method: Int => Int = (i: Int) => i + start
+}
+
+case class MultipleLambdas() extends Serializable {
+  val direct: Int => Int = (i: Int) => i + 1
+  val indirect: Int => Int = (i: Int) => direct(i)
 }
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a3002eb40f4..a556f03dc09 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -95,7 +95,8 @@ private[spark] object Utils
   extends Logging
   with SparkClassUtils
   with SparkErrorUtils
-  with SparkFileUtils {
+  with SparkFileUtils
+  with SparkSerDeUtils {
 
   private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler
   @volatile private var cachedLocalDir: String = ""
@@ -121,26 +122,6 @@ private[spark] object Utils
   private val copyBuffer = ThreadLocal.withInitial[Array[Byte]](() => {
     new Array[Byte](COPY_BUFFER_LEN)
   })
-
-  /** Serialize an object using Java serialization */
-  def serialize[T](o: T): Array[Byte] = SparkSerDeUtils.serialize(o)
-
-  /** Deserialize an object using Java serialization */
-  def deserialize[T](bytes: Array[Byte]): T = 
SparkSerDeUtils.deserialize(bytes)
-
-  /** Deserialize an object using Java serialization and the given ClassLoader 
*/
-  def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = {
-    val bis = new ByteArrayInputStream(bytes)
-    val ois = new ObjectInputStream(bis) {
-      override def resolveClass(desc: ObjectStreamClass): Class[_] = {
-        // scalastyle:off classforname
-        Class.forName(desc.getName, false, loader)
-        // scalastyle:on classforname
-      }
-    }
-    ois.readObject.asInstanceOf[T]
-  }
-
   /** Deserialize a Long value (used for 
[[org.apache.spark.api.python.PythonPartitioner]]) */
   def deserializeLongValue(bytes: Array[Byte]) : Long = {
     // Note: we assume that we are given a Long value encoded in network 
(big-endian) byte order


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to