This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 90fe41b  [SPARK-36627][CORE] Fix java deserialization of proxy classes
90fe41b is described below

commit 90fe41b70a9d0403418aa05a220d38c20f51c6f9
Author: Samuel Souza <sso...@palantir.com>
AuthorDate: Thu Oct 28 18:15:38 2021 -0500

    [SPARK-36627][CORE] Fix java deserialization of proxy classes
    
    ## Upstream SPARK-XXXXX ticket and PR link (if not applicable, explain)
    https://issues.apache.org/jira/browse/SPARK-36627
    
    ## What changes were proposed in this pull request?
    In JavaSerializer.JavaDeserializationStream we override resolveClass of 
ObjectInputStream to use the threads' contextClassLoader. However, we do not 
override resolveProxyClass, which is used when deserializing Java proxy 
objects, which makes spark use the wrong classloader when deserializing 
objects, which causes the job to fail with the following exception:
    
    ```
    Caused by: org.apache.spark.SparkException: Job aborted due to stage 
failure: Task 0 in stage 1.0 failed 4 times, most recent failure: Lost task 0.3 
in stage 1.0 (TID 4, <host>, executor 1): java.lang.ClassNotFoundException: 
<class&gt;
        at 
java.base/jdk.internal.loader.BuiltinClassLoader.loadClass(BuiltinClassLoader.java:581)
        at 
java.base/jdk.internal.loader.ClassLoaders$AppClassLoader.loadClass(ClassLoaders.java:178)
        at java.base/java.lang.ClassLoader.loadClass(ClassLoader.java:522)
        at java.base/java.lang.Class.forName0(Native Method)
        at java.base/java.lang.Class.forName(Class.java:398)
        at 
java.base/java.io.ObjectInputStream.resolveProxyClass(ObjectInputStream.java:829)
        at 
java.base/java.io.ObjectInputStream.readProxyDesc(ObjectInputStream.java:1917)
        ...
        at 
org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:76)
    ```
    
    ### Why are the changes needed?
    Spark deserialization fails with no recourse for the user.
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #33879 from fsamuel-bs/SPARK-36627.
    
    Authored-by: Samuel Souza <sso...@palantir.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../apache/spark/serializer/JavaSerializer.scala   | 50 +++++++++++++++-------
 .../spark/serializer/ContainsProxyClass.java       | 50 ++++++++++++++++++++++
 .../spark/serializer/JavaSerializerSuite.scala     | 26 ++++++++++-
 3 files changed, 108 insertions(+), 18 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala 
b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 077b035..9d76611 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -28,8 +28,10 @@ import org.apache.spark.internal.config._
 import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, 
Utils}
 
 private[spark] class JavaSerializationStream(
-    out: OutputStream, counterReset: Int, extraDebugInfo: Boolean)
-  extends SerializationStream {
+    out: OutputStream,
+    counterReset: Int,
+    extraDebugInfo: Boolean)
+    extends SerializationStream {
   private val objOut = new ObjectOutputStream(out)
   private var counter = 0
 
@@ -59,9 +61,10 @@ private[spark] class JavaSerializationStream(
 }
 
 private[spark] class JavaDeserializationStream(in: InputStream, loader: 
ClassLoader)
-  extends DeserializationStream {
+    extends DeserializationStream {
 
   private val objIn = new ObjectInputStream(in) {
+
     override def resolveClass(desc: ObjectStreamClass): Class[_] =
       try {
         // scalastyle:off classforname
@@ -71,6 +74,14 @@ private[spark] class JavaDeserializationStream(in: 
InputStream, loader: ClassLoa
         case e: ClassNotFoundException =>
           JavaDeserializationStream.primitiveMappings.getOrElse(desc.getName, 
throw e)
       }
+
+    override def resolveProxyClass(ifaces: Array[String]): Class[_] = {
+      // scalastyle:off classforname
+      val resolved = ifaces.map(iface => Class.forName(iface, false, loader))
+      // scalastyle:on classforname
+      java.lang.reflect.Proxy.getProxyClass(loader, resolved: _*)
+    }
+
   }
 
   def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
@@ -78,6 +89,7 @@ private[spark] class JavaDeserializationStream(in: 
InputStream, loader: ClassLoa
 }
 
 private object JavaDeserializationStream {
+
   val primitiveMappings = Map[String, Class[_]](
     "boolean" -> classOf[Boolean],
     "byte" -> classOf[Byte],
@@ -87,13 +99,15 @@ private object JavaDeserializationStream {
     "long" -> classOf[Long],
     "float" -> classOf[Float],
     "double" -> classOf[Double],
-    "void" -> classOf[Void]
-  )
+    "void" -> classOf[Void])
+
 }
 
 private[spark] class JavaSerializerInstance(
-    counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: 
ClassLoader)
-  extends SerializerInstance {
+    counterReset: Int,
+    extraDebugInfo: Boolean,
+    defaultClassLoader: ClassLoader)
+    extends SerializerInstance {
 
   override def serialize[T: ClassTag](t: T): ByteBuffer = {
     val bos = new ByteBufferOutputStream()
@@ -126,6 +140,7 @@ private[spark] class JavaSerializerInstance(
   def deserializeStream(s: InputStream, loader: ClassLoader): 
DeserializationStream = {
     new JavaDeserializationStream(s, loader)
   }
+
 }
 
 /**
@@ -141,20 +156,23 @@ class JavaSerializer(conf: SparkConf) extends Serializer 
with Externalizable {
   private var counterReset = conf.get(SERIALIZER_OBJECT_STREAM_RESET)
   private var extraDebugInfo = conf.get(SERIALIZER_EXTRA_DEBUG_INFO)
 
-  protected def this() = this(new SparkConf())  // For deserialization only
+  protected def this() = this(new SparkConf()) // For deserialization only
 
   override def newInstance(): SerializerInstance = {
     val classLoader = 
defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
     new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
   }
 
-  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException 
{
-    out.writeInt(counterReset)
-    out.writeBoolean(extraDebugInfo)
-  }
+  override def writeExternal(out: ObjectOutput): Unit =
+    Utils.tryOrIOException {
+      out.writeInt(counterReset)
+      out.writeBoolean(extraDebugInfo)
+    }
+
+  override def readExternal(in: ObjectInput): Unit =
+    Utils.tryOrIOException {
+      counterReset = in.readInt()
+      extraDebugInfo = in.readBoolean()
+    }
 
-  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
-    counterReset = in.readInt()
-    extraDebugInfo = in.readBoolean()
-  }
 }
diff --git 
a/core/src/test/java/org/apache/spark/serializer/ContainsProxyClass.java 
b/core/src/test/java/org/apache/spark/serializer/ContainsProxyClass.java
new file mode 100644
index 0000000..66b2ba4
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/serializer/ContainsProxyClass.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer;
+
+import java.io.Serializable;
+import java.lang.reflect.InvocationHandler;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
+
+class ContainsProxyClass implements Serializable {
+  final MyInterface proxy = (MyInterface) Proxy.newProxyInstance(
+    MyInterface.class.getClassLoader(),
+    new Class[]{MyInterface.class},
+    new MyInvocationHandler());
+
+  // Interface needs to be public as classloaders will mismatch.
+  // See ObjectInputStream#resolveProxyClass for details.
+  public interface MyInterface {
+    void myMethod();
+  }
+
+  static class MyClass implements MyInterface, Serializable {
+    @Override
+    public void myMethod() {}
+  }
+
+  class MyInvocationHandler implements InvocationHandler, Serializable {
+    private final MyClass real = new MyClass();
+
+    @Override
+    public Object invoke(Object proxy, Method method, Object[] args) throws 
Throwable {
+      return method.invoke(real, args);
+    }
+  }
+}
diff --git 
a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala 
b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
index 6a6ea42..77226af 100644
--- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -31,11 +31,33 @@ class JavaSerializerSuite extends SparkFunSuite {
   test("Deserialize object containing a primitive Class as attribute") {
     val serializer = new JavaSerializer(new SparkConf())
     val instance = serializer.newInstance()
-    val obj = instance.deserialize[ContainsPrimitiveClass](instance.serialize(
-      new ContainsPrimitiveClass()))
+    val obj = instance.deserialize[ContainsPrimitiveClass](
+      instance.serialize(new ContainsPrimitiveClass()))
     // enforce class cast
     obj.getClass
   }
+
+  test("SPARK-36627: Deserialize object containing a proxy Class as 
attribute") {
+    var classesLoaded = Set[String]()
+    val outer = Thread.currentThread.getContextClassLoader
+    val inner = new ClassLoader() {
+      override def loadClass(name: String): Class[_] = {
+        classesLoaded = classesLoaded + name
+        outer.loadClass(name)
+      }
+    }
+    Thread.currentThread.setContextClassLoader(inner)
+
+    val serializer = new JavaSerializer(new SparkConf())
+    val instance = serializer.newInstance()
+    val obj =
+      instance.deserialize[ContainsProxyClass](instance.serialize(new 
ContainsProxyClass()))
+    // enforce class cast
+    obj.getClass
+
+    // check that serializer's loader is used to resolve proxied interface.
+    assert(classesLoaded.exists(klass => klass.contains("MyInterface")))
+  }
 }
 
 private class ContainsPrimitiveClass extends Serializable {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to