Use Graph Surgery in the DirectRunner Remove DirectRunner#apply(). This migrates the DirectRunner to work on a runner-agnostic graph.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/77a1afb2 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/77a1afb2 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/77a1afb2 Branch: refs/heads/master Commit: 77a1afb2efc20076dcba207d8d7303c5635d4daf Parents: 67d02b9 Author: Thomas Groh <tg...@google.com> Authored: Thu Feb 9 11:45:06 2017 -0800 Committer: Thomas Groh <tg...@google.com> Committed: Thu Feb 16 18:57:58 2017 -0800 ---------------------------------------------------------------------- .../translation/ParDoBoundMultiTranslator.java | 5 ++ .../apex/translation/ParDoBoundTranslator.java | 5 ++ .../beam/runners/direct/DirectRunner.java | 88 ++++++++++++-------- .../beam/runners/direct/EvaluationContext.java | 3 +- .../direct/TestStreamEvaluatorFactory.java | 3 +- .../direct/WriteWithShardingFactory.java | 24 +++--- .../beam/runners/direct/DirectRunnerTest.java | 3 +- .../direct/KeyedPValueTrackingVisitorTest.java | 23 ++++- .../StatefulParDoEvaluatorFactoryTest.java | 57 ++++++++----- .../direct/TestStreamEvaluatorFactoryTest.java | 20 ++--- .../direct/ViewEvaluatorFactoryTest.java | 4 +- .../FlinkBatchTransformTranslators.java | 12 +++ .../FlinkStreamingTransformTranslators.java | 12 +++ .../dataflow/DataflowPipelineTranslator.java | 7 ++ .../spark/translation/TransformTranslator.java | 3 + .../spark/translation/TranslationUtils.java | 9 ++ .../streaming/StreamingTransformTranslator.java | 3 + .../org/apache/beam/sdk/testing/TestStream.java | 10 +-- .../org/apache/beam/sdk/transforms/ParDo.java | 9 +- .../apache/beam/sdk/testing/TestStreamTest.java | 17 ---- .../apache/beam/sdk/transforms/ParDoTest.java | 30 ------- 21 files changed, 200 insertions(+), 147 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java index 2439020..f55b48c 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundMultiTranslator.java @@ -56,6 +56,11 @@ class ParDoBoundMultiTranslator<InputT, OutputT> DoFn<InputT, OutputT> doFn = transform.getFn(); DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn)); + } if (signature.stateDeclarations().size() > 0) { throw new UnsupportedOperationException( String.format( http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java index c24250f..5195809 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoBoundTranslator.java @@ -43,6 +43,11 @@ class ParDoBoundTranslator<InputT, OutputT> DoFn<InputT, OutputT> doFn = transform.getFn(); DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn)); + } if (signature.stateDeclarations().size() > 0) { throw new UnsupportedOperationException( String.format( http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 40ef60e..06189a2 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -30,8 +30,9 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; -import org.apache.beam.runners.core.SplittableParDo; +import org.apache.beam.runners.core.SplittableParDo.GBKIntoKeyedWorkItems; import org.apache.beam.runners.core.TimerInternals.TimerData; +import org.apache.beam.runners.core.construction.PTransformMatchers; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory; @@ -41,10 +42,11 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.io.Write; +import org.apache.beam.sdk.io.Write.Bound; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.runners.PTransformMatcher; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.testing.TestStream; @@ -53,14 +55,13 @@ import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.BoundMulti; import org.apache.beam.sdk.transforms.View.CreatePCollectionView; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.PInput; -import org.apache.beam.sdk.values.POutput; import org.joda.time.Duration; import org.joda.time.Instant; @@ -72,23 +73,49 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { /** * The default set of transform overrides to use in the {@link DirectRunner}. * - * <p>A transform override must have a single-argument constructor that takes an instance of the - * type of transform it is overriding. + * <p>The order in which overrides is applied is important, as some overrides are expanded into a + * composite. If the composite contains {@link PTransform PTransforms} which are also overridden, + * these PTransforms must occur later in the iteration order. {@link ImmutableMap} has an + * iteration order based on the order at which elements are added to it. */ @SuppressWarnings("rawtypes") - private static Map<Class<? extends PTransform>, PTransformOverrideFactory> - defaultTransformOverrides = - ImmutableMap.<Class<? extends PTransform>, PTransformOverrideFactory>builder() - .put(CreatePCollectionView.class, new ViewOverrideFactory()) - .put(GroupByKey.class, new DirectGroupByKeyOverrideFactory()) - .put(TestStream.class, new DirectTestStreamFactory()) - .put(Write.Bound.class, new WriteWithShardingFactory()) - .put(ParDo.Bound.class, new ParDoSingleViaMultiOverrideFactory()) - .put(ParDo.BoundMulti.class, new ParDoMultiOverrideFactory()) - .put( - SplittableParDo.GBKIntoKeyedWorkItems.class, - new DirectGBKIntoKeyedWorkItemsOverrideFactory()) - .build(); + private static Map<PTransformMatcher, PTransformOverrideFactory> defaultTransformOverrides = + ImmutableMap.<PTransformMatcher, PTransformOverrideFactory>builder() + .put( + PTransformMatchers.classEqualTo(Bound.class), + new WriteWithShardingFactory()) /* Uses a view internally. */ + .put( + PTransformMatchers.classEqualTo(CreatePCollectionView.class), + new ViewOverrideFactory()) /* Uses pardos and GBKs */ + .put( + PTransformMatchers.classEqualTo(TestStream.class), + new DirectTestStreamFactory()) /* primitive */ + /* Single-output ParDos are implemented in terms of Multi-output ParDos. Any override + that is applied to a multi-output ParDo must first have all matching Single-output ParDos + converted to match. + */ + .put(PTransformMatchers.splittableParDoSingle(), new ParDoSingleViaMultiOverrideFactory()) + .put( + PTransformMatchers.stateOrTimerParDoSingle(), + new ParDoSingleViaMultiOverrideFactory()) + // SplittableParMultiDo is implemented in terms of nonsplittable single ParDos + .put(PTransformMatchers.splittableParDoMulti(), new ParDoMultiOverrideFactory()) + // state and timer pardos are implemented in terms of nonsplittable single ParDos + .put(PTransformMatchers.stateOrTimerParDoMulti(), new ParDoMultiOverrideFactory()) + .put( + PTransformMatchers.classEqualTo(ParDo.Bound.class), + new ParDoSingleViaMultiOverrideFactory()) /* returns a BoundMulti */ + .put( + PTransformMatchers.classEqualTo(BoundMulti.class), + /* returns one of two primitives; SplittableParDos are replaced above. */ + new ParDoMultiOverrideFactory()) + .put( + PTransformMatchers.classEqualTo(GBKIntoKeyedWorkItems.class), + new DirectGBKIntoKeyedWorkItemsOverrideFactory()) /* Returns a GBKO */ + .put( + PTransformMatchers.classEqualTo(GroupByKey.class), + new DirectGroupByKeyOverrideFactory()) /* returns two chained primitives. */ + .build(); /** * Part of a {@link PCollection}. Elements are output to a bundle, which will cause them to be @@ -281,23 +308,11 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { } @Override - public <OutputT extends POutput, InputT extends PInput> OutputT apply( - PTransform<InputT, OutputT> transform, InputT input) { - PTransformOverrideFactory<InputT, OutputT, PTransform<InputT, OutputT>> overrideFactory = - defaultTransformOverrides.get(transform.getClass()); - if (overrideFactory != null) { - PTransform<InputT, OutputT> customTransform = - overrideFactory.getReplacementTransform(transform); - if (customTransform != transform) { - return Pipeline.applyTransform(transform.getName(), input, customTransform); - } - } - // If there is no override, or we should not apply the override, apply the original transform - return super.apply(transform, input); - } - - @Override public DirectPipelineResult run(Pipeline pipeline) { + for (Map.Entry<PTransformMatcher, PTransformOverrideFactory> override : + defaultTransformOverrides.entrySet()) { + pipeline.replace(override.getKey(), override.getValue()); + } MetricsEnvironment.setMetricsSupported(true); DirectGraphVisitor graphVisitor = new DirectGraphVisitor(); pipeline.traverseTopologically(graphVisitor); @@ -464,4 +479,7 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { return NanosOffsetClock.create(); } } + + private static class ComplexParDoMatcher { + } } http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java index 69752fa..49c9ec2 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java @@ -229,7 +229,8 @@ class EvaluationContext { private void fireAvailableCallbacks(AppliedPTransform<?, ?, ?> producingTransform) { TransformWatermarks watermarks = watermarkManager.getWatermarks(producingTransform); - callbackExecutor.fireForWatermark(producingTransform, watermarks.getOutputWatermark()); + Instant outputWatermark = watermarks.getOutputWatermark(); + callbackExecutor.fireForWatermark(producingTransform, outputWatermark); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java index 53e2671..628aa23 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java @@ -187,7 +187,8 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory { static class DirectTestStream<T> extends PTransform<PBegin, PCollection<T>> { private final TestStream<T> original; - private DirectTestStream(TestStream<T> transform) { + @VisibleForTesting + DirectTestStream(TestStream<T> transform) { this.original = transform; } http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index 966ce4e..83c82a5 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -88,15 +88,19 @@ class WriteWithShardingFactory<InputT> @Override public PDone expand(PCollection<T> input) { - checkArgument(IsBounded.BOUNDED == input.isBounded(), + checkArgument( + IsBounded.BOUNDED == input.isBounded(), "%s can only be applied to a Bounded PCollection", getClass().getSimpleName()); - PCollection<T> records = input.apply("RewindowInputs", - Window.<T>into(new GlobalWindows()).triggering(DefaultTrigger.of()) - .withAllowedLateness(Duration.ZERO) - .discardingFiredPanes()); - final PCollectionView<Long> numRecords = records - .apply("CountRecords", Count.<T>globally().asSingletonView()); + PCollection<T> records = + input.apply( + "RewindowInputs", + Window.<T>into(new GlobalWindows()) + .triggering(DefaultTrigger.of()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + final PCollectionView<Long> numRecords = + records.apply("CountRecords", Count.<T>globally().asSingletonView()); PCollection<T> resharded = records .apply( @@ -113,15 +117,13 @@ class WriteWithShardingFactory<InputT> // without adding a new Write Transform Node, which would be overwritten the same way, leading // to an infinite recursion. We cannot modify the number of shards, because that is determined // at runtime. - return original.expand(resharded); + return resharded.apply(original); } } @VisibleForTesting static class KeyBasedOnCountFn<T> extends DoFn<T, KV<Integer, T>> { - @VisibleForTesting - static final int MIN_SHARDS_FOR_LOG = 3; - + @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3; private final PCollectionView<Long> numRecords; private final int randomExtraShards; private int currentShard; http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java index ac1689d..d2b6d1d 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectRunnerTest.java @@ -83,8 +83,7 @@ public class DirectRunnerTest implements Serializable { PipelineOptions opts = PipelineOptionsFactory.create(); opts.setRunner(DirectRunner.class); - Pipeline p = Pipeline.create(opts); - return p; + return Pipeline.create(opts); } @Test http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java index 8fac534..74e70f8 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java @@ -24,6 +24,8 @@ import static org.junit.Assert.assertThat; import java.util.Collections; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItemCoder; +import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow; +import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -39,8 +41,11 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; @@ -67,7 +72,12 @@ public class KeyedPValueTrackingVisitorTest { @Test public void groupByKeyProducesKeyedOutput() { PCollection<KV<String, Iterable<Integer>>> keyed = - p.apply(Create.of(KV.of("foo", 3))).apply(GroupByKey.<String, Integer>create()); + p + .apply(Create.of(KV.of("foo", 3))) + .apply(new DirectGroupByKeyOnly<String, Integer>()) + .apply( + new DirectGroupAlsoByWindow<String, Integer>( + WindowingStrategy.globalDefault(), WindowingStrategy.globalDefault())); p.traverseTopologically(visitor); assertThat(visitor.getKeyedPValues(), hasItem(keyed)); @@ -144,10 +154,17 @@ public class KeyedPValueTrackingVisitorTest { WindowedValue.getValueOnlyCoder( KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))))); + TupleTag<KeyedWorkItem<String, KV<String, Integer>>> keyedTag = new TupleTag<>(); PCollection<KeyedWorkItem<String, KV<String, Integer>>> keyed = input - .apply(GroupByKey.<String, WindowedValue<KV<String, Integer>>>create()) - .apply(ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>())) + .apply(new DirectGroupByKeyOnly<String, WindowedValue<KV<String, Integer>>>()) + .apply( + new DirectGroupAlsoByWindow<String, WindowedValue<KV<String, Integer>>>( + WindowingStrategy.globalDefault(), WindowingStrategy.globalDefault())) + .apply( + ParDo.of(new ParDoMultiOverrideFactory.ToKeyedWorkItem<String, Integer>()) + .withOutputTags(keyedTag, TupleTagList.empty())) + .get(keyedTag) .setCoder( KeyedWorkItemCoder.of( StringUtf8Coder.of(), http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index ac7d2bd..9bf6bc9 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -45,6 +45,7 @@ import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo; import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Create; @@ -66,6 +67,8 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Before; @@ -122,17 +125,23 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { .apply(Create.of(KV.of("hello", 1), KV.of("hello", 2))) .apply(Window.<KV<String, Integer>>into(FixedWindows.of(Duration.millis(10)))); + TupleTag<Integer> mainOutput = new TupleTag<>(); PCollection<Integer> produced = - input.apply( - ParDo.of( - new DoFn<KV<String, Integer>, Integer>() { - @StateId(stateId) - private final StateSpec<Object, ValueState<String>> spec = - StateSpecs.value(StringUtf8Coder.of()); - - @ProcessElement - public void process(ProcessContext c) {} - })); + input + .apply( + new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( + ParDo.of( + new DoFn<KV<String, Integer>, Integer>() { + @StateId(stateId) + private final StateSpec<Object, ValueState<String>> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }) + .withOutputTags(mainOutput, TupleTagList.empty()))) + .get(mainOutput) + .setCoder(VarIntCoder.of()); StatefulParDoEvaluatorFactory<String, Integer, Integer> factory = new StatefulParDoEvaluatorFactory(mockEvaluationContext); @@ -229,18 +238,24 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { .apply("Window side input", Window.<Integer>into(FixedWindows.of(Duration.millis(10)))) .apply("View side input", View.<Integer>asList()); + TupleTag<Integer> mainOutput = new TupleTag<>(); PCollection<Integer> produced = - mainInput.apply( - ParDo.withSideInputs(sideInput) - .of( - new DoFn<KV<String, Integer>, Integer>() { - @StateId(stateId) - private final StateSpec<Object, ValueState<String>> spec = - StateSpecs.value(StringUtf8Coder.of()); - - @ProcessElement - public void process(ProcessContext c) {} - })); + mainInput + .apply( + new ParDoMultiOverrideFactory.GbkThenStatefulParDo<>( + ParDo.withSideInputs(sideInput) + .of( + new DoFn<KV<String, Integer>, Integer>() { + @StateId(stateId) + private final StateSpec<Object, ValueState<String>> spec = + StateSpecs.value(StringUtf8Coder.of()); + + @ProcessElement + public void process(ProcessContext c) {} + }) + .withOutputTags(mainOutput, TupleTagList.empty()))) + .get(mainOutput) + .setCoder(VarIntCoder.of()); StatefulParDoEvaluatorFactory<String, Integer, Integer> factory = new StatefulParDoEvaluatorFactory(mockEvaluationContext); http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java index 4dc7738..9ed72d5 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java @@ -28,6 +28,7 @@ import java.util.Collection; import java.util.Collections; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; +import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory.DirectTestStream; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestClock; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestStreamIndex; import org.apache.beam.sdk.Pipeline; @@ -70,17 +71,16 @@ public class TestStreamEvaluatorFactoryTest { /** Demonstrates that returned evaluators produce elements in sequence. */ @Test public void producesElementsInSequence() throws Exception { + TestStream<Integer> testStream = TestStream.create(VarIntCoder.of()) + .addElements(1, 2, 3) + .advanceWatermarkTo(new Instant(0)) + .addElements(TimestampedValue.atMinimumTimestamp(4), + TimestampedValue.atMinimumTimestamp(5), + TimestampedValue.atMinimumTimestamp(6)) + .advanceProcessingTime(Duration.standardMinutes(10)) + .advanceWatermarkToInfinity(); PCollection<Integer> streamVals = - p.apply( - TestStream.create(VarIntCoder.of()) - .addElements(1, 2, 3) - .advanceWatermarkTo(new Instant(0)) - .addElements( - TimestampedValue.atMinimumTimestamp(4), - TimestampedValue.atMinimumTimestamp(5), - TimestampedValue.atMinimumTimestamp(6)) - .advanceProcessingTime(Duration.standardMinutes(10)) - .advanceWatermarkToInfinity()); + p.apply(new DirectTestStream<Integer>(testStream)); TestClock clock = new TestClock(); when(context.getClock()).thenReturn(clock); http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java index 5b03bcd..b094d17 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewEvaluatorFactoryTest.java @@ -60,7 +60,6 @@ public class ViewEvaluatorFactoryTest { @Test public void testInMemoryEvaluator() throws Exception { - PCollection<String> input = p.apply(Create.of("foo", "bar")); CreatePCollectionView<String, Iterable<String>> createView = CreatePCollectionView.of( @@ -77,8 +76,7 @@ public class ViewEvaluatorFactoryTest { TestViewWriter<String, Iterable<String>> viewWriter = new TestViewWriter<>(); when(context.createPCollectionViewWriter(concat, view)).thenReturn(viewWriter); - CommittedBundle<String> inputBundle = - bundleFactory.createBundle(input).commit(Instant.now()); + CommittedBundle<String> inputBundle = bundleFactory.createBundle(input).commit(Instant.now()); AppliedPTransform<?, ?, ?> producer = DirectGraphs.getProducer(view); TransformEvaluator<Iterable<String>> evaluator = new ViewEvaluatorFactory(context) http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java index f7f1878..29ba9a6 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkBatchTransformTranslators.java @@ -489,6 +489,16 @@ class FlinkBatchTransformTranslators { } } + private static void rejectSplittable(DoFn<?, ?> doFn) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support splittable DoFn: %s", + FlinkRunner.class.getSimpleName(), doFn)); + } + } + private static void rejectStateAndTimers(DoFn<?, ?> doFn) { DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); @@ -523,6 +533,7 @@ class FlinkBatchTransformTranslators { FlinkBatchTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectStateAndTimers(doFn); DataSet<WindowedValue<InputT>> inputDataSet = @@ -569,6 +580,7 @@ class FlinkBatchTransformTranslators { ParDo.BoundMulti<InputT, OutputT> transform, FlinkBatchTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectStateAndTimers(doFn); DataSet<WindowedValue<InputT>> inputDataSet = context.getInputDataSet(context.getInput(transform)); http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java index 2131729..757cdd2 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/FlinkStreamingTransformTranslators.java @@ -310,6 +310,16 @@ public class FlinkStreamingTransformTranslators { } } + private static void rejectSplittable(DoFn<?, ?> doFn) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support splittable DoFn: %s", + FlinkRunner.class.getSimpleName(), doFn)); + } + } + private static void rejectTimers(DoFn<?, ?> doFn) { DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); @@ -334,6 +344,7 @@ public class FlinkStreamingTransformTranslators { FlinkStreamingTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectTimers(doFn); WindowingStrategy<?, ?> windowingStrategy = @@ -519,6 +530,7 @@ public class FlinkStreamingTransformTranslators { FlinkStreamingTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectTimers(doFn); // we assume that the transformation does not change the windowing strategy. http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 697bb58..fa6b78e 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -946,6 +946,13 @@ public class DataflowPipelineTranslator { Map<Long, TupleTag<?>> outputMap) { DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support splittable DoFn: %s", + DataflowRunner.class.getSimpleName(), + fn)); + } stepContext.addInput(PropertyNames.USER_FN, fn.getClass().getName()); stepContext.addInput( http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 5ce1f77..14c14dc 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -24,6 +24,7 @@ import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutput import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFilePrefix; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.getOutputFileTemplate; import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceShardCount; +import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import com.google.common.collect.Maps; @@ -244,6 +245,7 @@ public final class TransformTranslator { public void evaluate(ParDo.Bound<InputT, OutputT> transform, EvaluationContext context) { String stepName = context.getCurrentTransform().getFullName(); DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectStateAndTimers(doFn); @SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRDD = @@ -271,6 +273,7 @@ public final class TransformTranslator { public void evaluate(ParDo.BoundMulti<InputT, OutputT> transform, EvaluationContext context) { String stepName = context.getCurrentTransform().getFullName(); DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectStateAndTimers(doFn); @SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRDD = http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index 4dcc705..890a91b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -223,6 +223,15 @@ public final class TranslationUtils { } } + public static void rejectSplittable(DoFn<?, ?> doFn) { + DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass()); + + if (signature.processElement().isSplittable()) { + throw new UnsupportedOperationException( + String.format( + "%s does not support splittable DoFn: %s", SparkRunner.class.getSimpleName(), doFn)); + } + } /** * Reject state and timers {@link DoFn}. * http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 36cd2f3..a49b959 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.spark.translation.streaming; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import com.google.common.collect.Maps; @@ -377,6 +378,7 @@ final class StreamingTransformTranslator { public void evaluate(final ParDo.Bound<InputT, OutputT> transform, final EvaluationContext context) { final DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectStateAndTimers(doFn); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final WindowingStrategy<?, ?> windowingStrategy = @@ -419,6 +421,7 @@ final class StreamingTransformTranslator { public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { final DoFn<InputT, OutputT> doFn = transform.getFn(); + rejectSplittable(doFn); rejectStateAndTimers(doFn); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final SparkPCollectionView pviews = context.getPViews(); http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java index 392cad7..6d8ad6a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/TestStream.java @@ -42,8 +42,10 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.PropertyNames; import org.apache.beam.sdk.util.VarInt; +import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TimestampedValue.TimestampedValueCoder; import org.apache.beam.sdk.values.TypeDescriptor; @@ -255,11 +257,9 @@ public final class TestStream<T> extends PTransform<PBegin, PCollection<T>> { @Override public PCollection<T> expand(PBegin input) { - throw new IllegalStateException( - String.format( - "Pipeline Runner %s does not provide a required override for %s", - input.getPipeline().getRunner().getClass().getSimpleName(), - getClass().getSimpleName())); + return PCollection.<T>createPrimitiveOutputInternal( + input.getPipeline(), WindowingStrategy.globalDefault(), IsBounded.UNBOUNDED) + .setCoder(coder); } public Coder<T> getValueCoder() { http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 5b4fa19..19c5a2d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -738,10 +738,6 @@ public class ParDo { @Override public PCollection<OutputT> expand(PCollection<? extends InputT> input) { - checkArgument( - !isSplittable(getFn()), - "%s does not support Splittable DoFn", - input.getPipeline().getOptions().getRunner().getName()); validateWindowType(input, fn); return PCollection.<OutputT>createPrimitiveOutputInternal( input.getPipeline(), @@ -932,10 +928,7 @@ public class ParDo { @Override public PCollectionTuple expand(PCollection<? extends InputT> input) { - checkArgument( - !isSplittable(getFn()), - "%s does not support Splittable DoFn", - input.getPipeline().getOptions().getRunner().getName()); + // SplittableDoFn should be forbidden on the runner-side. validateWindowType(input, fn); PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal( input.getPipeline(), http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java index a6a5f0e..1514601 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/testing/TestStreamTest.java @@ -24,12 +24,9 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.Assert.assertThat; import java.io.Serializable; -import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VarLongCoder; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.TestStream.Builder; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Flatten; @@ -316,20 +313,6 @@ public class TestStreamTest implements Serializable { } @Test - public void testUnsupportedRunnerThrows() { - PipelineOptions opts = PipelineOptionsFactory.create(); - opts.setRunner(CrashingRunner.class); - - Pipeline p = Pipeline.create(opts); - - thrown.expect(IllegalStateException.class); - thrown.expectMessage("does not provide a required override"); - thrown.expectMessage(TestStream.class.getSimpleName()); - thrown.expectMessage(CrashingRunner.class.getSimpleName()); - p.apply(TestStream.create(VarIntCoder.of()).advanceWatermarkToInfinity()); - } - - @Test public void testEncodeDecode() throws Exception { TestStream.Event<Integer> elems = TestStream.ElementEvent.add( http://git-wip-us.apache.org/repos/asf/beam/blob/77a1afb2/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index f40bbe1..6db0af4 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -2152,34 +2152,4 @@ public class ParDoTest implements Serializable { // If it doesn't crash, we made it! } - - @Test - public void testRejectsSplittableDoFnByDefault() { - // ParDo with a splittable DoFn must be overridden by the runner. - // Without an override, applying it directly must fail. - - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage(pipeline.getRunner().getClass().getName()); - thrown.expectMessage("does not support Splittable DoFn"); - - pipeline.apply(Create.of(1, 2, 3)).apply(ParDo.of(new TestSplittableDoFn())); - } - - @Test - public void testMultiRejectsSplittableDoFnByDefault() { - // ParDo with a splittable DoFn must be overridden by the runner. - // Without an override, applying it directly must fail. - - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage(pipeline.getRunner().getClass().getName()); - thrown.expectMessage("does not support Splittable DoFn"); - - pipeline - .apply(Create.of(1, 2, 3)) - .apply( - ParDo.of(new TestSplittableDoFn()) - .withOutputTags( - new TupleTag<String>("main") {}, - TupleTagList.of(new TupleTag<String>("side1") {}))); - } }