http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java index 8b4573f..b44c890 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGraphVisitorTest.java @@ -45,7 +45,6 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.TaggedPValue; import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; @@ -101,9 +100,9 @@ public class DirectGraphVisitorTest implements Serializable { graph.getProducer(created), graph.getProducer(counted), graph.getProducer(unCounted))); for (AppliedPTransform<?, ?, ?> root : graph.getRootTransforms()) { // Root transforms will have no inputs - assertThat(root.getInputs(), emptyIterable()); + assertThat(root.getInputs().entrySet(), emptyIterable()); assertThat( - Iterables.getOnlyElement(root.getOutputs()).getValue(), + Iterables.getOnlyElement(root.getOutputs().values()), Matchers.<POutput>isOneOf(created, counted, unCounted)); } } @@ -121,7 +120,7 @@ public class DirectGraphVisitorTest implements Serializable { Matchers.<AppliedPTransform<?, ?, ?>>containsInAnyOrder(graph.getProducer(empty))); AppliedPTransform<?, ?, ?> onlyRoot = Iterables.getOnlyElement(graph.getRootTransforms()); assertThat(onlyRoot.getTransform(), Matchers.<PTransform<?, ?>>equalTo(flatten)); - assertThat(onlyRoot.getInputs(), Matchers.<TaggedPValue>emptyIterable()); + assertThat(onlyRoot.getInputs().entrySet(), emptyIterable()); assertThat(onlyRoot.getOutputs(), equalTo(empty.expand())); }
http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index c85b85e..2a94d48 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -162,7 +162,7 @@ public class ParDoEvaluatorTest { evaluationContext, stepContext, transform, - ((PCollection<?>) Iterables.getOnlyElement(transform.getInputs()).getValue()) + ((PCollection<?>) Iterables.getOnlyElement(transform.getInputs().values())) .getWindowingStrategy(), fn, null /* key */, http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index 946cd69..ecb8130 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -308,7 +308,7 @@ public class StatefulParDoEvaluatorFactoryTest implements Serializable { BUNDLE_FACTORY .createBundle( (PCollection<KeyedWorkItem<String, KV<String, Integer>>>) - Iterables.getOnlyElement(producingTransform.getInputs()).getValue()) + Iterables.getOnlyElement(producingTransform.getInputs().values())) .add(gbkOutputElement) .commit(Instant.now()); TransformEvaluator<KeyedWorkItem<String, KV<String, Integer>>> evaluator = http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java index fc689fe..0d909c2 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java @@ -40,8 +40,9 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TupleTag; import org.hamcrest.Matchers; import org.joda.time.Duration; import org.joda.time.Instant; @@ -183,7 +184,7 @@ public class TestStreamEvaluatorFactoryTest { @Test public void overrideFactoryGetInputSucceeds() { DirectTestStreamFactory<?> factory = new DirectTestStreamFactory<>(runner); - PBegin begin = factory.getInput(Collections.<TaggedPValue>emptyList(), p); + PBegin begin = factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), p); assertThat(begin.getPipeline(), Matchers.<Pipeline>equalTo(p)); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java index 6dcc13c..258cb46 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java @@ -107,7 +107,7 @@ public class ViewOverrideFactoryTest implements Serializable { is(false)); PCollectionView replacementView = ((WriteView) node.getTransform()).getView(); assertThat(replacementView, Matchers.<PCollectionView>theInstance(view)); - assertThat(node.getInputs(), hasSize(1)); + assertThat(node.getInputs().entrySet(), hasSize(1)); } } }); http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 1d6728b..ff9521c 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -27,6 +27,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMergingNonShuffleReduceFunction; @@ -71,7 +72,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FlatMapFunction; @@ -511,15 +511,15 @@ class FlinkBatchTransformTranslators { DataSet<WindowedValue<InputT>> inputDataSet = context.getInputDataSet(context.getInput(transform)); - List<TaggedPValue> outputs = context.getOutputs(transform); + Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform); Map<TupleTag<?>, Integer> outputMap = Maps.newHashMap(); // put the main output at index 0, FlinkMultiOutputDoFnFunction expects this outputMap.put(transform.getMainOutputTag(), 0); int count = 1; - for (TaggedPValue taggedValue : outputs) { - if (!outputMap.containsKey(taggedValue.getTag())) { - outputMap.put(taggedValue.getTag(), count++); + for (TupleTag<?> tag : outputs.keySet()) { + if (!outputMap.containsKey(tag)) { + outputMap.put(tag, count++); } } @@ -528,13 +528,13 @@ class FlinkBatchTransformTranslators { // collect all output Coders and create a UnionCoder for our tagged outputs List<Coder<?>> outputCoders = Lists.newArrayList(); - for (TaggedPValue taggedValue : outputs) { + for (PValue taggedValue : outputs.values()) { checkState( - taggedValue.getValue() instanceof PCollection, + taggedValue instanceof PCollection, "Within ParDo, got a non-PCollection output %s of type %s", - taggedValue.getValue(), - taggedValue.getValue().getClass().getSimpleName()); - PCollection<?> coll = (PCollection<?>) taggedValue.getValue(); + taggedValue, + taggedValue.getClass().getSimpleName()); + PCollection<?> coll = (PCollection<?>) taggedValue; outputCoders.add(coll.getCoder()); windowingStrategy = coll.getWindowingStrategy(); } @@ -599,11 +599,11 @@ class FlinkBatchTransformTranslators { transformSideInputs(sideInputs, outputDataSet, context); - for (TaggedPValue output : outputs) { + for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { pruneOutput( outputDataSet, context, - outputMap.get(output.getTag()), + outputMap.get(output.getKey()), (PCollection) output.getValue()); } @@ -640,7 +640,7 @@ class FlinkBatchTransformTranslators { Flatten.PCollections<T> transform, FlinkBatchTranslationContext context) { - List<TaggedPValue> allInputs = context.getInputs(transform); + Map<TupleTag<?>, PValue> allInputs = context.getInputs(transform); DataSet<WindowedValue<T>> result = null; if (allInputs.isEmpty()) { @@ -661,13 +661,13 @@ class FlinkBatchTransformTranslators { (Coder<T>) VoidCoder.of(), GlobalWindow.Coder.INSTANCE))); } else { - for (TaggedPValue taggedPc : allInputs) { + for (PValue taggedPc : allInputs.values()) { checkArgument( - taggedPc.getValue() instanceof PCollection, + taggedPc instanceof PCollection, "Got non-PCollection input to flatten: %s of type %s", - taggedPc.getValue(), - taggedPc.getValue().getClass().getSimpleName()); - PCollection<T> collection = (PCollection<T>) taggedPc.getValue(); + taggedPc, + taggedPc.getClass().getSimpleName()); + PCollection<T> collection = (PCollection<T>) taggedPc; DataSet<WindowedValue<T>> current = context.getInputDataSet(collection); if (result == null) { result = current; http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java index cb69575..98dd0fb 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkBatchTranslationContext.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.flink; import com.google.common.collect.Iterables; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; @@ -31,7 +30,7 @@ import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; @@ -134,21 +133,21 @@ class FlinkBatchTranslationContext { return new CoderTypeInformation<>(windowedValueCoder); } - List<TaggedPValue> getInputs(PTransform<?, ?> transform) { + Map<TupleTag<?>, PValue> getInputs(PTransform<?, ?> transform) { return currentTransform.getInputs(); } @SuppressWarnings("unchecked") <T extends PValue> T getInput(PTransform<T, ?> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs()).getValue(); + return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); } - List<TaggedPValue> getOutputs(PTransform<?, ?> transform) { + Map<TupleTag<?>, PValue> getOutputs(PTransform<?, ?> transform) { return currentTransform.getOutputs(); } @SuppressWarnings("unchecked") <T extends PValue> T getOutput(PTransform<?, T> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getOutputs()).getValue(); + return (T) Iterables.getOnlyElement(currentTransform.getOutputs().values()); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java index 8b5637e..70da2b3 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java @@ -39,7 +39,7 @@ import org.apache.beam.sdk.util.InstanceBuilder; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -259,14 +259,13 @@ class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator { } @Override - public PCollection<? extends InputT> getInput( - List<TaggedPValue> inputs, Pipeline p) { - return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs).getValue(); + public PCollection<? extends InputT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs.values()); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollectionTuple newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollectionTuple newOutput) { return ReplacementOutputs.tagged(outputs, newOutput); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 5c29db2..af157f0 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -29,6 +29,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.apache.beam.runners.core.ElementAndRestriction; import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.SplittableParDo; @@ -79,7 +80,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; @@ -420,7 +420,7 @@ class FlinkStreamingTransformTranslators { DoFn<InputT, OutputT> doFn, PCollection<InputT> input, List<PCollectionView<?>> sideInputs, - List<TaggedPValue> outputs, + Map<TupleTag<?>, PValue> outputs, TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> sideOutputTags, FlinkStreamingTranslationContext context, @@ -537,8 +537,8 @@ class FlinkStreamingTransformTranslators { } }); - for (TaggedPValue output : outputs) { - final int outputTag = tagsToLabels.get(output.getTag()); + for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { + final int outputTag = tagsToLabels.get(output.getKey()); TypeInformation outputTypeInfo = context.getTypeInfo((PCollection<?>) output.getValue()); @@ -557,28 +557,28 @@ class FlinkStreamingTransformTranslators { private static Map<TupleTag<?>, Integer> transformTupleTagsToLabels( TupleTag<?> mainTag, - List<TaggedPValue> allTaggedValues) { + Map<TupleTag<?>, PValue> allTaggedValues) { Map<TupleTag<?>, Integer> tagToLabelMap = Maps.newHashMap(); int count = 0; tagToLabelMap.put(mainTag, count++); - for (TaggedPValue taggedPValue : allTaggedValues) { - if (!tagToLabelMap.containsKey(taggedPValue.getTag())) { - tagToLabelMap.put(taggedPValue.getTag(), count++); + for (TupleTag<?> key : allTaggedValues.keySet()) { + if (!tagToLabelMap.containsKey(key)) { + tagToLabelMap.put(key, count++); } } return tagToLabelMap; } - private static UnionCoder createUnionCoder(Collection<TaggedPValue> taggedCollections) { + private static UnionCoder createUnionCoder(Map<TupleTag<?>, PValue> taggedCollections) { List<Coder<?>> outputCoders = Lists.newArrayList(); - for (TaggedPValue taggedColl : taggedCollections) { + for (PValue taggedColl : taggedCollections.values()) { checkArgument( - taggedColl.getValue() instanceof PCollection, + taggedColl instanceof PCollection, "A Union Coder can only be created for a Collection of Tagged %s. Got %s", PCollection.class.getSimpleName(), - taggedColl.getValue().getClass().getSimpleName()); - PCollection<?> coll = (PCollection<?>) taggedColl.getValue(); + taggedColl.getClass().getSimpleName()); + PCollection<?> coll = (PCollection<?>) taggedColl; WindowedValue.FullWindowedValueCoder<?> windowedValueCoder = WindowedValue.getFullCoder( coll.getCoder(), @@ -1042,7 +1042,7 @@ class FlinkStreamingTransformTranslators { public void translateNode( Flatten.PCollections<T> transform, FlinkStreamingTranslationContext context) { - List<TaggedPValue> allInputs = context.getInputs(transform); + Map<TupleTag<?>, PValue> allInputs = context.getInputs(transform); if (allInputs.isEmpty()) { @@ -1069,8 +1069,8 @@ class FlinkStreamingTransformTranslators { } else { DataStream<T> result = null; - for (TaggedPValue input : allInputs) { - DataStream<T> current = context.getInputDataStream(input.getValue()); + for (PValue input : allInputs.values()) { + DataStream<T> current = context.getInputDataStream(input); result = (result == null) ? current : result.union(current); } context.setOutputDataStream(context.getOutput(transform), result); http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java index 3d5b83f..1a943a3 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTranslationContext.java @@ -21,7 +21,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.collect.Iterables; import java.util.HashMap; -import java.util.List; import java.util.Map; import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.sdk.coders.Coder; @@ -33,7 +32,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -111,19 +110,20 @@ class FlinkStreamingTranslationContext { @SuppressWarnings("unchecked") public <T extends PValue> T getInput(PTransform<T, ?> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getInputs()).getValue(); + return (T) Iterables.getOnlyElement(currentTransform.getInputs().values()); } - public <T extends PInput> List<TaggedPValue> getInputs(PTransform<T, ?> transform) { + public <T extends PInput> Map<TupleTag<?>, PValue> getInputs(PTransform<T, ?> transform) { return currentTransform.getInputs(); } @SuppressWarnings("unchecked") public <T extends PValue> T getOutput(PTransform<?, T> transform) { - return (T) Iterables.getOnlyElement(currentTransform.getOutputs()).getValue(); + return (T) Iterables.getOnlyElement(currentTransform.getOutputs().values()); } - public <OutputT extends POutput> List<TaggedPValue> getOutputs(PTransform<?, OutputT> transform) { + public <OutputT extends POutput> Map<TupleTag<?>, PValue> getOutputs( + PTransform<?, OutputT> transform) { return currentTransform.getOutputs(); } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index 1d19d64..3ded079 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -20,7 +20,6 @@ package org.apache.beam.runners.dataflow; import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.Iterables; -import java.util.List; import java.util.Map; import org.apache.beam.runners.core.construction.ReplacementOutputs; import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly; @@ -42,7 +41,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptor; import org.joda.time.Instant; @@ -93,13 +92,13 @@ public class BatchStatefulParDoOverrides { } @Override - public PCollection<KV<K, InputT>> getInput(List<TaggedPValue> inputs, Pipeline p) { - return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs).getValue(); + public PCollection<KV<K, InputT>> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values()); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollection<OutputT> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<OutputT> newOutput) { return ReplacementOutputs.singleton(outputs, newOutput); } } @@ -116,13 +115,13 @@ public class BatchStatefulParDoOverrides { } @Override - public PCollection<KV<K, InputT>> getInput(List<TaggedPValue> inputs, Pipeline p) { - return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs).getValue(); + public PCollection<KV<K, InputT>> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values()); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollectionTuple newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollectionTuple newOutput) { return ReplacementOutputs.tagged(outputs, newOutput); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index ab9df70..1a2e663 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -97,7 +97,6 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypedPValue; import org.slf4j.Logger; @@ -371,24 +370,25 @@ public class DataflowPipelineTranslator { } @Override - public <InputT extends PInput> List<TaggedPValue> getInputs(PTransform<InputT, ?> transform) { + public <InputT extends PInput> Map<TupleTag<?>, PValue> getInputs( + PTransform<InputT, ?> transform) { return getCurrentTransform(transform).getInputs(); } @Override public <InputT extends PValue> InputT getInput(PTransform<InputT, ?> transform) { - return (InputT) Iterables.getOnlyElement(getInputs(transform)).getValue(); + return (InputT) Iterables.getOnlyElement(getInputs(transform).values()); } @Override - public <OutputT extends POutput> List<TaggedPValue> getOutputs( + public <OutputT extends POutput> Map<TupleTag<?>, PValue> getOutputs( PTransform<?, OutputT> transform) { return getCurrentTransform(transform).getOutputs(); } @Override public <OutputT extends PValue> OutputT getOutput(PTransform<?, OutputT> transform) { - return (OutputT) Iterables.getOnlyElement(getOutputs(transform)).getValue(); + return (OutputT) Iterables.getOnlyElement(getOutputs(transform).values()); } @Override @@ -758,10 +758,10 @@ public class DataflowPipelineTranslator { StepTranslationContext stepContext = context.addStep(transform, "Flatten"); List<OutputReference> inputs = new LinkedList<>(); - for (TaggedPValue input : context.getInputs(transform)) { + for (PValue input : context.getInputs(transform).values()) { inputs.add( context.asOutputReference( - input.getValue(), context.getProducer(input.getValue()))); + input, context.getProducer(input))); } stepContext.addInput(PropertyNames.INPUTS, inputs); stepContext.addOutput(context.getOutput(transform)); @@ -967,11 +967,11 @@ public class DataflowPipelineTranslator { } private static BiMap<Long, TupleTag<?>> translateOutputs( - List<TaggedPValue> outputs, + Map<TupleTag<?>, PValue> outputs, StepTranslationContext stepContext) { ImmutableBiMap.Builder<Long, TupleTag<?>> mapBuilder = ImmutableBiMap.builder(); - for (TaggedPValue taggedOutput : outputs) { - TupleTag<?> tag = taggedOutput.getTag(); + for (Map.Entry<TupleTag<?>, PValue> taggedOutput : outputs.entrySet()) { + TupleTag<?> tag = taggedOutput.getKey(); checkArgument(taggedOutput.getValue() instanceof PCollection, "Non %s returned from Multi-output %s", PCollection.class.getSimpleName(), http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index f789769..9b993f4 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -127,7 +127,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.joda.time.DateTimeUtils; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; @@ -450,13 +450,13 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } @Override - public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) { + public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { return p.begin(); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollection<T> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) { return ReplacementOutputs.singleton(outputs, newOutput); } } @@ -760,7 +760,7 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { if (node.getTransform() instanceof View.AsMap || node.getTransform() instanceof View.AsMultimap) { PCollection<KV<?, ?>> input = - (PCollection<KV<?, ?>>) Iterables.getOnlyElement(node.getInputs()).getValue(); + (PCollection<KV<?, ?>>) Iterables.getOnlyElement(node.getInputs().values()); KvCoder<?, ?> inputCoder = (KvCoder) input.getCoder(); try { inputCoder.getKeyCoder().verifyDeterministic(); @@ -825,13 +825,13 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } @Override - public PCollection<T> getInput(List<TaggedPValue> inputs, Pipeline p) { - return (PCollection<T>) Iterables.getOnlyElement(inputs).getValue(); + public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return (PCollection<T>) Iterables.getOnlyElement(inputs.values()); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PDone newOutput) { + Map<TupleTag<?>, PValue> outputs, PDone newOutput) { return Collections.emptyMap(); } } @@ -1317,13 +1317,13 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { @Override public PCollection<KV<K, Iterable<InputT>>> getInput( - List<TaggedPValue> inputs, Pipeline p) { - return (PCollection<KV<K, Iterable<InputT>>>) Iterables.getOnlyElement(inputs).getValue(); + Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return (PCollection<KV<K, Iterable<InputT>>>) Iterables.getOnlyElement(inputs.values()); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollection<KV<K, OutputT>> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<KV<K, OutputT>> newOutput) { return ReplacementOutputs.singleton(outputs, newOutput); } } @@ -1343,12 +1343,13 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { } @Override - public PCollection<T> getInput(List<TaggedPValue> inputs, Pipeline p) { - return (PCollection<T>) Iterables.getOnlyElement(inputs).getValue(); + public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { + return (PCollection<T>) Iterables.getOnlyElement(inputs.values()); } @Override - public Map<PValue, ReplacementOutput> mapOutputs(List<TaggedPValue> outputs, PDone newOutput) { + public Map<PValue, ReplacementOutput> mapOutputs( + Map<TupleTag<?>, PValue> outputs, PDone newOutput) { return Collections.emptyMap(); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java index e020e83..52b3a31 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/TransformTranslator.java @@ -29,7 +29,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; /** * A {@link TransformTranslator} knows how to translate a particular subclass of {@link PTransform} @@ -47,12 +47,12 @@ interface TransformTranslator<TransformT extends PTransform> { DataflowPipelineOptions getPipelineOptions(); /** Returns the input of the currently being translated transform. */ - <InputT extends PInput> List<TaggedPValue> getInputs(PTransform<InputT, ?> transform); + <InputT extends PInput> Map<TupleTag<?>, PValue> getInputs(PTransform<InputT, ?> transform); <InputT extends PValue> InputT getInput(PTransform<InputT, ?> transform); /** Returns the output of the currently being translated transform. */ - <OutputT extends POutput> List<TaggedPValue> getOutputs(PTransform<?, OutputT> transform); + <OutputT extends POutput> Map<TupleTag<?>, PValue> getOutputs(PTransform<?, OutputT> transform); <OutputT extends PValue> OutputT getOutput(PTransform<?, OutputT> transform); http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java ---------------------------------------------------------------------- diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java index e3d2e4e..e7f2b48 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineJobTest.java @@ -72,7 +72,8 @@ import org.apache.beam.sdk.util.NoopPathValidator; import org.apache.beam.sdk.util.TestCredential; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; import org.joda.time.Duration; import org.junit.Before; import org.junit.Rule; @@ -689,8 +690,8 @@ public class DataflowPipelineJobTest { when(input.getPipeline()).thenReturn(p); return AppliedPTransform.of( fullName, - Collections.<TaggedPValue>emptyList(), - Collections.<TaggedPValue>emptyList(), + Collections.<TupleTag<?>, PValue>emptyMap(), + Collections.<TupleTag<?>, PValue>emptyMap(), transform, p); } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index 5b4f73e..97487f3 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -22,6 +22,7 @@ import com.google.common.collect.Iterables; import java.util.Arrays; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; @@ -55,7 +56,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.SparkEnv$; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.metrics.MetricsSystem; @@ -315,8 +316,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { // The goal is to detect the PCollections accessed more than one time, and so enable cache // on the underlying RDDs or DStreams. - for (TaggedPValue input : node.getInputs()) { - PValue value = input.getValue(); + for (PValue value : node.getInputs().values()) { if (value instanceof PCollection) { long count = 1L; if (ctxt.getCacheCandidates().get(value) != null) { @@ -362,7 +362,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { if (node.getInputs().size() != 1) { return false; } - PValue input = Iterables.getOnlyElement(node.getInputs()).getValue(); + PValue input = Iterables.getOnlyElement(node.getInputs().values()); if (!(input instanceof PCollection) || ((PCollection) input).getWindowingStrategy().getWindowFn().isNonMerging()) { return false; @@ -420,14 +420,14 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { //--- determine if node is bounded/unbounded. // usually, the input determines if the PCollection to apply the next transformation to // is BOUNDED or UNBOUNDED, meaning RDD/DStream. - Collection<TaggedPValue> pValues; + Map<TupleTag<?>, PValue> pValues; if (node.getInputs().isEmpty()) { // in case of a PBegin, it's the output. pValues = node.getOutputs(); } else { pValues = node.getInputs(); } - PCollection.IsBounded isNodeBounded = isBoundedCollection(pValues); + PCollection.IsBounded isNodeBounded = isBoundedCollection(pValues.values()); // translate accordingly. LOG.debug("Translating {} as {}", transform, isNodeBounded); return isNodeBounded.equals(PCollection.IsBounded.BOUNDED) @@ -435,15 +435,15 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { : translator.translateUnbounded(transformClass); } - protected PCollection.IsBounded isBoundedCollection(Collection<TaggedPValue> pValues) { + protected PCollection.IsBounded isBoundedCollection(Collection<PValue> pValues) { // anything that is not a PCollection, is BOUNDED. // For PCollections: // BOUNDED behaves as the Identity Element, BOUNDED + BOUNDED = BOUNDED // while BOUNDED + UNBOUNDED = UNBOUNDED. PCollection.IsBounded isBounded = PCollection.IsBounded.BOUNDED; - for (TaggedPValue pValue : pValues) { - if (pValue.getValue() instanceof PCollection) { - isBounded = isBounded.and(((PCollection) pValue.getValue()).isBounded()); + for (PValue pValue : pValues) { + if (pValue instanceof PCollection) { + isBounded = isBounded.and(((PCollection) pValue).isBounded()); } else { isBounded = isBounded.and(PCollection.IsBounded.BOUNDED); } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index fcc00f9..aacb942 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -28,7 +28,6 @@ import com.google.common.util.concurrent.Uninterruptibles; import java.io.File; import java.io.IOException; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; import org.apache.beam.runners.core.construction.PTransformMatchers; @@ -53,7 +52,7 @@ import org.apache.beam.sdk.util.ValueWithRecordId; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.commons.io.FileUtils; import org.joda.time.Duration; import org.joda.time.Instant; @@ -251,13 +250,13 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { } @Override - public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) { + public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { return p.begin(); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollection<T> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<T> newOutput) { return ReplacementOutputs.singleton(outputs, newOutput); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java index 643749d..838c504 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java @@ -24,7 +24,6 @@ import com.google.common.collect.Iterables; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; import org.apache.beam.runners.spark.SparkPipelineOptions; @@ -37,7 +36,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -96,11 +95,11 @@ public class EvaluationContext { public <T extends PValue> T getInput(PTransform<T, ?> transform) { @SuppressWarnings("unchecked") - T input = (T) Iterables.getOnlyElement(getInputs(transform)).getValue(); + T input = (T) Iterables.getOnlyElement(getInputs(transform).values()); return input; } - public <T> List<TaggedPValue> getInputs(PTransform<?, ?> transform) { + public <T> Map<TupleTag<?>, PValue> getInputs(PTransform<?, ?> transform) { checkArgument(currentTransform != null && currentTransform.getTransform() == transform, "can only be called with current transform"); return currentTransform.getInputs(); @@ -108,11 +107,11 @@ public class EvaluationContext { public <T extends PValue> T getOutput(PTransform<?, T> transform) { @SuppressWarnings("unchecked") - T output = (T) Iterables.getOnlyElement(getOutputs(transform)).getValue(); + T output = (T) Iterables.getOnlyElement(getOutputs(transform).values()); return output; } - public List<TaggedPValue> getOutputs(PTransform<?, ?> transform) { + public Map<TupleTag<?>, PValue> getOutputs(PTransform<?, ?> transform) { checkArgument(currentTransform != null && currentTransform.getTransform() == transform, "can only be called with current transform"); return currentTransform.getOutputs(); http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 7894c4e..c2a8b06 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -26,9 +26,10 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectS import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import java.util.Collection; import java.util.Collections; -import java.util.List; import java.util.Map; +import java.util.Map.Entry; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; import org.apache.beam.runners.spark.aggregators.NamedAggregators; @@ -61,7 +62,7 @@ import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; @@ -83,19 +84,21 @@ public final class TransformTranslator { @SuppressWarnings("unchecked") @Override public void evaluate(Flatten.PCollections<T> transform, EvaluationContext context) { - List<TaggedPValue> pcs = context.getInputs(transform); + Collection<PValue> pcs = context.getInputs(transform).values(); JavaRDD<WindowedValue<T>> unionRDD; if (pcs.size() == 0) { unionRDD = context.getSparkContext().emptyRDD(); } else { JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()]; - for (int i = 0; i < rdds.length; i++) { + int index = 0; + for (PValue pc : pcs) { checkArgument( - pcs.get(i).getValue() instanceof PCollection, + pc instanceof PCollection, "Flatten had non-PCollection value in input: %s of type %s", - pcs.get(i).getValue(), - pcs.get(i).getValue().getClass().getSimpleName()); - rdds[i] = ((BoundedDataset<T>) context.borrowDataset(pcs.get(i).getValue())).getRDD(); + pc, + pc.getClass().getSimpleName()); + rdds[index] = ((BoundedDataset<T>) context.borrowDataset(pc)).getRDD(); + index++; } unionRDD = context.getSparkContext().union(rdds); } @@ -360,15 +363,15 @@ public final class TransformTranslator { transform.getMainOutputTag(), TranslationUtils.getSideInputs(transform.getSideInputs(), context), windowingStrategy)); - List<TaggedPValue> outputs = context.getOutputs(transform); + Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform); if (outputs.size() > 1) { // cache the RDD if we're going to filter it more than once. all.cache(); } - for (TaggedPValue output : outputs) { + for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { @SuppressWarnings("unchecked") JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = - all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); + all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaRDD<WindowedValue<Object>> values = http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index d4c6c9d..65892d2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -77,7 +77,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.Accumulator; @@ -191,19 +191,19 @@ public final class StreamingTransformTranslator { @SuppressWarnings("unchecked") @Override public void evaluate(Flatten.PCollections<T> transform, EvaluationContext context) { - List<TaggedPValue> pcs = context.getInputs(transform); + Map<TupleTag<?>, PValue> pcs = context.getInputs(transform); // since this is a streaming pipeline, at least one of the PCollections to "flatten" are // unbounded, meaning it represents a DStream. // So we could end up with an unbounded unified DStream. final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>(); final List<Integer> streamingSources = new ArrayList<>(); - for (TaggedPValue pv : pcs) { + for (PValue pv : pcs.values()) { checkArgument( - pv.getValue() instanceof PCollection, + pv instanceof PCollection, "Flatten had non-PCollection value in input: %s of type %s", - pv.getValue(), - pv.getValue().getClass().getSimpleName()); - PCollection<T> pcol = (PCollection<T>) pv.getValue(); + pv, + pv.getClass().getSimpleName()); + PCollection<T> pcol = (PCollection<T>) pv; Dataset dataset = context.borrowDataset(pcol); if (dataset instanceof UnboundedDataset) { UnboundedDataset<T> unboundedDataset = (UnboundedDataset<T>) dataset; @@ -416,15 +416,15 @@ public final class StreamingTransformTranslator { windowingStrategy)); } }); - List<TaggedPValue> outputs = context.getOutputs(transform); + Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform); if (outputs.size() > 1) { // cache the DStream if we're going to filter it more than once. all.cache(); } - for (TaggedPValue output : outputs) { + for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { @SuppressWarnings("unchecked") JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered = - all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); + all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaDStream<WindowedValue<Object>> values = http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java index e2b6009..57cba50 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java @@ -20,7 +20,6 @@ package org.apache.beam.sdk.runners; import com.google.auto.value.AutoValue; -import java.util.List; import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; @@ -30,6 +29,7 @@ import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; /** * Produces {@link PipelineRunner}-specific overrides of {@link PTransform PTransforms}, and @@ -48,17 +48,15 @@ public interface PTransformOverrideFactory< /** * Returns the composite type that replacement transforms consumed from an equivalent expansion. */ - InputT getInput(List<TaggedPValue> inputs, Pipeline p); + InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p); /** * Returns a {@link Map} from the expanded values in {@code newOutput} to the values produced by * the original transform. */ - Map<PValue, ReplacementOutput> mapOutputs(List<TaggedPValue> outputs, OutputT newOutput); + Map<PValue, ReplacementOutput> mapOutputs(Map<TupleTag<?>, PValue> outputs, OutputT newOutput); - /** - * A mapping between original {@link TaggedPValue} outputs and their replacements. - */ + /** A mapping between original {@link TaggedPValue} outputs and their replacements. */ @AutoValue abstract class ReplacementOutput { public static ReplacementOutput of(TaggedPValue original, TaggedPValue replacement) { http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 972cb5b..18bf2e9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -23,6 +23,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -41,7 +42,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.TupleTag; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -103,8 +104,8 @@ public class TransformHierarchy { "Replacing a node when the graph has an unexpanded input. This is an SDK bug."); Node replacement = new Node(existing.getEnclosingNode(), transform, existing.getFullName(), input); - for (TaggedPValue output : existing.getOutputs()) { - Node producer = producers.get(output.getValue()); + for (PValue output : existing.getOutputs().values()) { + Node producer = producers.get(output); boolean producedInExisting = false; do { if (producer.equals(existing)) { @@ -114,13 +115,13 @@ public class TransformHierarchy { } } while (!producedInExisting && !producer.isRootNode()); if (producedInExisting) { - producers.remove(output.getValue()); + producers.remove(output); LOG.debug("Removed producer for value {} as it is part of a replaced composite {}", - output.getValue(), + output, existing.getFullName()); } else { LOG.debug( - "Value {} not produced in existing node {}", output.getValue(), existing.getFullName()); + "Value {} not produced in existing node {}", output, existing.getFullName()); } } existing.getEnclosingNode().replaceChild(existing, replacement); @@ -137,18 +138,18 @@ public class TransformHierarchy { */ public void finishSpecifyingInput() { // Inputs must be completely specified before they are consumed by a transform. - for (TaggedPValue inputValue : current.getInputs()) { - Node producerNode = getProducer(inputValue.getValue()); - PInput input = producerInput.remove(inputValue.getValue()); - inputValue.getValue().finishSpecifying(input, producerNode.getTransform()); + for (PValue inputValue : current.getInputs().values()) { + Node producerNode = getProducer(inputValue); + PInput input = producerInput.remove(inputValue); + inputValue.finishSpecifying(input, producerNode.getTransform()); checkState( - producers.get(inputValue.getValue()) != null, + producers.get(inputValue) != null, "Producer unknown for input %s", inputValue); checkState( - producers.get(inputValue.getValue()) != null, + producers.get(inputValue) != null, "Producer unknown for input %s", - inputValue.getValue()); + inputValue); } } @@ -163,12 +164,12 @@ public class TransformHierarchy { * nodes. */ public void setOutput(POutput output) { - for (TaggedPValue value : output.expand()) { - if (!producers.containsKey(value.getValue())) { - producers.put(value.getValue(), current); + for (PValue value : output.expand().values()) { + if (!producers.containsKey(value)) { + producers.put(value, current); } - value.getValue().finishSpecifyingOutput(unexpandedInputs.get(current), current.transform); - producerInput.put(value.getValue(), unexpandedInputs.get(current)); + value.finishSpecifyingOutput(unexpandedInputs.get(current), current.transform); + producerInput.put(value, unexpandedInputs.get(current)); } output.finishSpecifyingOutput(unexpandedInputs.get(current), current.transform); current.setOutput(output); @@ -241,11 +242,11 @@ public class TransformHierarchy { private final List<Node> parts = new ArrayList<>(); // Input to the transform, in expanded form. - private final List<TaggedPValue> inputs; + private final Map<TupleTag<?>, PValue> inputs; // TODO: track which outputs need to be exported to parent. // Output of the transform, in expanded form. - private List<TaggedPValue> outputs; + private Map<TupleTag<?>, PValue> outputs; @VisibleForTesting boolean finishedSpecifying = false; @@ -269,7 +270,7 @@ public class TransformHierarchy { this.enclosingNode = enclosingNode; this.transform = transform; this.fullName = fullName; - this.inputs = input == null ? Collections.<TaggedPValue>emptyList() : input.expand(); + this.inputs = input == null ? Collections.<TupleTag<?>, PValue>emptyMap() : input.expand(); } /** @@ -333,8 +334,8 @@ public class TransformHierarchy { private boolean returnsOthersOutput() { PTransform<?, ?> transform = getTransform(); if (outputs != null) { - for (TaggedPValue outputValue : outputs) { - if (!getProducer(outputValue.getValue()).getTransform().equals(transform)) { + for (PValue outputValue : outputs.values()) { + if (!getProducer(outputValue).getTransform().equals(transform)) { return true; } } @@ -351,8 +352,8 @@ public class TransformHierarchy { } /** Returns the transform input, in unexpanded form. */ - public List<TaggedPValue> getInputs() { - return inputs == null ? Collections.<TaggedPValue>emptyList() : inputs; + public Map<TupleTag<?>, PValue> getInputs() { + return inputs == null ? Collections.<TupleTag<?>, PValue>emptyMap() : inputs; } /** @@ -368,8 +369,8 @@ public class TransformHierarchy { // Validate that a primitive transform produces only primitive output, and a composite // transform does not produce primitive output. Set<Node> outputProducers = new HashSet<>(); - for (TaggedPValue outputValue : output.expand()) { - outputProducers.add(getProducer(outputValue.getValue())); + for (PValue outputValue : output.expand().values()) { + outputProducers.add(getProducer(outputValue)); } if (outputProducers.contains(this)) { if (!parts.isEmpty() || outputProducers.size() > 1) { @@ -412,8 +413,8 @@ public class TransformHierarchy { // Replace the outputs of the component nodes component.replaceOutputs(originalToReplacement); } - List<TaggedPValue> newOutputs = new ArrayList<>(outputs.size()); - for (TaggedPValue output : outputs) { + ImmutableMap.Builder<TupleTag<?>, PValue> newOutputsBuilder = ImmutableMap.builder(); + for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { ReplacementOutput mapping = originalToReplacement.get(output.getValue()); if (mapping != null) { if (this.equals(producers.get(mapping.getReplacement().getValue()))) { @@ -429,11 +430,12 @@ public class TransformHierarchy { "Replacing output {} with original {}", mapping.getReplacement(), mapping.getOriginal()); - newOutputs.add(TaggedPValue.of(output.getTag(), mapping.getOriginal().getValue())); + newOutputsBuilder.put(output.getKey(), mapping.getOriginal().getValue()); } else { - newOutputs.add(output); + newOutputsBuilder.put(output); } } + ImmutableMap<TupleTag<?>, PValue> newOutputs = newOutputsBuilder.build(); checkState( outputs.size() == newOutputs.size(), "Number of outputs must be stable across replacement"); @@ -441,8 +443,8 @@ public class TransformHierarchy { } /** Returns the transform output, in expanded form. */ - public List<TaggedPValue> getOutputs() { - return outputs == null ? Collections.<TaggedPValue>emptyList() : outputs; + public Map<TupleTag<?>, PValue> getOutputs() { + return outputs == null ? Collections.<TupleTag<?>, PValue>emptyMap() : outputs; } /** @@ -466,9 +468,9 @@ public class TransformHierarchy { if (!isRootNode()) { // Visit inputs. - for (TaggedPValue inputValue : inputs) { - if (visitedValues.add(inputValue.getValue())) { - visitor.visitValue(inputValue.getValue(), getProducer(inputValue.getValue())); + for (PValue inputValue : inputs.values()) { + if (visitedValues.add(inputValue)) { + visitor.visitValue(inputValue, getProducer(inputValue)); } } } @@ -489,9 +491,9 @@ public class TransformHierarchy { if (!isRootNode()) { checkNotNull(outputs, "Outputs for non-root node %s are null", getFullName()); // Visit outputs. - for (TaggedPValue pValue : outputs) { - if (visitedValues.add(pValue.getValue())) { - visitor.visitValue(pValue.getValue(), this); + for (PValue pValue : outputs.values()) { + if (visitedValues.add(pValue)) { + visitor.visitValue(pValue, this); } } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java index e78d795..8d99a62 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java @@ -18,11 +18,12 @@ package org.apache.beam.sdk.transforms; import com.google.auto.value.AutoValue; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TupleTag; /** * Represents the application of a {@link PTransform} to a specific input to produce @@ -41,12 +42,14 @@ public abstract class AppliedPTransform< // To prevent extension outside of this package. AppliedPTransform() {} - public static <InputT extends PInput, OutputT extends POutput, + public static < + InputT extends PInput, + OutputT extends POutput, TransformT extends PTransform<? super InputT, OutputT>> AppliedPTransform<InputT, OutputT, TransformT> of( String fullName, - List<TaggedPValue> input, - List<TaggedPValue> output, + Map<TupleTag<?>, PValue> input, + Map<TupleTag<?>, PValue> output, TransformT transform, Pipeline p) { return new AutoValue_AppliedPTransform<InputT, OutputT, TransformT>( @@ -55,9 +58,9 @@ public abstract class AppliedPTransform< public abstract String getFullName(); - public abstract List<TaggedPValue> getInputs(); + public abstract Map<TupleTag<?>, PValue> getInputs(); - public abstract List<TaggedPValue> getOutputs(); + public abstract Map<TupleTag<?>, PValue> getOutputs(); public abstract TransformT getTransform(); http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java index b373909..2e7dd01 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java @@ -17,8 +17,10 @@ */ package org.apache.beam.sdk.transforms.join; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; @@ -27,7 +29,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.TaggedPValue; +import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -120,12 +122,12 @@ public class KeyedPCollectionTuple<K> implements PInput { * any tag-specific information. */ @Override - public List<TaggedPValue> expand() { - List<TaggedPValue> retval = new ArrayList<>(); + public Map<TupleTag<?>, PValue> expand() { + ImmutableMap.Builder<TupleTag<?>, PValue> retval = ImmutableMap.builder(); for (TaggedKeyedPCollection<K, ?> taggedPCollection : keyedCollections) { - retval.add(TaggedPValue.of(taggedPCollection.tupleTag, taggedPCollection.pCollection)); + retval.put(taggedPCollection.tupleTag, taggedPCollection.pCollection); } - return retval; + return retval.build(); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java index 2ba0f1c..04d1bdb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java @@ -18,7 +18,7 @@ package org.apache.beam.sdk.values; import java.util.Collections; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.TextIO.Read; import org.apache.beam.sdk.transforms.Create; @@ -64,9 +64,9 @@ public class PBegin implements PInput { } @Override - public List<TaggedPValue> expand() { + public Map<TupleTag<?>, PValue> expand() { // A PBegin contains no PValues. - return Collections.emptyList(); + return Collections.emptyMap(); } ///////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java index dcb64a8..7b45deb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java @@ -18,9 +18,10 @@ package org.apache.beam.sdk.values; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Iterables; +import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Objects; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -116,7 +117,7 @@ public class PCollectionList<T> implements PInput, POutput { return new PCollectionList<>(pipeline, ImmutableList.<TaggedPValue>builder() .addAll(pcollections) - .add(Iterables.getOnlyElement(pc.expand())) + .add(TaggedPValue.of(new TupleTag<T>(), pc)) .build()); } @@ -133,10 +134,9 @@ public class PCollectionList<T> implements PInput, POutput { builder.addAll(pcollections); for (PCollection<T> pc : pcs) { if (pc.getPipeline() != pipeline) { - throw new IllegalArgumentException( - "PCollections come from different Pipelines"); + throw new IllegalArgumentException("PCollections come from different Pipelines"); } - builder.add(Iterables.getOnlyElement(pc.expand())); + builder.add(TaggedPValue.of(new TupleTag<T>(), pc)); } return new PCollectionList<>(pipeline, builder.build()); } @@ -200,7 +200,10 @@ public class PCollectionList<T> implements PInput, POutput { // Internal details below here. final Pipeline pipeline; - // ImmutableMap has a defined iteration order. + /** + * The {@link PCollection PCollections} contained by this {@link PCollectionList}, and an + * arbitrary tags associated with each. + */ final List<TaggedPValue> pcollections; PCollectionList(Pipeline pipeline) { @@ -218,8 +221,12 @@ public class PCollectionList<T> implements PInput, POutput { } @Override - public List<TaggedPValue> expand() { - return pcollections; + public Map<TupleTag<?>, PValue> expand() { + ImmutableMap.Builder<TupleTag<?>, PValue> expanded = ImmutableMap.builder(); + for (TaggedPValue tagged : pcollections) { + expanded.put(tagged.getTag(), tagged.getValue()); + } + return expanded.build(); } @Override @@ -244,11 +251,11 @@ public class PCollectionList<T> implements PInput, POutput { return false; } PCollectionList that = (PCollectionList) other; - return this.pipeline.equals(that.pipeline) && this.pcollections.equals(that.pcollections); + return this.pipeline.equals(that.pipeline) && this.getAll().equals(that.getAll()); } @Override public int hashCode() { - return Objects.hash(this.pipeline, this.pcollections); + return Objects.hash(this.pipeline, this.getAll()); } } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java index d61db51..0ab26ca 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java @@ -17,11 +17,9 @@ */ package org.apache.beam.sdk.values; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.Objects; import org.apache.beam.sdk.Pipeline; @@ -234,12 +232,8 @@ public class PCollectionTuple implements PInput, POutput { } @Override - public List<TaggedPValue> expand() { - ImmutableList.Builder<TaggedPValue> values = ImmutableList.builder(); - for (Map.Entry<TupleTag<?>, PCollection<?>> entry : pcollectionMap.entrySet()) { - values.add(TaggedPValue.of(entry.getKey(), entry.getValue())); - } - return values.build(); + public Map<TupleTag<?>, PValue> expand() { + return ImmutableMap.<TupleTag<?>, PValue>copyOf(pcollectionMap); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java index b4a3025..eb5db20 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java @@ -18,7 +18,7 @@ package org.apache.beam.sdk.values; import java.util.Collections; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.PTransform; @@ -36,9 +36,9 @@ public class PDone extends POutputValueBase { } @Override - public List<TaggedPValue> expand() { + public Map<TupleTag<?>, PValue> expand() { // A PDone contains no PValues. - return Collections.emptyList(); + return Collections.emptyMap(); } private PDone(Pipeline pipeline) { http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java index 30d4297..caf7812 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; /** @@ -43,5 +43,5 @@ public interface PInput { * * <p>Not intended to be invoked directly by user code. */ - List<TaggedPValue> expand(); + Map<TupleTag<?>, PValue> expand(); } http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java index 062f565..bb01beb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -45,7 +45,7 @@ public interface POutput { * * <p>Not intended to be invoked directly by user code. */ - List<TaggedPValue> expand(); + Map<TupleTag<?>, PValue> expand(); /** * Records that this {@code POutput} is an output of the given http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java index 4c62972..06546aa 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.transforms.PTransform; /** @@ -37,7 +37,7 @@ public interface PValue extends POutput, PInput { * never appropriate. */ @Deprecated - List<TaggedPValue> expand(); + Map<TupleTag<?>, PValue> expand(); /** * After building, finalizes this {@code PValue} to make it ready for being used as an input to a http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java index 8778597..91ee392 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java @@ -18,7 +18,7 @@ package org.apache.beam.sdk.values; import java.util.Collections; -import java.util.List; +import java.util.Map; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -133,8 +133,8 @@ public abstract class PValueBase extends POutputValueBase implements PValue { } @Override - public final List<TaggedPValue> expand() { - return Collections.singletonList(TaggedPValue.of(tag, this)); + public final Map<TupleTag<?>, PValue> expand() { + return Collections.<TupleTag<?>, PValue>singletonMap(tag, this); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java index 458d16f..3b4d599 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java @@ -20,6 +20,7 @@ package org.apache.beam.sdk.values; import com.google.auto.value.AutoValue; +import com.google.common.collect.Iterables; /** * A (TupleTag, PValue) pair used in the expansion of a {@link PInput} or {@link POutput}. @@ -30,6 +31,10 @@ public abstract class TaggedPValue { return new AutoValue_TaggedPValue(tag, value); } + public static TaggedPValue ofExpandedValue(PValue value) { + return of(Iterables.getOnlyElement(value.expand().keySet()), value); + } + /** * Returns the local tag associated with the {@link PValue}. */ http://git-wip-us.apache.org/repos/asf/beam/blob/0e5737fd/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java index efe8db4..0a5746b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java @@ -393,17 +393,21 @@ public class PipelineTest { } @Override - public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) { + public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { return p.begin(); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollection<Long> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<Long> newOutput) { + Map.Entry<TupleTag<?>, PValue> original = Iterables.getOnlyElement(outputs.entrySet()); + Map.Entry<TupleTag<?>, PValue> replacement = + Iterables.getOnlyElement(newOutput.expand().entrySet()); return Collections.<PValue, ReplacementOutput>singletonMap( newOutput, ReplacementOutput.of( - Iterables.getOnlyElement(outputs), Iterables.getOnlyElement(newOutput.expand()))); + TaggedPValue.of(original.getKey(), original.getValue()), + TaggedPValue.of(replacement.getKey(), replacement.getValue()))); } } static class UnboundedCountingInputOverride @@ -415,17 +419,21 @@ public class PipelineTest { } @Override - public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) { + public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) { return p.begin(); } @Override public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PCollection<Long> newOutput) { + Map<TupleTag<?>, PValue> outputs, PCollection<Long> newOutput) { + Map.Entry<TupleTag<?>, PValue> original = Iterables.getOnlyElement(outputs.entrySet()); + Map.Entry<TupleTag<?>, PValue> replacement = + Iterables.getOnlyElement(newOutput.expand().entrySet()); return Collections.<PValue, ReplacementOutput>singletonMap( newOutput, ReplacementOutput.of( - Iterables.getOnlyElement(outputs), Iterables.getOnlyElement(newOutput.expand()))); + TaggedPValue.of(original.getKey(), original.getValue()), + TaggedPValue.of(replacement.getKey(), replacement.getValue()))); } } }