Repository: beam
Updated Branches:
  refs/heads/master 480e60fa4 -> 7ad45ad8c


[BEAM-2701] ensure objectinputstream uses the right classloader for 
serialization


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/bccea9dc
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/bccea9dc
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/bccea9dc

Branch: refs/heads/master
Commit: bccea9dc059749cfd43419416e4959a6ab578090
Parents: 480e60f
Author: Romain Manni-Bucau <rmannibu...@gmail.com>
Authored: Wed Sep 20 17:29:53 2017 +0200
Committer: Luke Cwik <lc...@google.com>
Committed: Fri Oct 6 09:26:50 2017 -0700

----------------------------------------------------------------------
 .../org/apache/beam/sdk/util/ApiSurface.java    |  2 +
 .../apache/beam/sdk/util/SerializableUtils.java | 69 ++++++++++++++++++--
 .../apache/beam/sdk/coders/AvroCoderTest.java   | 35 ++--------
 .../sdk/testing/InterceptingUrlClassLoader.java | 57 ++++++++++++++++
 .../beam/sdk/util/SerializableUtilsTest.java    | 60 +++++++++++++++++
 5 files changed, 185 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java
index 735190b..1266d75 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java
@@ -834,6 +834,8 @@ public class ApiSurface {
         .pruningPattern("org[.]apache[.]beam[.].*Test")
         // Exposes Guava, but not intended for users
         .pruningClassName("org.apache.beam.sdk.util.common.ReflectHelpers")
+         // test only
+        
.pruningClassName("org.apache.beam.sdk.testing.InterceptingUrlClassLoader")
         .pruningPrefix("java");
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java
index d4bfd0b..cf5a6f3 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java
@@ -24,12 +24,16 @@ import static 
org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.io.InputStream;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.io.ObjectStreamClass;
 import java.io.Serializable;
+import java.lang.reflect.Proxy;
 import java.util.Arrays;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.util.common.ReflectHelpers;
 import org.xerial.snappy.SnappyInputStream;
 import org.xerial.snappy.SnappyOutputStream;
 
@@ -67,7 +71,7 @@ public class SerializableUtils {
   public static Object deserializeFromByteArray(byte[] encodedValue,
       String description) {
     try {
-      try (ObjectInputStream ois = new ObjectInputStream(
+      try (ObjectInputStream ois = new ContextualObjectInputStream(
           new SnappyInputStream(new ByteArrayInputStream(encodedValue)))) {
         return ois.readObject();
       }
@@ -79,16 +83,31 @@ public class SerializableUtils {
   }
 
   public static <T extends Serializable> T ensureSerializable(T value) {
-    @SuppressWarnings("unchecked")
-    T copy = (T) deserializeFromByteArray(serializeToByteArray(value),
-        value.toString());
-    return copy;
+    return clone(value);
   }
 
   public static <T extends Serializable> T clone(T value) {
+    final Thread thread = Thread.currentThread();
+    final ClassLoader tccl = thread.getContextClassLoader();
+    ClassLoader loader = tccl;
+    try {
+      if (tccl.loadClass(value.getClass().getName()) != value.getClass()) {
+        loader = value.getClass().getClassLoader();
+      }
+    } catch (final NoClassDefFoundError | ClassNotFoundException e) {
+      loader = value.getClass().getClassLoader();
+    }
+    if (loader == null) {
+      loader = tccl; // will likely fail but the best we can do
+    }
+    thread.setContextClassLoader(loader);
     @SuppressWarnings("unchecked")
-    T copy = (T) deserializeFromByteArray(serializeToByteArray(value),
-        value.toString());
+    final T copy;
+    try {
+      copy = (T) deserializeFromByteArray(serializeToByteArray(value), 
value.toString());
+    } finally {
+      thread.setContextClassLoader(tccl);
+    }
     return copy;
   }
 
@@ -144,4 +163,40 @@ public class SerializableUtils {
             exn);
       }
   }
+
+  private static final class ContextualObjectInputStream extends 
ObjectInputStream {
+    private ContextualObjectInputStream(final InputStream in) throws 
IOException {
+      super(in);
+    }
+
+    @Override
+    protected Class<?> resolveClass(final ObjectStreamClass classDesc)
+            throws IOException, ClassNotFoundException {
+      // note: staying aligned on JVM default but can need class filtering 
here to avoid 0day issue
+      final String n = classDesc.getName();
+      final ClassLoader classloader = ReflectHelpers.findClassLoader();
+      try {
+        return Class.forName(n, false, classloader);
+      } catch (final ClassNotFoundException e) {
+        return super.resolveClass(classDesc);
+      }
+    }
+
+    @Override
+    protected Class resolveProxyClass(final String[] interfaces)
+            throws IOException, ClassNotFoundException {
+      final ClassLoader classloader = ReflectHelpers.findClassLoader();
+
+      final Class[] cinterfaces = new Class[interfaces.length];
+      for (int i = 0; i < interfaces.length; i++) {
+        cinterfaces[i] = classloader.loadClass(interfaces[i]);
+      }
+
+      try {
+        return Proxy.getProxyClass(classloader, cinterfaces);
+      } catch (final IllegalArgumentException e) {
+        throw new ClassNotFoundException(null, e);
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java
index 60b3232..deecb96 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java
@@ -26,10 +26,8 @@ import static org.junit.Assert.fail;
 import com.esotericsoftware.kryo.Kryo;
 import com.esotericsoftware.kryo.io.Input;
 import com.esotericsoftware.kryo.io.Output;
-import com.google.common.io.ByteStreams;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
-import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.util.ArrayList;
@@ -60,6 +58,7 @@ import org.apache.avro.util.Utf8;
 import org.apache.beam.sdk.coders.Coder.Context;
 import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
 import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.InterceptingUrlClassLoader;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
@@ -166,42 +165,16 @@ public class AvroCoderTest {
   }
 
   /**
-   * A classloader that intercepts loading of Pojo and makes a new one.
-   */
-  private static class InterceptingUrlClassLoader extends ClassLoader {
-
-    private InterceptingUrlClassLoader(ClassLoader parent) {
-      super(parent);
-    }
-
-    @Override
-    public Class<?> loadClass(String name) throws ClassNotFoundException {
-      if (name.equals(AvroCoderTestPojo.class.getName())) {
-        // Quite a hack?
-        try {
-          String classAsResource = name.replace('.', '/') + ".class";
-          byte[] classBytes =
-              
ByteStreams.toByteArray(getParent().getResourceAsStream(classAsResource));
-          return defineClass(name, classBytes, 0, classBytes.length);
-        } catch (IOException e) {
-          throw new RuntimeException(e);
-        }
-      } else {
-        return getParent().loadClass(name);
-      }
-    }
-  }
-
-  /**
    * Tests that {@link AvroCoder} works around issues in Avro where cache 
classes might be
    * from the wrong ClassLoader, causing confusing "Cannot cast X to X" error 
messages.
    */
   @Test
   public void testTwoClassLoaders() throws Exception {
+    ClassLoader contextClassLoader = 
Thread.currentThread().getContextClassLoader();
     ClassLoader loader1 =
-        new 
InterceptingUrlClassLoader(Thread.currentThread().getContextClassLoader());
+        new InterceptingUrlClassLoader(contextClassLoader, 
AvroCoderTestPojo.class.getName());
     ClassLoader loader2 =
-        new 
InterceptingUrlClassLoader(Thread.currentThread().getContextClassLoader());
+        new InterceptingUrlClassLoader(contextClassLoader, 
AvroCoderTestPojo.class.getName());
 
     Class<?> pojoClass1 = loader1.loadClass(AvroCoderTestPojo.class.getName());
     Class<?> pojoClass2 = loader2.loadClass(AvroCoderTestPojo.class.getName());

http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java
new file mode 100644
index 0000000..b5adcb5
--- /dev/null
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.testing;
+
+import com.google.common.collect.Sets;
+import com.google.common.io.ByteStreams;
+import java.io.IOException;
+import java.util.Set;
+
+/**
+ * A classloader that intercepts loading of specifically named classes. This 
classloader copies
+ * the original classes definition and is useful for testing code which needs 
to validate usage
+ * with multiple classloaders..
+ */
+public class InterceptingUrlClassLoader extends ClassLoader {
+    private final Set<String> ownedClasses;
+
+    public InterceptingUrlClassLoader(final ClassLoader parent, final 
String... ownedClasses) {
+        super(parent);
+        this.ownedClasses = Sets.newHashSet(ownedClasses);
+    }
+
+    @Override
+    public Class<?> loadClass(final String name) throws ClassNotFoundException 
{
+        final Class<?> alreadyLoaded = super.findLoadedClass(name);
+        if (alreadyLoaded != null) {
+            return alreadyLoaded;
+        }
+
+        if (name != null && ownedClasses.contains(name)) {
+            try {
+                final String classAsResource = name.replace('.', '/') + 
".class";
+                final byte[] classBytes =
+                        
ByteStreams.toByteArray(getParent().getResourceAsStream(classAsResource));
+                return defineClass(name, classBytes, 0, classBytes.length);
+            } catch (final IOException e) {
+                throw new RuntimeException(e);
+            }
+        }
+        return getParent().loadClass(name);
+    }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java
 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java
index 9a80730..c3b0171 100644
--- 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java
+++ 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java
@@ -18,8 +18,10 @@
 package org.apache.beam.sdk.util;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotSame;
 
 import com.google.common.collect.ImmutableList;
+
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -28,6 +30,9 @@ import java.util.List;
 import org.apache.beam.sdk.coders.AtomicCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.testing.InterceptingUrlClassLoader;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -51,6 +56,30 @@ public class SerializableUtilsTest {
   }
 
   @Test
+  public void customClassLoader() throws Exception {
+    // define a classloader with test-classes in it
+    final ClassLoader testLoader = 
Thread.currentThread().getContextClassLoader();
+    final ClassLoader loader = new InterceptingUrlClassLoader(testLoader, 
MySource.class.getName());
+    final Class<?> source = loader.loadClass(
+            "org.apache.beam.sdk.util.SerializableUtilsTest$MySource");
+    assertNotSame(source.getClassLoader(), MySource.class.getClassLoader());
+
+    // validate if the caller set the classloader that it works well
+    final Serializable customLoaderSourceInstance = Serializable.class.cast(
+            source.getConstructor().newInstance());
+    final Thread thread = Thread.currentThread();
+    thread.setContextClassLoader(loader);
+    try {
+      assertSerializationClassLoader(loader, customLoaderSourceInstance);
+    } finally {
+      thread.setContextClassLoader(testLoader);
+    }
+
+    // now let beam be a little be more fancy and try to ensure it by itself 
from the incoming value
+    assertSerializationClassLoader(loader, customLoaderSourceInstance);
+  }
+
+  @Test
   public void testTranscode() {
     String stringValue = "hi bob";
     int intValue = 42;
@@ -114,4 +143,35 @@ public class SerializableUtilsTest {
     expectedException.expectMessage("unable to serialize");
     SerializableUtils.ensureSerializable(new UnserializableCoderByJava());
   }
+
+  private void assertSerializationClassLoader(
+          final ClassLoader loader, final Serializable 
customLoaderSourceInstance) {
+    final Serializable copy = 
SerializableUtils.ensureSerializable(customLoaderSourceInstance);
+    assertEquals(loader, copy.getClass().getClassLoader());
+    assertEquals(
+            copy.getClass().getClassLoader(),
+            customLoaderSourceInstance.getClass().getClassLoader());
+  }
+
+  /**
+   * a sample class to test framework serialization,
+   * {@see SerializableUtilsTest#customClassLoader}.
+   */
+  public static class MySource extends BoundedSource<String> {
+    @Override
+    public List<? extends BoundedSource<String>> split(
+            final long desiredBundleSizeBytes, final PipelineOptions options) 
throws Exception {
+      return null;
+    }
+
+    @Override
+    public long getEstimatedSizeBytes(final PipelineOptions options) throws 
Exception {
+      return 0;
+    }
+
+    @Override
+    public BoundedReader<String> createReader(final PipelineOptions options) 
throws IOException {
+      return null;
+    }
+  }
 }

Reply via email to