Repository: beam Updated Branches: refs/heads/master 480e60fa4 -> 7ad45ad8c
[BEAM-2701] ensure objectinputstream uses the right classloader for serialization Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/bccea9dc Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/bccea9dc Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/bccea9dc Branch: refs/heads/master Commit: bccea9dc059749cfd43419416e4959a6ab578090 Parents: 480e60f Author: Romain Manni-Bucau <rmannibu...@gmail.com> Authored: Wed Sep 20 17:29:53 2017 +0200 Committer: Luke Cwik <lc...@google.com> Committed: Fri Oct 6 09:26:50 2017 -0700 ---------------------------------------------------------------------- .../org/apache/beam/sdk/util/ApiSurface.java | 2 + .../apache/beam/sdk/util/SerializableUtils.java | 69 ++++++++++++++++++-- .../apache/beam/sdk/coders/AvroCoderTest.java | 35 ++-------- .../sdk/testing/InterceptingUrlClassLoader.java | 57 ++++++++++++++++ .../beam/sdk/util/SerializableUtilsTest.java | 60 +++++++++++++++++ 5 files changed, 185 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java index 735190b..1266d75 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/ApiSurface.java @@ -834,6 +834,8 @@ public class ApiSurface { .pruningPattern("org[.]apache[.]beam[.].*Test") // Exposes Guava, but not intended for users .pruningClassName("org.apache.beam.sdk.util.common.ReflectHelpers") + // test only + .pruningClassName("org.apache.beam.sdk.testing.InterceptingUrlClassLoader") .pruningPrefix("java"); } } http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java index d4bfd0b..cf5a6f3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/SerializableUtils.java @@ -24,12 +24,16 @@ import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.io.Serializable; +import java.lang.reflect.Proxy; import java.util.Arrays; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.util.common.ReflectHelpers; import org.xerial.snappy.SnappyInputStream; import org.xerial.snappy.SnappyOutputStream; @@ -67,7 +71,7 @@ public class SerializableUtils { public static Object deserializeFromByteArray(byte[] encodedValue, String description) { try { - try (ObjectInputStream ois = new ObjectInputStream( + try (ObjectInputStream ois = new ContextualObjectInputStream( new SnappyInputStream(new ByteArrayInputStream(encodedValue)))) { return ois.readObject(); } @@ -79,16 +83,31 @@ public class SerializableUtils { } public static <T extends Serializable> T ensureSerializable(T value) { - @SuppressWarnings("unchecked") - T copy = (T) deserializeFromByteArray(serializeToByteArray(value), - value.toString()); - return copy; + return clone(value); } public static <T extends Serializable> T clone(T value) { + final Thread thread = Thread.currentThread(); + final ClassLoader tccl = thread.getContextClassLoader(); + ClassLoader loader = tccl; + try { + if (tccl.loadClass(value.getClass().getName()) != value.getClass()) { + loader = value.getClass().getClassLoader(); + } + } catch (final NoClassDefFoundError | ClassNotFoundException e) { + loader = value.getClass().getClassLoader(); + } + if (loader == null) { + loader = tccl; // will likely fail but the best we can do + } + thread.setContextClassLoader(loader); @SuppressWarnings("unchecked") - T copy = (T) deserializeFromByteArray(serializeToByteArray(value), - value.toString()); + final T copy; + try { + copy = (T) deserializeFromByteArray(serializeToByteArray(value), value.toString()); + } finally { + thread.setContextClassLoader(tccl); + } return copy; } @@ -144,4 +163,40 @@ public class SerializableUtils { exn); } } + + private static final class ContextualObjectInputStream extends ObjectInputStream { + private ContextualObjectInputStream(final InputStream in) throws IOException { + super(in); + } + + @Override + protected Class<?> resolveClass(final ObjectStreamClass classDesc) + throws IOException, ClassNotFoundException { + // note: staying aligned on JVM default but can need class filtering here to avoid 0day issue + final String n = classDesc.getName(); + final ClassLoader classloader = ReflectHelpers.findClassLoader(); + try { + return Class.forName(n, false, classloader); + } catch (final ClassNotFoundException e) { + return super.resolveClass(classDesc); + } + } + + @Override + protected Class resolveProxyClass(final String[] interfaces) + throws IOException, ClassNotFoundException { + final ClassLoader classloader = ReflectHelpers.findClassLoader(); + + final Class[] cinterfaces = new Class[interfaces.length]; + for (int i = 0; i < interfaces.length; i++) { + cinterfaces[i] = classloader.loadClass(interfaces[i]); + } + + try { + return Proxy.getProxyClass(classloader, cinterfaces); + } catch (final IllegalArgumentException e) { + throw new ClassNotFoundException(null, e); + } + } + } } http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java index 60b3232..deecb96 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/AvroCoderTest.java @@ -26,10 +26,8 @@ import static org.junit.Assert.fail; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; -import com.google.common.io.ByteStreams; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.ArrayList; @@ -60,6 +58,7 @@ import org.apache.avro.util.Utf8; import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.coders.Coder.NonDeterministicException; import org.apache.beam.sdk.testing.CoderProperties; +import org.apache.beam.sdk.testing.InterceptingUrlClassLoader; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -166,42 +165,16 @@ public class AvroCoderTest { } /** - * A classloader that intercepts loading of Pojo and makes a new one. - */ - private static class InterceptingUrlClassLoader extends ClassLoader { - - private InterceptingUrlClassLoader(ClassLoader parent) { - super(parent); - } - - @Override - public Class<?> loadClass(String name) throws ClassNotFoundException { - if (name.equals(AvroCoderTestPojo.class.getName())) { - // Quite a hack? - try { - String classAsResource = name.replace('.', '/') + ".class"; - byte[] classBytes = - ByteStreams.toByteArray(getParent().getResourceAsStream(classAsResource)); - return defineClass(name, classBytes, 0, classBytes.length); - } catch (IOException e) { - throw new RuntimeException(e); - } - } else { - return getParent().loadClass(name); - } - } - } - - /** * Tests that {@link AvroCoder} works around issues in Avro where cache classes might be * from the wrong ClassLoader, causing confusing "Cannot cast X to X" error messages. */ @Test public void testTwoClassLoaders() throws Exception { + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); ClassLoader loader1 = - new InterceptingUrlClassLoader(Thread.currentThread().getContextClassLoader()); + new InterceptingUrlClassLoader(contextClassLoader, AvroCoderTestPojo.class.getName()); ClassLoader loader2 = - new InterceptingUrlClassLoader(Thread.currentThread().getContextClassLoader()); + new InterceptingUrlClassLoader(contextClassLoader, AvroCoderTestPojo.class.getName()); Class<?> pojoClass1 = loader1.loadClass(AvroCoderTestPojo.class.getName()); Class<?> pojoClass2 = loader2.loadClass(AvroCoderTestPojo.class.getName()); http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java new file mode 100644 index 0000000..b5adcb5 --- /dev/null +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/InterceptingUrlClassLoader.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import com.google.common.collect.Sets; +import com.google.common.io.ByteStreams; +import java.io.IOException; +import java.util.Set; + +/** + * A classloader that intercepts loading of specifically named classes. This classloader copies + * the original classes definition and is useful for testing code which needs to validate usage + * with multiple classloaders.. + */ +public class InterceptingUrlClassLoader extends ClassLoader { + private final Set<String> ownedClasses; + + public InterceptingUrlClassLoader(final ClassLoader parent, final String... ownedClasses) { + super(parent); + this.ownedClasses = Sets.newHashSet(ownedClasses); + } + + @Override + public Class<?> loadClass(final String name) throws ClassNotFoundException { + final Class<?> alreadyLoaded = super.findLoadedClass(name); + if (alreadyLoaded != null) { + return alreadyLoaded; + } + + if (name != null && ownedClasses.contains(name)) { + try { + final String classAsResource = name.replace('.', '/') + ".class"; + final byte[] classBytes = + ByteStreams.toByteArray(getParent().getResourceAsStream(classAsResource)); + return defineClass(name, classBytes, 0, classBytes.length); + } catch (final IOException e) { + throw new RuntimeException(e); + } + } + return getParent().loadClass(name); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/bccea9dc/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java index 9a80730..c3b0171 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SerializableUtilsTest.java @@ -18,8 +18,10 @@ package org.apache.beam.sdk.util; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; import com.google.common.collect.ImmutableList; + import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -28,6 +30,9 @@ import java.util.List; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.testing.InterceptingUrlClassLoader; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -51,6 +56,30 @@ public class SerializableUtilsTest { } @Test + public void customClassLoader() throws Exception { + // define a classloader with test-classes in it + final ClassLoader testLoader = Thread.currentThread().getContextClassLoader(); + final ClassLoader loader = new InterceptingUrlClassLoader(testLoader, MySource.class.getName()); + final Class<?> source = loader.loadClass( + "org.apache.beam.sdk.util.SerializableUtilsTest$MySource"); + assertNotSame(source.getClassLoader(), MySource.class.getClassLoader()); + + // validate if the caller set the classloader that it works well + final Serializable customLoaderSourceInstance = Serializable.class.cast( + source.getConstructor().newInstance()); + final Thread thread = Thread.currentThread(); + thread.setContextClassLoader(loader); + try { + assertSerializationClassLoader(loader, customLoaderSourceInstance); + } finally { + thread.setContextClassLoader(testLoader); + } + + // now let beam be a little be more fancy and try to ensure it by itself from the incoming value + assertSerializationClassLoader(loader, customLoaderSourceInstance); + } + + @Test public void testTranscode() { String stringValue = "hi bob"; int intValue = 42; @@ -114,4 +143,35 @@ public class SerializableUtilsTest { expectedException.expectMessage("unable to serialize"); SerializableUtils.ensureSerializable(new UnserializableCoderByJava()); } + + private void assertSerializationClassLoader( + final ClassLoader loader, final Serializable customLoaderSourceInstance) { + final Serializable copy = SerializableUtils.ensureSerializable(customLoaderSourceInstance); + assertEquals(loader, copy.getClass().getClassLoader()); + assertEquals( + copy.getClass().getClassLoader(), + customLoaderSourceInstance.getClass().getClassLoader()); + } + + /** + * a sample class to test framework serialization, + * {@see SerializableUtilsTest#customClassLoader}. + */ + public static class MySource extends BoundedSource<String> { + @Override + public List<? extends BoundedSource<String>> split( + final long desiredBundleSizeBytes, final PipelineOptions options) throws Exception { + return null; + } + + @Override + public long getEstimatedSizeBytes(final PipelineOptions options) throws Exception { + return 0; + } + + @Override + public BoundedReader<String> createReader(final PipelineOptions options) throws IOException { + return null; + } + } }