Repository: incubator-beam Updated Branches: refs/heads/master 7855f29c5 -> 09f6aa607
Add an Enforcement enum This tracks enabled enforcements independently of the PipelineOptions object. Move utility methods to construct enforcements and bundle factories based around enabled enforcements to the Enforcement enum. Only apply Immutability Enforcement if Immutability Enforcement is applicable to the producing PTransform. Remove EncodabilityEnforcement, as its responsibilities are handled in CloningBundleFactory. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/a3b80d1e Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/a3b80d1e Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/a3b80d1e Branch: refs/heads/master Commit: a3b80d1ea0919ec260b97fc8529b3625895d780e Parents: 7855f29 Author: Thomas Groh <tg...@google.com> Authored: Mon Nov 7 11:30:22 2016 -0800 Committer: Thomas Groh <tg...@google.com> Committed: Tue Nov 8 11:10:49 2016 -0800 ---------------------------------------------------------------------- .../beam/runners/direct/DirectRunner.java | 114 ++++--- .../direct/EncodabilityEnforcementFactory.java | 80 ----- .../ImmutabilityCheckingBundleFactory.java | 11 +- .../direct/ImmutabilityEnforcementFactory.java | 2 - .../direct/CloningBundleFactoryTest.java | 122 ++++++- .../beam/runners/direct/DirectRunnerTest.java | 26 ++ .../EncodabilityEnforcementFactoryTest.java | 323 ------------------- 7 files changed, 224 insertions(+), 454 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 4d5a449..f4aeb3e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -24,8 +24,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.io.IOException; import java.util.Collection; +import java.util.Collections; +import java.util.EnumSet; import java.util.HashMap; import java.util.Map; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import javax.annotation.Nullable; @@ -189,8 +192,72 @@ public class DirectRunner void add(Iterable<WindowedValue<ElemT>> values); } + /** The set of {@link PTransform PTransforms} that execute a UDF. Useful for some enforcements. */ + private static final Set<Class<? extends PTransform>> CONTAINS_UDF = + ImmutableSet.of( + Read.Bounded.class, Read.Unbounded.class, ParDo.Bound.class, ParDo.BoundMulti.class); + + enum Enforcement { + ENCODABILITY { + @Override + public boolean appliesTo(PTransform<?, ?> transform) { + return true; + } + }, + IMMUTABILITY { + @Override + public boolean appliesTo(PTransform<?, ?> transform) { + return CONTAINS_UDF.contains(transform.getClass()); + } + }; + + public abstract boolean appliesTo(PTransform<?, ?> transform); + + //////////////////////////////////////////////////////////////////////////////////////////////// + // Utilities for creating enforcements + public static Set<Enforcement> enabled(DirectOptions options) { + EnumSet<Enforcement> enabled = EnumSet.noneOf(Enforcement.class); + if (options.isEnforceEncodability()) { + enabled.add(ENCODABILITY); + } + if (options.isEnforceImmutability()) { + enabled.add(IMMUTABILITY); + } + return Collections.unmodifiableSet(enabled); + } + + public static BundleFactory bundleFactoryFor(Set<Enforcement> enforcements) { + BundleFactory bundleFactory = + enforcements.contains(Enforcement.ENCODABILITY) + ? CloningBundleFactory.create() + : ImmutableListBundleFactory.create(); + if (enforcements.contains(Enforcement.IMMUTABILITY)) { + bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory); + } + return bundleFactory; + } + + @SuppressWarnings("rawtypes") + private static Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> + defaultModelEnforcements(Set<Enforcement> enabledEnforcements) { + ImmutableMap.Builder<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> + enforcements = ImmutableMap.builder(); + ImmutableList.Builder<ModelEnforcementFactory> enabledParDoEnforcements = + ImmutableList.builder(); + if (enabledEnforcements.contains(Enforcement.IMMUTABILITY)) { + enabledParDoEnforcements.add(ImmutabilityEnforcementFactory.create()); + } + Collection<ModelEnforcementFactory> parDoEnforcements = enabledParDoEnforcements.build(); + enforcements.put(ParDo.Bound.class, parDoEnforcements); + enforcements.put(ParDo.BoundMulti.class, parDoEnforcements); + return enforcements.build(); + } + + } + //////////////////////////////////////////////////////////////////////////////////////////////// private final DirectOptions options; + private final Set<Enforcement> enabledEnforcements; private Supplier<ExecutorService> executorServiceSupplier; private Supplier<Clock> clockSupplier = new NanosOffsetClockSupplier(); @@ -200,6 +267,7 @@ public class DirectRunner private DirectRunner(DirectOptions options) { this.options = options; + this.enabledEnforcements = Enforcement.enabled(options); this.executorServiceSupplier = new FixedThreadPoolSupplier(options); } @@ -252,7 +320,7 @@ public class DirectRunner EvaluationContext.create( getPipelineOptions(), clockSupplier.get(), - createBundleFactory(getPipelineOptions()), + Enforcement.bundleFactoryFor(enabledEnforcements), consumerTrackingVisitor.getRootTransforms(), consumerTrackingVisitor.getValueToConsumers(), consumerTrackingVisitor.getStepNames(), @@ -270,7 +338,7 @@ public class DirectRunner keyedPValueVisitor.getKeyedPValues(), rootInputProvider, registry, - defaultModelEnforcements(options), + Enforcement.defaultModelEnforcements(enabledEnforcements), context); executor.start(consumerTrackingVisitor.getRootTransforms()); @@ -292,48 +360,6 @@ public class DirectRunner return result; } - @SuppressWarnings("rawtypes") - private Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> - defaultModelEnforcements(DirectOptions options) { - ImmutableMap.Builder<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> - enforcements = ImmutableMap.builder(); - Collection<ModelEnforcementFactory> parDoEnforcements = createParDoEnforcements(options); - enforcements.put(ParDo.Bound.class, parDoEnforcements); - enforcements.put(ParDo.BoundMulti.class, parDoEnforcements); - if (options.isEnforceEncodability()) { - enforcements.put( - Read.Unbounded.class, - ImmutableSet.<ModelEnforcementFactory>of(EncodabilityEnforcementFactory.create())); - enforcements.put( - Read.Bounded.class, - ImmutableSet.<ModelEnforcementFactory>of(EncodabilityEnforcementFactory.create())); - } - return enforcements.build(); - } - - private Collection<ModelEnforcementFactory> createParDoEnforcements( - DirectOptions options) { - ImmutableList.Builder<ModelEnforcementFactory> enforcements = ImmutableList.builder(); - if (options.isEnforceImmutability()) { - enforcements.add(ImmutabilityEnforcementFactory.create()); - } - if (options.isEnforceEncodability()) { - enforcements.add(EncodabilityEnforcementFactory.create()); - } - return enforcements.build(); - } - - private BundleFactory createBundleFactory(DirectOptions pipelineOptions) { - BundleFactory bundleFactory = - pipelineOptions.isEnforceEncodability() - ? CloningBundleFactory.create() - : ImmutableListBundleFactory.create(); - if (pipelineOptions.isEnforceImmutability()) { - bundleFactory = ImmutabilityCheckingBundleFactory.create(bundleFactory); - } - return bundleFactory; - } - /** * The result of running a {@link Pipeline} with the {@link DirectRunner}. * http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java deleted file mode 100644 index 0a5f03f..0000000 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactory.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.direct; - -import static com.google.common.base.Preconditions.checkArgument; - -import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.util.CoderUtils; -import org.apache.beam.sdk.util.UserCodeException; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollection; - -/** - * Enforces that all elements in a {@link PCollection} can be encoded using that - * {@link PCollection PCollection's} {@link Coder}. - */ -class EncodabilityEnforcementFactory implements ModelEnforcementFactory { - // The factory proper is stateless - private static final EncodabilityEnforcementFactory INSTANCE = - new EncodabilityEnforcementFactory(); - - public static EncodabilityEnforcementFactory create() { - return INSTANCE; - } - - @Override - public <T> ModelEnforcement<T> forBundle( - CommittedBundle<T> input, AppliedPTransform<?, ?, ?> consumer) { - return new EncodabilityEnforcement<>(); - } - - private static class EncodabilityEnforcement<T> extends AbstractModelEnforcement<T> { - @Override - public void afterFinish( - CommittedBundle<T> input, - TransformResult result, - Iterable<? extends CommittedBundle<?>> outputs) { - for (CommittedBundle<?> bundle : outputs) { - ensureBundleEncodable(bundle); - } - } - - private <T> void ensureBundleEncodable(CommittedBundle<T> bundle) { - Coder<T> coder = bundle.getPCollection().getCoder(); - for (WindowedValue<T> element : bundle.getElements()) { - try { - T clone = CoderUtils.clone(coder, element.getValue()); - if (coder.consistentWithEquals()) { - checkArgument( - coder.structuralValue(element.getValue()).equals(coder.structuralValue(clone)), - "Coder %s of class %s does not maintain structural value equality" - + " on input element %s", - coder, - coder.getClass().getSimpleName(), - element.getValue()); - } - } catch (Exception e) { - throw UserCodeException.wrap(e); - } - } - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java index 08c6e78..4f72f68 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityCheckingBundleFactory.java @@ -22,6 +22,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.collect.HashMultimap; import com.google.common.collect.SetMultimap; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; +import org.apache.beam.runners.direct.DirectRunner.Enforcement; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; @@ -71,13 +72,19 @@ class ImmutabilityCheckingBundleFactory implements BundleFactory { @Override public <T> UncommittedBundle<T> createBundle(PCollection<T> output) { - return new ImmutabilityEnforcingBundle<>(underlying.createBundle(output)); + if (Enforcement.IMMUTABILITY.appliesTo(output.getProducingTransformInternal().getTransform())) { + return new ImmutabilityEnforcingBundle<>(underlying.createBundle(output)); + } + return underlying.createBundle(output); } @Override public <K, T> UncommittedBundle<T> createKeyedBundle( StructuralKey<K> key, PCollection<T> output) { - return new ImmutabilityEnforcingBundle<>(underlying.createKeyedBundle(key, output)); + if (Enforcement.IMMUTABILITY.appliesTo(output.getProducingTransformInternal().getTransform())) { + return new ImmutabilityEnforcingBundle<>(underlying.createKeyedBundle(key, output)); + } + return underlying.createKeyedBundle(key, output); } private static class ImmutabilityEnforcingBundle<T> implements UncommittedBundle<T> { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactory.java index 1602f68..612922a 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ImmutabilityEnforcementFactory.java @@ -32,8 +32,6 @@ import org.apache.beam.sdk.util.WindowedValue; /** * {@link ModelEnforcement} that enforces elements are not modified over the course of processing * an element. - * - * <p>Implies {@link EncodabilityEnforcment}. */ class ImmutabilityEnforcementFactory implements ModelEnforcementFactory { public static ModelEnforcementFactory create() { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java index 03846d9..bafab59 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CloningBundleFactoryTest.java @@ -28,18 +28,21 @@ import static org.junit.Assert.assertThat; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; -import org.apache.beam.runners.direct.EncodabilityEnforcementFactoryTest.Record; -import org.apache.beam.runners.direct.EncodabilityEnforcementFactoryTest.RecordNoDecodeCoder; -import org.apache.beam.runners.direct.EncodabilityEnforcementFactoryTest.RecordNoEncodeCoder; +import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; +import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; @@ -174,4 +177,117 @@ public class CloningBundleFactoryTest { thrown.expectMessage("Decode not allowed"); bundle.add(WindowedValue.valueInGlobalWindow(new Record())); } + + static class Record {} + static class RecordNoEncodeCoder extends AtomicCoder<Record> { + + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws IOException { + throw new CoderException("Encode not allowed"); + } + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws IOException { + return null; + } + } + + static class RecordNoDecodeCoder extends AtomicCoder<Record> { + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws IOException {} + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws IOException { + throw new CoderException("Decode not allowed"); + } + } + + private static class RecordStructuralValueCoder extends AtomicCoder<Record> { + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new Record() { + @Override + public String toString() { + return "DecodedRecord"; + } + }; + } + + @Override + public boolean consistentWithEquals() { + return true; + } + + @Override + public Object structuralValue(Record value) { + return value; + } + } + + private static class RecordNotConsistentWithEqualsStructuralValueCoder + extends AtomicCoder<Record> { + @Override + public void encode( + Record value, + OutputStream outStream, + org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException {} + + @Override + public Record decode( + InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) + throws CoderException, IOException { + return new Record() { + @Override + public String toString() { + return "DecodedRecord"; + } + }; + } + + @Override + public boolean consistentWithEquals() { + return false; + } + + @Override + public Object structuralValue(Record value) { + return value; + } + } + + private static class IdentityDoFn extends DoFn<Record, Record> { + @ProcessElement + public void proc(ProcessContext ctxt) { + ctxt.output(ctxt.element()); + } + } + + private static class SimpleIdentity extends SimpleFunction<Record, Record> { + @Override + public Record apply(Record input) { + return input; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java index 34a5469..3836f58 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java @@ -271,6 +271,32 @@ public class DirectRunnerTest implements Serializable { * {@link DirectRunner}. */ @Test + public void testMutatingOutputWithEnforcementDisabledSucceeds() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + options.setRunner(DirectRunner.class); + options.as(DirectOptions.class).setEnforceImmutability(false); + Pipeline pipeline = Pipeline.create(options); + + pipeline + .apply(Create.of(42)) + .apply(ParDo.of(new DoFn<Integer, List<Integer>>() { + @ProcessElement + public void processElement(ProcessContext c) { + List<Integer> outputList = Arrays.asList(1, 2, 3, 4); + c.output(outputList); + outputList.set(0, 37); + c.output(outputList); + } + })); + + pipeline.run(); + } + + /** + * Tests that a {@link DoFn} that mutates an output with a good equals() fails in the + * {@link DirectRunner}. + */ + @Test public void testMutatingOutputThenTerminateDoFnError() throws Exception { Pipeline pipeline = getPipeline(); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a3b80d1e/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java deleted file mode 100644 index e6bdbd0..0000000 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EncodabilityEnforcementFactoryTest.java +++ /dev/null @@ -1,323 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.direct; - -import static org.hamcrest.Matchers.isA; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.Collections; -import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; -import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; -import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.util.UserCodeException; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.PCollection; -import org.joda.time.Instant; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Tests for {@link EncodabilityEnforcementFactory}. - */ -@RunWith(JUnit4.class) -public class EncodabilityEnforcementFactoryTest { - @Rule public ExpectedException thrown = ExpectedException.none(); - private EncodabilityEnforcementFactory factory = EncodabilityEnforcementFactory.create(); - private BundleFactory bundleFactory = ImmutableListBundleFactory.create(); - - private PCollection<Record> inputPCollection; - private CommittedBundle<Record> inputBundle; - private PCollection<Record> outputPCollection; - - @Before - public void setup() { - Pipeline p = TestPipeline.create(); - inputPCollection = p.apply(Create.of(new Record()).withCoder(new RecordNoDecodeCoder())); - outputPCollection = inputPCollection.apply(ParDo.of(new IdentityDoFn())); - - inputBundle = - bundleFactory - .<Record>createRootBundle() - .add(WindowedValue.valueInGlobalWindow(new Record())) - .commit(Instant.now()); - } - - @Test - public void encodeFailsThrows() { - WindowedValue<Record> record = WindowedValue.valueInGlobalWindow(new Record()); - - ModelEnforcement<Record> enforcement = createEnforcement(new RecordNoEncodeCoder(), record); - - UncommittedBundle<Record> output = - bundleFactory.createBundle(outputPCollection).add(record); - - enforcement.beforeElement(record); - enforcement.afterElement(record); - thrown.expect(UserCodeException.class); - thrown.expectCause(isA(CoderException.class)); - thrown.expectMessage("Encode not allowed"); - enforcement.afterFinish( - inputBundle, - StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) - .addOutput(output) - .build(), - Collections.<CommittedBundle<?>>singleton(output.commit(Instant.now()))); - } - - @Test - public void decodeFailsThrows() { - WindowedValue<Record> record = WindowedValue.valueInGlobalWindow(new Record()); - - ModelEnforcement<Record> enforcement = createEnforcement(new RecordNoDecodeCoder(), record); - - UncommittedBundle<Record> output = - bundleFactory.createBundle(outputPCollection).add(record); - - enforcement.beforeElement(record); - enforcement.afterElement(record); - thrown.expect(UserCodeException.class); - thrown.expectCause(isA(CoderException.class)); - thrown.expectMessage("Decode not allowed"); - enforcement.afterFinish( - inputBundle, - StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) - .addOutput(output) - .build(), - Collections.<CommittedBundle<?>>singleton(output.commit(Instant.now()))); - } - - @Test - public void consistentWithEqualsStructuralValueNotEqualThrows() { - WindowedValue<Record> record = - WindowedValue.<Record>valueInGlobalWindow( - new Record() { - @Override - public String toString() { - return "OriginalRecord"; - } - }); - - ModelEnforcement<Record> enforcement = - createEnforcement(new RecordStructuralValueCoder(), record); - - UncommittedBundle<Record> output = - bundleFactory.createBundle(outputPCollection).add(record); - - enforcement.beforeElement(record); - enforcement.afterElement(record); - - thrown.expect(UserCodeException.class); - thrown.expectCause(isA(IllegalArgumentException.class)); - thrown.expectMessage("does not maintain structural value equality"); - thrown.expectMessage(RecordStructuralValueCoder.class.getSimpleName()); - thrown.expectMessage("OriginalRecord"); - enforcement.afterFinish( - inputBundle, - StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) - .addOutput(output) - .build(), - Collections.<CommittedBundle<?>>singleton(output.commit(Instant.now()))); - } - - @Test - public void notConsistentWithEqualsStructuralValueNotEqualSucceeds() { - outputPCollection.setCoder(new RecordNotConsistentWithEqualsStructuralValueCoder()); - WindowedValue<Record> record = WindowedValue.<Record>valueInGlobalWindow(new Record()); - - ModelEnforcement<Record> enforcement = - factory.forBundle(inputBundle, outputPCollection.getProducingTransformInternal()); - - UncommittedBundle<Record> output = - bundleFactory.createBundle(outputPCollection).add(record); - - enforcement.beforeElement(record); - enforcement.afterElement(record); - enforcement.afterFinish( - inputBundle, - StepTransformResult.withoutHold(outputPCollection.getProducingTransformInternal()) - .addOutput(output) - .build(), - Collections.<CommittedBundle<?>>singleton(output.commit(Instant.now()))); - } - - private ModelEnforcement<Record> createEnforcement( - Coder<Record> coder, WindowedValue<Record> record) { - TestPipeline p = TestPipeline.create(); - PCollection<Record> unencodable = p.apply(Create.<Record>of().withCoder(coder)); - outputPCollection = - unencodable.apply( - MapElements.via(new SimpleIdentity())); - AppliedPTransform<?, ?, ?> consumer = outputPCollection.getProducingTransformInternal(); - CommittedBundle<Record> input = - bundleFactory.createBundle(unencodable).add(record).commit(Instant.now()); - ModelEnforcement<Record> enforcement = factory.forBundle(input, consumer); - return enforcement; - } - - @Test - public void structurallyEqualResultsSucceeds() { - TestPipeline p = TestPipeline.create(); - PCollection<Integer> unencodable = p.apply(Create.of(1).withCoder(VarIntCoder.of())); - AppliedPTransform<?, ?, ?> consumer = - unencodable.apply(Count.<Integer>globally()).getProducingTransformInternal(); - - WindowedValue<Integer> value = WindowedValue.valueInGlobalWindow(1); - - CommittedBundle<Integer> input = - bundleFactory.createBundle(unencodable).add(value).commit(Instant.now()); - ModelEnforcement<Integer> enforcement = factory.forBundle(input, consumer); - - enforcement.beforeElement(value); - enforcement.afterElement(value); - enforcement.afterFinish( - input, - StepTransformResult.withoutHold(consumer).build(), - Collections.<CommittedBundle<?>>emptyList()); - } - - static class Record {} - static class RecordNoEncodeCoder extends AtomicCoder<Record> { - - @Override - public void encode( - Record value, - OutputStream outStream, - org.apache.beam.sdk.coders.Coder.Context context) - throws IOException { - throw new CoderException("Encode not allowed"); - } - - @Override - public Record decode( - InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws IOException { - return null; - } - } - - static class RecordNoDecodeCoder extends AtomicCoder<Record> { - @Override - public void encode( - Record value, - OutputStream outStream, - org.apache.beam.sdk.coders.Coder.Context context) - throws IOException {} - - @Override - public Record decode( - InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws IOException { - throw new CoderException("Decode not allowed"); - } - } - - private static class RecordStructuralValueCoder extends AtomicCoder<Record> { - @Override - public void encode( - Record value, - OutputStream outStream, - org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException {} - - @Override - public Record decode( - InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { - return new Record() { - @Override - public String toString() { - return "DecodedRecord"; - } - }; - } - - @Override - public boolean consistentWithEquals() { - return true; - } - - @Override - public Object structuralValue(Record value) { - return value; - } - } - - private static class RecordNotConsistentWithEqualsStructuralValueCoder - extends AtomicCoder<Record> { - @Override - public void encode( - Record value, - OutputStream outStream, - org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException {} - - @Override - public Record decode( - InputStream inStream, org.apache.beam.sdk.coders.Coder.Context context) - throws CoderException, IOException { - return new Record() { - @Override - public String toString() { - return "DecodedRecord"; - } - }; - } - - @Override - public boolean consistentWithEquals() { - return false; - } - - @Override - public Object structuralValue(Record value) { - return value; - } - } - - private static class IdentityDoFn extends DoFn<Record, Record> { - @ProcessElement - public void proc(ProcessContext ctxt) { - ctxt.output(ctxt.element()); - } - } - - private static class SimpleIdentity extends SimpleFunction<Record, Record> { - @Override - public Record apply(Record input) { - return input; - } - } -}