Repository: beam Updated Branches: refs/heads/master 33883ed88 -> b4c77167f
Roll-forward Include Additional PTransform inputs in Transform Nodes Update DirectGraph to have All and Non-Additional Inputs This reverts commit 247f9bc1581984d026764b3d433cb594e700bc21 Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/696f8b28 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/696f8b28 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/696f8b28 Branch: refs/heads/master Commit: 696f8b28a3a17e7de81e2d46bb9774d57d6e265e Parents: 33883ed Author: Thomas Groh <tg...@google.com> Authored: Tue Jun 6 17:00:09 2017 -0700 Committer: Thomas Groh <tg...@google.com> Committed: Fri Jun 9 15:00:33 2017 -0700 ---------------------------------------------------------------------- .../apex/translation/TranslationContext.java | 4 +- .../core/construction/TransformInputs.java | 50 ++++++ .../core/construction/TransformInputsTest.java | 166 +++++++++++++++++++ .../apache/beam/runners/direct/DirectGraph.java | 34 +++- .../beam/runners/direct/DirectGraphVisitor.java | 26 ++- .../direct/ExecutorServiceParallelExecutor.java | 2 +- .../runners/direct/ParDoEvaluatorFactory.java | 9 +- ...littableProcessElementsEvaluatorFactory.java | 2 + .../direct/StatefulParDoEvaluatorFactory.java | 1 + .../beam/runners/direct/WatermarkManager.java | 14 +- .../runners/direct/DirectGraphVisitorTest.java | 10 +- .../runners/direct/EvaluationContextTest.java | 2 +- .../beam/runners/direct/ParDoEvaluatorTest.java | 6 +- .../flink/FlinkBatchTranslationContext.java | 3 +- .../flink/FlinkStreamingTranslationContext.java | 3 +- .../dataflow/DataflowPipelineTranslator.java | 5 +- .../spark/translation/EvaluationContext.java | 4 +- .../beam/sdk/runners/TransformHierarchy.java | 28 +++- 18 files changed, 323 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java index aff3863..94d13e1 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.apex.translation.utils.ApexStateInternals; import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.CoderAdapterStreamCodec; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -93,7 +94,8 @@ class TranslationContext { } public <InputT extends PValue> InputT getInput() { - return (InputT) Iterables.getOnlyElement(getCurrentTransform().getInputs().values()); + return (InputT) + Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); } public Map<TupleTag<?>, PValue> getOutputs() { http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java new file mode 100644 index 0000000..2baf93a --- /dev/null +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformInputs.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.collect.ImmutableList; +import java.util.Collection; +import java.util.Map; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; + +/** Utilities for extracting subsets of inputs from an {@link AppliedPTransform}. */ +public class TransformInputs { + /** + * Gets all inputs of the {@link AppliedPTransform} that are not returned by {@link + * PTransform#getAdditionalInputs()}. + */ + public static Collection<PValue> nonAdditionalInputs(AppliedPTransform<?, ?, ?> application) { + ImmutableList.Builder<PValue> mainInputs = ImmutableList.builder(); + PTransform<?, ?> transform = application.getTransform(); + for (Map.Entry<TupleTag<?>, PValue> input : application.getInputs().entrySet()) { + if (!transform.getAdditionalInputs().containsKey(input.getKey())) { + mainInputs.add(input.getValue()); + } + } + checkArgument( + !mainInputs.build().isEmpty() || application.getInputs().isEmpty(), + "Expected at least one main input if any inputs exist"); + return mainInputs.build(); + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java new file mode 100644 index 0000000..f5b2c11 --- /dev/null +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/TransformInputsTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.core.construction; + +import static org.junit.Assert.assertThat; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.GenerateSequence; +import org.apache.beam.sdk.runners.AppliedPTransform; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; +import org.hamcrest.Matchers; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link TransformInputs}. */ +@RunWith(JUnit4.class) +public class TransformInputsTest { + @Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false); + @Rule public ExpectedException thrown = ExpectedException.none(); + + @Test + public void nonAdditionalInputsWithNoInputSucceeds() { + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "input-free", + Collections.<TupleTag<?>, PValue>emptyMap(), + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat(TransformInputs.nonAdditionalInputs(transform), Matchers.<PValue>empty()); + } + + @Test + public void nonAdditionalInputsWithOneMainInputSucceeds() { + PCollection<Long> input = pipeline.apply(GenerateSequence.from(1L)); + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "input-single", + Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>() {}, input), + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), Matchers.<PValue>containsInAnyOrder(input)); + } + + @Test + public void nonAdditionalInputsWithMultipleNonAdditionalInputsSucceeds() { + Map<TupleTag<?>, PValue> allInputs = new HashMap<>(); + PCollection<Integer> mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag<Integer>() {}, mainInts); + PCollection<Void> voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put(new TupleTag<Void>() {}, voids); + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "additional-free", + allInputs, + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.<PValue>containsInAnyOrder(voids, mainInts)); + } + + @Test + public void nonAdditionalInputsWithAdditionalInputsSucceeds() { + Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag<String>() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag<Long>() {}, pipeline.apply(GenerateSequence.from(3L))); + + Map<TupleTag<?>, PValue> allInputs = new HashMap<>(); + PCollection<Integer> mainInts = pipeline.apply("MainInput", Create.of(12, 3)); + allInputs.put(new TupleTag<Integer>() {}, mainInts); + PCollection<Void> voids = pipeline.apply("VoidInput", Create.empty(VoidCoder.of())); + allInputs.put( + new TupleTag<Void>() {}, voids); + allInputs.putAll(additionalInputs); + + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "additional", + allInputs, + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + assertThat( + TransformInputs.nonAdditionalInputs(transform), + Matchers.<PValue>containsInAnyOrder(mainInts, voids)); + } + + @Test + public void nonAdditionalInputsWithOnlyAdditionalInputsThrows() { + Map<TupleTag<?>, PValue> additionalInputs = new HashMap<>(); + additionalInputs.put(new TupleTag<String>() {}, pipeline.apply(Create.of("1, 2", "3"))); + additionalInputs.put(new TupleTag<Long>() {}, pipeline.apply(GenerateSequence.from(3L))); + + AppliedPTransform<PInput, POutput, TestTransform> transform = + AppliedPTransform.of( + "additional-only", + additionalInputs, + Collections.<TupleTag<?>, PValue>emptyMap(), + new TestTransform(additionalInputs), + pipeline); + + thrown.expect(IllegalArgumentException.class); + thrown.expectMessage("at least one"); + TransformInputs.nonAdditionalInputs(transform); + } + + private static class TestTransform extends PTransform<PInput, POutput> { + private final Map<TupleTag<?>, PValue> additionalInputs; + + private TestTransform() { + this(Collections.<TupleTag<?>, PValue>emptyMap()); + } + + private TestTransform(Map<TupleTag<?>, PValue> additionalInputs) { + this.additionalInputs = additionalInputs; + } + + @Override + public POutput expand(PInput input) { + return PDone.in(input.getPipeline()); + } + + @Override + public Map<TupleTag<?>, PValue> getAdditionalInputs() { + return additionalInputs; + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java index 9ca745d..ad17b2b 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraph.java @@ -17,6 +17,8 @@ */ package org.apache.beam.runners.direct; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.common.collect.ListMultimap; import java.util.Collection; import java.util.List; @@ -36,7 +38,8 @@ import org.apache.beam.sdk.values.PValue; class DirectGraph { private final Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers; private final Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters; - private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers; + private final ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers; + private final ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers; private final Set<AppliedPTransform<?, ?, ?>> rootTransforms; private final Map<AppliedPTransform<?, ?, ?>, String> stepNames; @@ -44,23 +47,36 @@ class DirectGraph { public static DirectGraph create( Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers, Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters, - ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers, + ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers, + ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers, Set<AppliedPTransform<?, ?, ?>> rootTransforms, Map<AppliedPTransform<?, ?, ?>, String> stepNames) { - return new DirectGraph(producers, viewWriters, primitiveConsumers, rootTransforms, stepNames); + return new DirectGraph( + producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames); } private DirectGraph( Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers, Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters, - ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers, + ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers, + ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers, Set<AppliedPTransform<?, ?, ?>> rootTransforms, Map<AppliedPTransform<?, ?, ?>, String> stepNames) { this.producers = producers; this.viewWriters = viewWriters; - this.primitiveConsumers = primitiveConsumers; + this.perElementConsumers = perElementConsumers; + this.allConsumers = allConsumers; this.rootTransforms = rootTransforms; this.stepNames = stepNames; + for (AppliedPTransform<?, ?, ?> step : stepNames.keySet()) { + for (PValue input : step.getInputs().values()) { + checkArgument( + allConsumers.get(input).contains(step), + "Step %s lists value %s as input, but it is not in the graph of consumers", + step.getFullName(), + input); + } + } } AppliedPTransform<?, ?, ?> getProducer(PCollection<?> produced) { @@ -71,8 +87,12 @@ class DirectGraph { return viewWriters.get(view); } - List<AppliedPTransform<?, ?, ?>> getPrimitiveConsumers(PValue consumed) { - return primitiveConsumers.get(consumed); + List<AppliedPTransform<?, ?, ?>> getPerElementConsumers(PValue consumed) { + return perElementConsumers.get(consumed); + } + + List<AppliedPTransform<?, ?, ?>> getAllConsumers(PValue consumed) { + return allConsumers.get(consumed); } Set<AppliedPTransform<?, ?, ?>> getRootTransforms() { http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 07bcf06..675de2c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -22,10 +22,12 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; @@ -37,6 +39,8 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the @@ -44,12 +48,15 @@ import org.apache.beam.sdk.values.PValue; * input after the upstream transform has produced and committed output. */ class DirectGraphVisitor extends PipelineVisitor.Defaults { + private static final Logger LOG = LoggerFactory.getLogger(DirectGraphVisitor.class); private Map<PCollection<?>, AppliedPTransform<?, ?, ?>> producers = new HashMap<>(); private Map<PCollectionView<?>, AppliedPTransform<?, ?, ?>> viewWriters = new HashMap<>(); private Set<PCollectionView<?>> consumedViews = new HashSet<>(); - private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> primitiveConsumers = + private ListMultimap<PInput, AppliedPTransform<?, ?, ?>> perElementConsumers = + ArrayListMultimap.create(); + private ListMultimap<PValue, AppliedPTransform<?, ?, ?>> allConsumers = ArrayListMultimap.create(); private Set<AppliedPTransform<?, ?, ?>> rootTransforms = new HashSet<>(); @@ -94,8 +101,19 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { if (node.getInputs().isEmpty()) { rootTransforms.add(appliedTransform); } else { + Collection<PValue> mainInputs = + TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(getPipeline())); + if (!mainInputs.containsAll(node.getInputs().values())) { + LOG.debug( + "Inputs reduced to {} from {} by removing additional inputs", + mainInputs, + node.getInputs().values()); + } + for (PValue value : mainInputs) { + perElementConsumers.put(value, appliedTransform); + } for (PValue value : node.getInputs().values()) { - primitiveConsumers.put(value, appliedTransform); + allConsumers.put(value, appliedTransform); } } if (node.getTransform() instanceof ParDo.MultiOutput) { @@ -106,7 +124,7 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { } } - @Override + @Override public void visitValue(PValue value, TransformHierarchy.Node producer) { AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(producer); if (value instanceof PCollection && !producers.containsKey(value)) { @@ -131,6 +149,6 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { public DirectGraph getGraph() { checkState(finalized, "Can't get a graph before the Pipeline has been completely traversed"); return DirectGraph.create( - producers, viewWriters, primitiveConsumers, rootTransforms, stepNames); + producers, viewWriters, perElementConsumers, allConsumers, rootTransforms, stepNames); } } http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 71ab4cc..6fe8ebd 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -355,7 +355,7 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { for (CommittedBundle<?> outputBundle : committedResult.getOutputs()) { allUpdates.offer( ExecutorUpdate.fromBundle( - outputBundle, graph.getPrimitiveConsumers(outputBundle.getPCollection()))); + outputBundle, graph.getPerElementConsumers(outputBundle.getPCollection()))); } CommittedBundle<?> unprocessedInputs = committedResult.getUnprocessedInputs(); if (unprocessedInputs != null && !Iterables.isEmpty(unprocessedInputs.getElements())) { http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java index 8aa75cf..516f798 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.direct; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; -import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -79,6 +78,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator (TransformEvaluator<T>) createEvaluator( (AppliedPTransform) application, + (PCollection<InputT>) inputBundle.getPCollection(), inputBundle.getKey(), doFn, transform.getSideInputs(), @@ -102,6 +102,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator @SuppressWarnings({"unchecked", "rawtypes"}) DoFnLifecycleManagerRemovingTransformEvaluator<InputT> createEvaluator( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> application, + PCollection<InputT> mainInput, StructuralKey<?> inputBundleKey, DoFn<InputT, OutputT> doFn, List<PCollectionView<?>> sideInputs, @@ -120,6 +121,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator createParDoEvaluator( application, inputBundleKey, + mainInput, sideInputs, mainOutputTag, additionalOutputTags, @@ -132,6 +134,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator ParDoEvaluator<InputT> createParDoEvaluator( AppliedPTransform<PCollection<InputT>, PCollectionTuple, ?> application, StructuralKey<?> key, + PCollection<InputT> mainInput, List<PCollectionView<?>> sideInputs, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, @@ -144,8 +147,7 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator evaluationContext, stepContext, application, - ((PCollection<InputT>) Iterables.getOnlyElement(application.getInputs().values())) - .getWindowingStrategy(), + mainInput.getWindowingStrategy(), fn, key, sideInputs, @@ -173,5 +175,4 @@ final class ParDoEvaluatorFactory<InputT, OutputT> implements TransformEvaluator } return pcs; } - } http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index b85f481c..eccc83a 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -116,6 +116,8 @@ class SplittableProcessElementsEvaluatorFactory< delegateFactory.createParDoEvaluator( application, inputBundle.getKey(), + (PCollection<KeyedWorkItem<String, ElementAndRestriction<InputT, RestrictionT>>>) + inputBundle.getPCollection(), transform.getSideInputs(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index 506c84c..3619d05 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -117,6 +117,7 @@ final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements Transfo DoFnLifecycleManagerRemovingTransformEvaluator<KV<K, InputT>> delegateEvaluator = delegateFactory.createEvaluator( (AppliedPTransform) application, + (PCollection) inputBundle.getPCollection(), inputBundle.getKey(), doFn, application.getTransform().getUnderlyingParDo().getSideInputs(), http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 40ce163..80a3504 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -54,6 +54,7 @@ import javax.annotation.concurrent.GuardedBy; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternals.TimerData; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.state.TimeDomain; @@ -62,7 +63,6 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Instant; /** @@ -831,11 +831,11 @@ class WatermarkManager { private Collection<Watermark> getInputProcessingWatermarks(AppliedPTransform<?, ?, ?> transform) { ImmutableList.Builder<Watermark> inputWmsBuilder = ImmutableList.builder(); - Map<TupleTag<?>, PValue> inputs = transform.getInputs(); + Collection<PValue> inputs = TransformInputs.nonAdditionalInputs(transform); if (inputs.isEmpty()) { inputWmsBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs.values()) { + for (PValue pvalue : inputs) { Watermark producerOutputWatermark = getValueWatermark(pvalue).synchronizedProcessingOutputWatermark; inputWmsBuilder.add(producerOutputWatermark); @@ -845,11 +845,11 @@ class WatermarkManager { private List<Watermark> getInputWatermarks(AppliedPTransform<?, ?, ?> transform) { ImmutableList.Builder<Watermark> inputWatermarksBuilder = ImmutableList.builder(); - Map<TupleTag<?>, PValue> inputs = transform.getInputs(); + Collection< PValue> inputs = TransformInputs.nonAdditionalInputs(transform); if (inputs.isEmpty()) { inputWatermarksBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs.values()) { + for (PValue pvalue : inputs) { Watermark producerOutputWatermark = getValueWatermark(pvalue).outputWatermark; inputWatermarksBuilder.add(producerOutputWatermark); } @@ -987,7 +987,7 @@ class WatermarkManager { // refresh. for (CommittedBundle<?> bundle : result.getOutputs()) { for (AppliedPTransform<?, ?, ?> consumer : - graph.getPrimitiveConsumers(bundle.getPCollection())) { + graph.getPerElementConsumers(bundle.getPCollection())) { TransformWatermarks watermarks = transformToWatermarks.get(consumer); watermarks.addPending(bundle); } @@ -1035,7 +1035,7 @@ class WatermarkManager { if (updateResult.isAdvanced()) { Set<AppliedPTransform<?, ?, ?>> additionalRefreshes = new HashSet<>(); for (PValue outputPValue : toRefresh.getOutputs().values()) { - additionalRefreshes.addAll(graph.getPrimitiveConsumers(outputPValue)); + additionalRefreshes.addAll(graph.getPerElementConsumers(outputPValue)); } return additionalRefreshes; } http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java index 576edf3..bf3e83e 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java @@ -151,13 +151,13 @@ public class DirectGraphVisitorTest implements Serializable { graph.getProducer(flattened); assertThat( - graph.getPrimitiveConsumers(created), + graph.getPerElementConsumers(created), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder( transformedProducer, flattenedProducer)); assertThat( - graph.getPrimitiveConsumers(transformed), + graph.getPerElementConsumers(transformed), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(flattenedProducer)); - assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable()); + assertThat(graph.getPerElementConsumers(flattened), emptyIterable()); } @Test @@ -173,10 +173,10 @@ public class DirectGraphVisitorTest implements Serializable { AppliedPTransform<?, ?, ?> flattenedProducer = graph.getProducer(flattened); assertThat( - graph.getPrimitiveConsumers(created), + graph.getPerElementConsumers(created), Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(flattenedProducer, flattenedProducer)); - assertThat(graph.getPrimitiveConsumers(flattened), emptyIterable()); + assertThat(graph.getPerElementConsumers(flattened), emptyIterable()); } @Test http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java index f3edf55..699a318 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/EvaluationContextTest.java @@ -414,7 +414,7 @@ public class EvaluationContextTest { StepTransformResult.withoutHold(unboundedProducer).build()); assertThat(context.isDone(), is(false)); - for (AppliedPTransform<?, ?, ?> consumers : graph.getPrimitiveConsumers(created)) { + for (AppliedPTransform<?, ?, ?> consumers : graph.getPerElementConsumers(created)) { context.handleResult( committedBundle, ImmutableList.<TimerData>of(), http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index df84cbf..7912538 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -98,7 +98,7 @@ public class ParDoEvaluatorTest { when(evaluationContext.createBundle(output)).thenReturn(outputBundle); ParDoEvaluator<Integer> evaluator = - createEvaluator(singletonView, fn, output); + createEvaluator(singletonView, fn, inputPc, output); IntervalWindow nonGlobalWindow = new IntervalWindow(new Instant(0), new Instant(10_000L)); WindowedValue<Integer> first = WindowedValue.valueInGlobalWindow(3); @@ -132,6 +132,7 @@ public class ParDoEvaluatorTest { private ParDoEvaluator<Integer> createEvaluator( PCollectionView<Integer> singletonView, RecorderFn fn, + PCollection<Integer> input, PCollection<Integer> output) { when( evaluationContext.createSideInputReader( @@ -157,8 +158,7 @@ public class ParDoEvaluatorTest { evaluationContext, stepContext, transform, - ((PCollection<?>) Iterables.getOnlyElement(transform.getInputs().values())) - .getWindowingStrategy(), + input.getWindowingStrategy(), fn, null /* key */, ImmutableList.<PCollectionView<?>>of(singletonView), http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java index 0439119..6e70198 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java @@ -20,6 +20,7 @@ package org.apache.beam.runners.flink; import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -143,7 +144,7 @@ class FlinkBatchTranslationContext { @SuppressWarnings("unchecked") <T extends PValue> T getInput(PTransform<T, ?> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } Map<TupleTag<?>, PValue> getOutputs(PTransform<?, ?> transform) { http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java index ea5f6b3..74a5fb9 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java @@ -22,6 +22,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.Map; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; @@ -113,7 +114,7 @@ class FlinkStreamingTranslationContext { @SuppressWarnings("unchecked") public <T extends PValue> T getInput(PTransform<T, ?> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); + return (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform)); } public <T extends PInput> Map<TupleTag<?>, PValue> getInputs(PTransform<T, ?> transform) { http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index a3a7ab6..afc34e6 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -56,6 +56,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nullable; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.core.construction.WindowingStrategyTranslation; import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly; import org.apache.beam.runners.dataflow.DataflowRunner.CombineGroupedValues; @@ -395,7 +396,9 @@ public class DataflowPipelineTranslator { @Override public <InputT extends PValue> InputT getInput(PTransform<InputT, ?> transform) { - return (InputT) Iterables.getOnlyElement(getInputs(transform).values()); + return (InputT) + Iterables.getOnlyElement( + TransformInputs.nonAdditionalInputs(getCurrentTransform(transform))); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 8102926..0c6c4d1 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -26,6 +26,7 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.runners.core.construction.TransformInputs; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.Pipeline; @@ -103,7 +104,8 @@ public class EvaluationContext { public <T extends PValue> T getInput(PTransform<T, ?> transform) { @SuppressWarnings("unchecked") - T input = (T) Iterables.getOnlyElement(getInputs(transform).values()); + T input = + (T) Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(getCurrentTransform())); return input; } http://git-wip-us.apache.org/repos/asf/beam/blob/696f8b28/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 5e048eb..9c5f148 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -34,7 +34,6 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; -import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior; @@ -71,7 +70,7 @@ public class TransformHierarchy { producers = new HashMap<>(); producerInput = new HashMap<>(); unexpandedInputs = new HashMap<>(); - root = new Node(null, null, "", null); + root = new Node(); current = root; } @@ -297,25 +296,36 @@ public class TransformHierarchy { boolean finishedSpecifying = false; /** + * Creates the root-level node. The root level node has a null enclosing node, a null transform, + * an empty map of inputs, and a name equal to the empty string. + */ + private Node() { + this.enclosingNode = null; + this.transform = null; + this.fullName = ""; + this.inputs = Collections.emptyMap(); + } + + /** * Creates a new Node with the given parent and transform. * - * <p>EnclosingNode and transform may both be null for a root-level node, which holds all other - * nodes. - * * @param enclosingNode the composite node containing this node * @param transform the PTransform tracked by this node * @param fullName the fully qualified name of the transform * @param input the unexpanded input to the transform */ private Node( - @Nullable Node enclosingNode, - @Nullable PTransform<?, ?> transform, + Node enclosingNode, + PTransform<?, ?> transform, String fullName, - @Nullable PInput input) { + PInput input) { this.enclosingNode = enclosingNode; this.transform = transform; this.fullName = fullName; - this.inputs = input == null ? Collections.<TupleTag<?>, PValue>emptyMap() : input.expand(); + ImmutableMap.Builder<TupleTag<?>, PValue> inputs = ImmutableMap.builder(); + inputs.putAll(input.expand()); + inputs.putAll(transform.getAdditionalInputs()); + this.inputs = inputs.build(); } /**