Provide local tags in PInput, POutput expansions Output an ordered colleciton in PInput and POutput expansions.
This provides information that is necessary to reconstruct a PInput or POutput from its expansion. Implement PCollectionList.equals, PCollectionTuple.equals Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/34373c21 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/34373c21 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/34373c21 Branch: refs/heads/python-sdk Commit: 34373c21ed67696235d88ef40d50e31c77b84c33 Parents: 6a05d7f Author: Thomas Groh <tg...@google.com> Authored: Tue Dec 6 11:03:52 2016 -0800 Committer: Thomas Groh <tg...@google.com> Committed: Tue Dec 20 15:18:55 2016 -0800 ---------------------------------------------------------------------- .../beam/runners/direct/DirectGraphVisitor.java | 18 +-- .../beam/runners/direct/EvaluationContext.java | 7 +- .../direct/KeyedPValueTrackingVisitor.java | 16 ++- .../beam/runners/direct/WatermarkManager.java | 19 +-- .../apache/beam/runners/spark/SparkRunner.java | 13 ++- .../beam/sdk/runners/TransformHierarchy.java | 49 ++++---- .../transforms/join/KeyedPCollectionTuple.java | 9 +- .../java/org/apache/beam/sdk/values/PBegin.java | 4 +- .../apache/beam/sdk/values/PCollectionList.java | 65 +++++++---- .../beam/sdk/values/PCollectionTuple.java | 28 ++++- .../java/org/apache/beam/sdk/values/PDone.java | 4 +- .../java/org/apache/beam/sdk/values/PInput.java | 4 +- .../org/apache/beam/sdk/values/POutput.java | 4 +- .../java/org/apache/beam/sdk/values/PValue.java | 10 ++ .../org/apache/beam/sdk/values/PValueBase.java | 11 +- .../apache/beam/sdk/values/TaggedPValue.java | 42 +++++++ .../sdk/runners/TransformHierarchyTest.java | 23 +++- .../apache/beam/sdk/transforms/ParDoTest.java | 34 ++++++ .../beam/sdk/values/PCollectionListTest.java | 117 +++++++++++++++++++ .../beam/sdk/values/PCollectionTupleTest.java | 70 +++++++++++ 20 files changed, 449 insertions(+), 98 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index 0283d03..425bbf1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; /** * Tracks the {@link AppliedPTransform AppliedPTransforms} that consume each {@link PValue} in the @@ -79,14 +80,16 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { @Override public void visitPrimitiveTransform(TransformHierarchy.Node node) { - toFinalize.removeAll(node.getInputs()); + for (TaggedPValue consumed : node.getInputs()) { + toFinalize.remove(consumed.getValue()); + } AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(node); stepNames.put(appliedTransform, genStepName()); if (node.getInputs().isEmpty()) { rootTransforms.add(appliedTransform); } else { - for (PValue value : node.getInputs()) { - primitiveConsumers.put(value, appliedTransform); + for (TaggedPValue value : node.getInputs()) { + primitiveConsumers.put(value.getValue(), appliedTransform); } } } @@ -96,15 +99,12 @@ class DirectGraphVisitor extends PipelineVisitor.Defaults { toFinalize.add(value); AppliedPTransform<?, ?, ?> appliedTransform = getAppliedTransform(producer); + if (value instanceof PCollectionView) { + views.add((PCollectionView<?>) value); + } if (!producers.containsKey(value)) { producers.put(value, appliedTransform); } - if (value instanceof PCollectionView) { - views.add((PCollectionView<?>) value); - } - if (!producers.containsKey(value)) { - producers.put(value, appliedTransform); - } } private AppliedPTransform<?, ?, ?> getAppliedTransform(TransformHierarchy.Node node) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java index cb9ddd8..bbcab8e 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; import org.joda.time.Instant; /** @@ -419,9 +420,9 @@ class EvaluationContext { } // If the PTransform has any unbounded outputs, and unbounded producers should not be shut down, // the PTransform may produce additional output. It is not done. - for (PValue output : transform.getOutput().expand()) { - if (output instanceof PCollection) { - IsBounded bounded = ((PCollection<?>) output).isBounded(); + for (TaggedPValue output : transform.getOutput().expand()) { + if (output.getValue() instanceof PCollection) { + IsBounded bounded = ((PCollection<?>) output.getValue()).isBounded(); if (bounded.equals(IsBounded.UNBOUNDED) && !options.isShutdownUnboundedProducersWithMaxWatermark()) { return false; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java index 65c41e0..32eb692 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java @@ -18,11 +18,10 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Predicates.in; -import static com.google.common.collect.Iterables.all; import com.google.common.collect.ImmutableSet; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.apache.beam.runners.core.SplittableParDo; import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow; @@ -33,6 +32,7 @@ import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; /** * A pipeline visitor that tracks all keyed {@link PValue PValues}. A {@link PValue} is keyed if it @@ -83,7 +83,10 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor { if (node.isRootNode()) { finalized = true; } else if (PRODUCES_KEYED_OUTPUTS.contains(node.getTransform().getClass())) { - keyedValues.addAll(node.getOutputs()); + List<TaggedPValue> outputs = node.getOutputs(); + for (TaggedPValue output : outputs) { + keyedValues.add(output.getValue()); + } } } @@ -92,9 +95,12 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor { @Override public void visitValue(PValue value, TransformHierarchy.Node producer) { + boolean inputsAreKeyed = true; + for (TaggedPValue input : producer.getInputs()) { + inputsAreKeyed = inputsAreKeyed && keyedValues.contains(input.getValue()); + } if (PRODUCES_KEYED_OUTPUTS.contains(producer.getTransform().getClass()) - || (isKeyPreserving(producer.getTransform()) - && all(producer.getInputs(), in(keyedValues)))) { + || (isKeyPreserving(producer.getTransform()) && inputsAreKeyed)) { keyedValues.add(value); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 247b1cc..7bed751 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -57,7 +57,7 @@ import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.TimerInternals; import org.apache.beam.sdk.util.TimerInternals.TimerData; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; import org.joda.time.Instant; /** @@ -755,13 +755,14 @@ public class WatermarkManager { private Collection<Watermark> getInputProcessingWatermarks(AppliedPTransform<?, ?, ?> transform) { ImmutableList.Builder<Watermark> inputWmsBuilder = ImmutableList.builder(); - Collection<? extends PValue> inputs = transform.getInput().expand(); + List<TaggedPValue> inputs = transform.getInput().expand(); if (inputs.isEmpty()) { inputWmsBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs) { + for (TaggedPValue pvalue : inputs) { Watermark producerOutputWatermark = - getTransformWatermark(graph.getProducer(pvalue)).synchronizedProcessingOutputWatermark; + getTransformWatermark(graph.getProducer(pvalue.getValue())) + .synchronizedProcessingOutputWatermark; inputWmsBuilder.add(producerOutputWatermark); } return inputWmsBuilder.build(); @@ -769,13 +770,13 @@ public class WatermarkManager { private List<Watermark> getInputWatermarks(AppliedPTransform<?, ?, ?> transform) { ImmutableList.Builder<Watermark> inputWatermarksBuilder = ImmutableList.builder(); - Collection<? extends PValue> inputs = transform.getInput().expand(); + List<TaggedPValue> inputs = transform.getInput().expand(); if (inputs.isEmpty()) { inputWatermarksBuilder.add(THE_END_OF_TIME); } - for (PValue pvalue : inputs) { + for (TaggedPValue pvalue : inputs) { Watermark producerOutputWatermark = - getTransformWatermark(graph.getProducer(pvalue)).outputWatermark; + getTransformWatermark(graph.getProducer(pvalue.getValue())).outputWatermark; inputWatermarksBuilder.add(producerOutputWatermark); } List<Watermark> inputCollectionWatermarks = inputWatermarksBuilder.build(); @@ -959,8 +960,8 @@ public class WatermarkManager { WatermarkUpdate updateResult = myWatermarks.refresh(); if (updateResult.isAdvanced()) { Set<AppliedPTransform<?, ?, ?>> additionalRefreshes = new HashSet<>(); - for (PValue outputPValue : toRefresh.getOutput().expand()) { - additionalRefreshes.addAll(graph.getPrimitiveConsumers(outputPValue)); + for (TaggedPValue outputPValue : toRefresh.getOutput().expand()) { + additionalRefreshes.addAll(graph.getPrimitiveConsumers(outputPValue.getValue())); } return additionalRefreshes; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java index 3d98b87..92c07bb 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkRunner.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.spark.Accumulator; import org.apache.spark.SparkEnv$; import org.apache.spark.api.java.JavaSparkContext; @@ -282,7 +283,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { if (node.getInputs().size() != 1) { return false; } - PValue input = Iterables.getOnlyElement(node.getInputs()); + PValue input = Iterables.getOnlyElement(node.getInputs()).getValue(); if (!(input instanceof PCollection) || ((PCollection) input).getWindowingStrategy().getWindowFn().isNonMerging()) { return false; @@ -338,7 +339,7 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { //--- determine if node is bounded/unbounded. // usually, the input determines if the PCollection to apply the next transformation to // is BOUNDED or UNBOUNDED, meaning RDD/DStream. - Collection<? extends PValue> pValues; + Collection<TaggedPValue> pValues; if (node.getInputs().isEmpty()) { // in case of a PBegin, it's the output. pValues = node.getOutputs(); @@ -353,15 +354,15 @@ public final class SparkRunner extends PipelineRunner<SparkPipelineResult> { : translator.translateUnbounded(transformClass); } - private PCollection.IsBounded isBoundedCollection(Collection<? extends PValue> pValues) { + private PCollection.IsBounded isBoundedCollection(Collection<TaggedPValue> pValues) { // anything that is not a PCollection, is BOUNDED. // For PCollections: // BOUNDED behaves as the Identity Element, BOUNDED + BOUNDED = BOUNDED // while BOUNDED + UNBOUNDED = UNBOUNDED. PCollection.IsBounded isBounded = PCollection.IsBounded.BOUNDED; - for (PValue pValue: pValues) { - if (pValue instanceof PCollection) { - isBounded = isBounded.and(((PCollection) pValue).isBounded()); + for (TaggedPValue pValue: pValues) { + if (pValue.getValue() instanceof PCollection) { + isBounded = isBounded.and(((PCollection) pValue.getValue()).isBounded()); } else { isBounded = isBounded.and(PCollection.IsBounded.BOUNDED); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java index 33d5231..29e7fcb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/TransformHierarchy.java @@ -37,6 +37,7 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; /** * Captures information about a collection of transformations and their @@ -84,10 +85,12 @@ public class TransformHierarchy { */ public void finishSpecifyingInput() { // Inputs must be completely specified before they are consumed by a transform. - for (PValue inputValue : current.getInputs()) { - inputValue.finishSpecifying(); - checkState(producers.get(inputValue) != null, "Producer unknown for input %s", inputValue); - inputValue.finishSpecifying(); + for (TaggedPValue inputValue : current.getInputs()) { + inputValue.getValue().finishSpecifying(); + checkState( + producers.get(inputValue.getValue()) != null, + "Producer unknown for input %s", + inputValue.getValue()); } } @@ -103,9 +106,9 @@ public class TransformHierarchy { */ public void setOutput(POutput output) { output.finishSpecifyingOutput(); - for (PValue value : output.expand()) { - if (!producers.containsKey(value)) { - producers.put(value, current); + for (TaggedPValue value : output.expand()) { + if (!producers.containsKey(value.getValue())) { + producers.put(value.getValue(), current); } } current.setOutput(output); @@ -133,8 +136,8 @@ public class TransformHierarchy { */ List<Node> getProducingTransforms(POutput output) { List<Node> producingTransforms = new ArrayList<>(); - for (PValue value : output.expand()) { - Node producer = getProducer(value); + for (TaggedPValue value : output.expand()) { + Node producer = getProducer(value.getValue()); if (producer != null) { producingTransforms.add(producer); } @@ -238,8 +241,8 @@ public class TransformHierarchy { private boolean returnsOthersOutput() { PTransform<?, ?> transform = getTransform(); if (output != null) { - for (PValue outputValue : output.expand()) { - if (!getProducer(outputValue).getTransform().equals(transform)) { + for (TaggedPValue outputValue : output.expand()) { + if (!getProducer(outputValue.getValue()).getTransform().equals(transform)) { return true; } } @@ -256,8 +259,8 @@ public class TransformHierarchy { } /** Returns the transform input, in unexpanded form. */ - public Collection<? extends PValue> getInputs() { - return input == null ? Collections.<PValue>emptyList() : input.expand(); + public List<TaggedPValue> getInputs() { + return input == null ? Collections.<TaggedPValue>emptyList() : input.expand(); } /** @@ -273,8 +276,8 @@ public class TransformHierarchy { // Validate that a primitive transform produces only primitive output, and a composite // transform does not produce primitive output. Set<Node> outputProducers = new HashSet<>(); - for (PValue outputValue : output.expand()) { - outputProducers.add(getProducer(outputValue)); + for (TaggedPValue outputValue : output.expand()) { + outputProducers.add(getProducer(outputValue.getValue())); } if (outputProducers.contains(this) && outputProducers.size() != 1) { Set<String> otherProducerNames = new HashSet<>(); @@ -296,8 +299,8 @@ public class TransformHierarchy { } /** Returns the transform output, in unexpanded form. */ - public Collection<? extends PValue> getOutputs() { - return output == null ? Collections.<PValue>emptyList() : output.expand(); + public List<TaggedPValue> getOutputs() { + return output == null ? Collections.<TaggedPValue>emptyList() : output.expand(); } /** @@ -320,9 +323,9 @@ public class TransformHierarchy { if (!isRootNode()) { // Visit inputs. - for (PValue inputValue : input.expand()) { - if (visitedValues.add(inputValue)) { - visitor.visitValue(inputValue, getProducer(inputValue)); + for (TaggedPValue inputValue : input.expand()) { + if (visitedValues.add(inputValue.getValue())) { + visitor.visitValue(inputValue.getValue(), getProducer(inputValue.getValue())); } } } @@ -342,9 +345,9 @@ public class TransformHierarchy { if (!isRootNode()) { // Visit outputs. - for (PValue pValue : output.expand()) { - if (visitedValues.add(pValue)) { - visitor.visitValue(pValue, this); + for (TaggedPValue pValue : output.expand()) { + if (visitedValues.add(pValue.getValue())) { + visitor.visitValue(pValue.getValue(), this); } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java index 67b819f..13d4ee1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/join/KeyedPCollectionTuple.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.transforms.join; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; @@ -28,7 +27,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; -import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -121,10 +120,10 @@ public class KeyedPCollectionTuple<K> implements PInput { * any tag-specific information. */ @Override - public Collection<? extends PValue> expand() { - List<PCollection<?>> retval = new ArrayList<>(); + public List<TaggedPValue> expand() { + List<TaggedPValue> retval = new ArrayList<>(); for (TaggedKeyedPCollection<K, ?> taggedPCollection : keyedCollections) { - retval.add(taggedPCollection.pCollection); + retval.add(TaggedPValue.of(taggedPCollection.tupleTag, taggedPCollection.pCollection)); } return retval; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java index f1dbb37..9aa4615 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PBegin.java @@ -17,8 +17,8 @@ */ package org.apache.beam.sdk.values; -import java.util.Collection; import java.util.Collections; +import java.util.List; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.TextIO.Read; import org.apache.beam.sdk.transforms.Create; @@ -64,7 +64,7 @@ public class PBegin implements PInput { } @Override - public Collection<? extends PValue> expand() { + public List<TaggedPValue> expand() { // A PBegin contains no PValues. return Collections.emptyList(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java index 4c9e220..e4bb7c5 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionList.java @@ -18,11 +18,10 @@ package org.apache.beam.sdk.values; import com.google.common.collect.ImmutableList; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; +import com.google.common.collect.Iterables; import java.util.Iterator; import java.util.List; +import java.util.Objects; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Flatten; @@ -115,9 +114,9 @@ public class PCollectionList<T> implements PInput, POutput { "PCollections come from different Pipelines"); } return new PCollectionList<>(pipeline, - new ImmutableList.Builder<PCollection<T>>() + ImmutableList.<TaggedPValue>builder() .addAll(pcollections) - .add(pc) + .add(Iterables.getOnlyElement(pc.expand())) .build()); } @@ -130,15 +129,16 @@ public class PCollectionList<T> implements PInput, POutput { * part of the same {@link Pipeline}. */ public PCollectionList<T> and(Iterable<PCollection<T>> pcs) { - List<PCollection<T>> copy = new ArrayList<>(pcollections); + ImmutableList.Builder<TaggedPValue> builder = ImmutableList.builder(); + builder.addAll(pcollections); for (PCollection<T> pc : pcs) { if (pc.getPipeline() != pipeline) { throw new IllegalArgumentException( "PCollections come from different Pipelines"); } - copy.add(pc); + builder.add(Iterables.getOnlyElement(pc.expand())); } - return new PCollectionList<>(pipeline, copy); + return new PCollectionList<>(pipeline, builder.build()); } /** @@ -155,7 +155,9 @@ public class PCollectionList<T> implements PInput, POutput { * {@code [0..size()-1]}. */ public PCollection<T> get(int index) { - return pcollections.get(index); + @SuppressWarnings("unchecked") // Type-safe by construction + PCollection<T> value = (PCollection<T>) pcollections.get(index).getValue(); + return value; } /** @@ -163,7 +165,13 @@ public class PCollectionList<T> implements PInput, POutput { * {@link PCollectionList}. */ public List<PCollection<T>> getAll() { - return pcollections; + ImmutableList.Builder<PCollection<T>> res = ImmutableList.builder(); + for (TaggedPValue value : pcollections) { + @SuppressWarnings("unchecked") // Type-safe by construction + PCollection<T> typedValue = (PCollection<T>) value.getValue(); + res.add(typedValue); + } + return res.build(); } /** @@ -192,15 +200,16 @@ public class PCollectionList<T> implements PInput, POutput { // Internal details below here. final Pipeline pipeline; - final List<PCollection<T>> pcollections; + // ImmutableMap has a defined iteration order. + final List<TaggedPValue> pcollections; PCollectionList(Pipeline pipeline) { - this(pipeline, new ArrayList<PCollection<T>>()); + this(pipeline, ImmutableList.<TaggedPValue>of()); } - PCollectionList(Pipeline pipeline, List<PCollection<T>> pcollections) { + PCollectionList(Pipeline pipeline, List<TaggedPValue> values) { this.pipeline = pipeline; - this.pcollections = Collections.unmodifiableList(pcollections); + this.pcollections = ImmutableList.copyOf(values); } @Override @@ -209,14 +218,16 @@ public class PCollectionList<T> implements PInput, POutput { } @Override - public Collection<? extends PValue> expand() { + public List<TaggedPValue> expand() { return pcollections; } @Override public void recordAsOutput(AppliedPTransform<?, ?, ?> transform) { int i = 0; - for (PCollection<T> pc : pcollections) { + for (TaggedPValue tpv : pcollections) { + @SuppressWarnings("unchecked") + PCollection<T> pc = (PCollection<T>) tpv.getValue(); pc.recordAsOutput(transform, "out" + i); i++; } @@ -224,15 +235,29 @@ public class PCollectionList<T> implements PInput, POutput { @Override public void finishSpecifying() { - for (PCollection<T> pc : pcollections) { - pc.finishSpecifying(); + for (TaggedPValue pc : pcollections) { + pc.getValue().finishSpecifying(); } } @Override public void finishSpecifyingOutput() { - for (PCollection<T> pc : pcollections) { - pc.finishSpecifyingOutput(); + for (TaggedPValue pc : pcollections) { + pc.getValue().finishSpecifyingOutput(); } } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PCollectionList)) { + return false; + } + PCollectionList that = (PCollectionList) other; + return this.pipeline.equals(that.pipeline) && this.pcollections.equals(that.pcollections); + } + + @Override + public int hashCode() { + return Objects.hash(this.pipeline, this.pcollections); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java index 727d882..6afe59e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionTuple.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.values; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.Objects; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -178,7 +180,7 @@ public class PCollectionTuple implements PInput, POutput { ///////////////////////////////////////////////////////////////////////////// // Internal details below here. - Pipeline pipeline; + final Pipeline pipeline; final Map<TupleTag<?>, PCollection<?>> pcollectionMap; PCollectionTuple(Pipeline pipeline) { @@ -232,8 +234,12 @@ public class PCollectionTuple implements PInput, POutput { } @Override - public Collection<? extends PValue> expand() { - return pcollectionMap.values(); + public List<TaggedPValue> expand() { + ImmutableList.Builder<TaggedPValue> values = ImmutableList.builder(); + for (Map.Entry<TupleTag<?>, PCollection<?>> entry : pcollectionMap.entrySet()) { + values.add(TaggedPValue.of(entry.getKey(), entry.getValue())); + } + return values.build(); } @Override @@ -261,4 +267,18 @@ public class PCollectionTuple implements PInput, POutput { pc.finishSpecifyingOutput(); } } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PCollectionTuple)) { + return false; + } + PCollectionTuple that = (PCollectionTuple) other; + return this.pipeline.equals(that.pipeline) && this.pcollectionMap.equals(that.pcollectionMap); + } + + @Override + public int hashCode() { + return Objects.hash(this.pipeline, this.pcollectionMap); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java index 9e8cae4..b4a3025 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PDone.java @@ -17,8 +17,8 @@ */ package org.apache.beam.sdk.values; -import java.util.Collection; import java.util.Collections; +import java.util.List; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.PTransform; @@ -36,7 +36,7 @@ public class PDone extends POutputValueBase { } @Override - public Collection<? extends PValue> expand() { + public List<TaggedPValue> expand() { // A PDone contains no PValues. return Collections.emptyList(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java index f938aeb..a27b939 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PInput.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import java.util.Collection; +import java.util.List; import org.apache.beam.sdk.Pipeline; /** @@ -43,7 +43,7 @@ public interface PInput { * * <p>Not intended to be invoked directly by user code. */ - Collection<? extends PValue> expand(); + List<TaggedPValue> expand(); /** * After building, finalizes this {@code PInput} to make it ready for http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java index 27a280f..e5d4504 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/POutput.java @@ -17,7 +17,7 @@ */ package org.apache.beam.sdk.values; -import java.util.Collection; +import java.util.List; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -45,7 +45,7 @@ public interface POutput { * * <p>Not intended to be invoked directly by user code. */ - Collection<? extends PValue> expand(); + List<TaggedPValue> expand(); /** * Records that this {@code POutput} is an output of the given http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java index 0cee2ca..e6dbaf7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValue.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.values; +import java.util.List; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -36,4 +37,13 @@ public interface PValue extends POutput, PInput { * <p>For internal use only. */ AppliedPTransform<?, ?, ?> getProducingTransformInternal(); + + /** + * {@inheritDoc}. + * + * <p>A {@link PValue} always expands into itself. Calling {@link #expand()} on a PValue is almost + * never appropriate. + */ + @Deprecated + List<TaggedPValue> expand(); } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java index 685e32f..3a10d5d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PValueBase.java @@ -17,8 +17,8 @@ */ package org.apache.beam.sdk.values; -import java.util.Collection; import java.util.Collections; +import java.util.List; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; @@ -91,6 +91,11 @@ public abstract class PValueBase extends POutputValueBase implements PValue { private String name; /** + * A local {@link TupleTag} used in the expansion of this {@link PValueBase}. + */ + private TupleTag<?> tag = new TupleTag<>(); + + /** * Whether this {@link PValueBase} has been finalized, and its core * properties, e.g., name, can no longer be changed. */ @@ -128,8 +133,8 @@ public abstract class PValueBase extends POutputValueBase implements PValue { } @Override - public Collection<? extends PValue> expand() { - return Collections.singletonList(this); + public final List<TaggedPValue> expand() { + return Collections.singletonList(TaggedPValue.of(tag, this)); } @Override http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java new file mode 100644 index 0000000..458d16f --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/TaggedPValue.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package org.apache.beam.sdk.values; + +import com.google.auto.value.AutoValue; + +/** + * A (TupleTag, PValue) pair used in the expansion of a {@link PInput} or {@link POutput}. + */ +@AutoValue +public abstract class TaggedPValue { + public static TaggedPValue of(TupleTag<?> tag, PValue value) { + return new AutoValue_TaggedPValue(tag, value); + } + + /** + * Returns the local tag associated with the {@link PValue}. + */ + public abstract TupleTag<?> getTag(); + + /** + * Returns the {@link PValue}. + */ + public abstract PValue getValue(); +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java index 2327459..d790d39 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/runners/TransformHierarchyTest.java @@ -22,7 +22,10 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; +import com.google.common.base.Function; +import com.google.common.collect.Lists; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.io.CountingSource; @@ -38,6 +41,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.TaggedPValue; import org.hamcrest.Matchers; import org.junit.Before; import org.junit.Rule; @@ -181,14 +185,16 @@ public class TransformHierarchyTest { assertThat(hierarchy.getCurrent(), equalTo(primitiveNode)); hierarchy.setOutput(created); hierarchy.popNode(); - assertThat(primitiveNode.getOutputs(), Matchers.<PValue>containsInAnyOrder(created)); - assertThat(primitiveNode.getInputs(), Matchers.<PValue>emptyIterable()); + assertThat( + fromTaggedValues(primitiveNode.getOutputs()), Matchers.<PValue>containsInAnyOrder(created)); + assertThat(primitiveNode.getInputs(), Matchers.<TaggedPValue>emptyIterable()); assertThat(primitiveNode.getTransform(), Matchers.<PTransform<?, ?>>equalTo(read)); assertThat(primitiveNode.getEnclosingNode(), equalTo(compositeNode)); hierarchy.setOutput(created); // The composite is listed as outputting a PValue created by the contained primitive - assertThat(compositeNode.getOutputs(), Matchers.<PValue>containsInAnyOrder(created)); + assertThat( + fromTaggedValues(compositeNode.getOutputs()), Matchers.<PValue>containsInAnyOrder(created)); // The producer of that PValue is still the primitive in which it is first output assertThat(hierarchy.getProducer(created), equalTo(primitiveNode)); hierarchy.popNode(); @@ -226,4 +232,15 @@ public class TransformHierarchyTest { assertThat(visitedValuesInVisitor, Matchers.<PValue>containsInAnyOrder(created, mapped)); assertThat(visitedValuesInVisitor, equalTo(visitedValues)); } + + private static List<PValue> fromTaggedValues(List<TaggedPValue> taggedValues) { + return Lists.transform( + taggedValues, + new Function<TaggedPValue, PValue>() { + @Override + public PValue apply(TaggedPValue input) { + return input.getValue(); + } + }); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 3a47fc7..fa8874c 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -29,6 +29,7 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.hamcrest.collection.IsIterableContainingInOrder.contains; import static org.junit.Assert.assertArrayEquals; @@ -50,6 +51,7 @@ import java.util.List; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.io.CountingInput; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; @@ -86,6 +88,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.hamcrest.Matchers; import org.joda.time.Duration; import org.joda.time.Instant; import org.joda.time.MutableDateTime; @@ -864,6 +867,37 @@ public class ParDoTest implements Serializable { } @Test + public void testMultiOutputAppliedMultipleTimesDifferentOutputs() { + pipeline.enableAbandonedNodeEnforcement(false); + PCollection<Long> longs = pipeline.apply(CountingInput.unbounded()); + + TupleTag<Long> mainOut = new TupleTag<>(); + final TupleTag<String> sideOutOne = new TupleTag<>(); + final TupleTag<Integer> sideOutTwo = new TupleTag<>(); + DoFn<Long, Long> fn = + new DoFn<Long, Long>() { + @ProcessElement + public void processElement(ProcessContext cxt) { + cxt.output(cxt.element()); + cxt.sideOutput(sideOutOne, Long.toString(cxt.element())); + cxt.sideOutput(sideOutTwo, Long.valueOf(cxt.element()).intValue()); + } + }; + + ParDo.BoundMulti<Long, Long> parDo = + ParDo.of(fn).withOutputTags(mainOut, TupleTagList.of(sideOutOne).and(sideOutTwo)); + PCollectionTuple firstApplication = longs.apply("first", parDo); + PCollectionTuple secondApplication = longs.apply("second", parDo); + assertThat(firstApplication, not(equalTo(secondApplication))); + assertThat( + firstApplication.getAll().keySet(), + Matchers.<TupleTag<?>>containsInAnyOrder(mainOut, sideOutOne, sideOutTwo)); + assertThat( + secondApplication.getAll().keySet(), + Matchers.<TupleTag<?>>containsInAnyOrder(mainOut, sideOutOne, sideOutTwo)); + } + + @Test @Category(RunnableOnService.class) public void testParDoInCustomTransform() { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionListTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionListTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionListTest.java index f76bf7e..2482f32 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionListTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionListTest.java @@ -18,10 +18,22 @@ package org.apache.beam.sdk.values; import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; +import com.google.common.collect.ImmutableList; +import com.google.common.testing.EqualsTester; import java.util.Collections; +import java.util.List; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.CountingInput; +import org.apache.beam.sdk.io.CountingInput.BoundedCountingInput; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.hamcrest.Matchers; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -44,4 +56,109 @@ public class PCollectionListTest { + "or must first call empty(Pipeline)")); } } + + @Test + public void testIterationOrder() { + Pipeline p = TestPipeline.create(); + PCollection<Long> createOne = p.apply("CreateOne", Create.of(1L, 2L, 3L)); + PCollection<Long> boundedCount = p.apply("CountBounded", CountingInput.upTo(23L)); + PCollection<Long> unboundedCount = p.apply("CountUnbounded", CountingInput.unbounded()); + PCollection<Long> createTwo = p.apply("CreateTwo", Create.of(-1L, -2L)); + PCollection<Long> maxRecordsCount = + p.apply("CountLimited", CountingInput.unbounded().withMaxNumRecords(22L)); + + ImmutableList<PCollection<Long>> counts = + ImmutableList.of(boundedCount, maxRecordsCount, unboundedCount); + // Build a PCollectionList from a list. This should have the same order as the input list. + PCollectionList<Long> pcList = PCollectionList.of(counts); + // Contains is the order-dependent matcher + assertThat( + pcList.getAll(), + contains(boundedCount, maxRecordsCount, unboundedCount)); + + // A list that is expanded with builder methods has the added value at the end + PCollectionList<Long> withOneCreate = pcList.and(createTwo); + assertThat( + withOneCreate.getAll(), contains(boundedCount, maxRecordsCount, unboundedCount, createTwo)); + + // Lists that are built entirely from the builder return outputs in the order they were added + PCollectionList<Long> fromEmpty = + PCollectionList.<Long>empty(p) + .and(unboundedCount) + .and(createOne) + .and(ImmutableList.of(boundedCount, maxRecordsCount)); + assertThat( + fromEmpty.getAll(), contains(unboundedCount, createOne, boundedCount, maxRecordsCount)); + + List<TaggedPValue> expansion = fromEmpty.expand(); + // TaggedPValues are stable between expansions + assertThat(expansion, equalTo(fromEmpty.expand())); + // TaggedPValues are equivalent between equivalent lists + assertThat( + expansion, + equalTo( + PCollectionList.of(unboundedCount) + .and(createOne) + .and(boundedCount) + .and(maxRecordsCount) + .expand())); + + List<PCollection<Long>> expectedList = + ImmutableList.of(unboundedCount, createOne, boundedCount, maxRecordsCount); + for (int i = 0; i < expansion.size(); i++) { + assertThat( + "Index " + i + " should have equal PValue", + expansion.get(i).getValue(), + Matchers.<PValue>equalTo(expectedList.get(i))); + } + } + + @Test + public void testEquals() { + Pipeline p = TestPipeline.create(); + PCollection<String> first = p.apply("Meta", Create.of("foo", "bar")); + PCollection<String> second = p.apply("Pythonic", Create.of("spam, ham")); + PCollection<String> third = p.apply("Syntactic", Create.of("eggs", "baz")); + + EqualsTester tester = new EqualsTester(); + tester.addEqualityGroup(PCollectionList.empty(p), PCollectionList.empty(p)); + tester.addEqualityGroup(PCollectionList.of(first).and(second)); + // Constructors should all produce equivalent + tester.addEqualityGroup( + PCollectionList.of(first).and(second).and(third), + PCollectionList.of(first).and(second).and(third), + PCollectionList.<String>empty(p).and(first).and(second).and(third), + PCollectionList.of(ImmutableList.of(first, second, third)), + PCollectionList.of(first).and(ImmutableList.of(second, third)), + PCollectionList.of(ImmutableList.of(first, second)).and(third)); + // Order is considered + tester.addEqualityGroup(PCollectionList.of(first).and(third).and(second)); + tester.addEqualityGroup(PCollectionList.empty(TestPipeline.create())); + + tester.testEquals(); + } + + @Test + public void testExpansionOrderWithDuplicates() { + TestPipeline p = TestPipeline.create(); + BoundedCountingInput count = CountingInput.upTo(10L); + PCollection<Long> firstCount = p.apply("CountFirst", count); + PCollection<Long> secondCount = p.apply("CountSecond", count); + + PCollectionList<Long> counts = + PCollectionList.of(firstCount).and(secondCount).and(firstCount).and(firstCount); + + ImmutableList<PCollection<Long>> expectedOrder = + ImmutableList.of(firstCount, secondCount, firstCount, firstCount); + PCollectionList<Long> reconstructed = PCollectionList.empty(p); + assertThat(counts.expand(), hasSize(4)); + for (int i = 0; i < 4; i++) { + PValue value = counts.expand().get(i).getValue(); + assertThat( + "Index " + i + " should be equal", value, + Matchers.<PValue>equalTo(expectedOrder.get(i))); + reconstructed = reconstructed.and((PCollection<Long>) value); + } + assertThat(reconstructed, equalTo(counts)); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/34373c21/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java index b5351da..7d767cf 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/values/PCollectionTupleTest.java @@ -17,21 +17,31 @@ */ package org.apache.beam.sdk.values; +import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.common.collect.ImmutableMap; +import com.google.common.testing.EqualsTester; import java.io.Serializable; import java.util.Arrays; import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.CountingInput; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.PCollection.IsBounded; +import org.hamcrest.Matchers; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -93,4 +103,64 @@ public final class PCollectionTupleTest implements Serializable { pipeline.run(); } + @Test + public void testEquals() { + TestPipeline p = TestPipeline.create(); + TupleTag<Long> longTag = new TupleTag<>(); + PCollection<Long> longs = p.apply(CountingInput.unbounded()); + TupleTag<String> strTag = new TupleTag<>(); + PCollection<String> strs = p.apply(Create.of("foo", "bar")); + + EqualsTester tester = new EqualsTester(); + // Empty tuples in the same pipeline are equal + tester.addEqualityGroup(PCollectionTuple.empty(p), PCollectionTuple.empty(p)); + + tester.addEqualityGroup(PCollectionTuple.of(longTag, longs).and(strTag, strs), + PCollectionTuple.of(longTag, longs).and(strTag, strs)); + + tester.addEqualityGroup(PCollectionTuple.of(longTag, longs)); + tester.addEqualityGroup(PCollectionTuple.of(strTag, strs)); + + TestPipeline otherPipeline = TestPipeline.create(); + // Empty tuples in different pipelines are not equal + tester.addEqualityGroup(PCollectionTuple.empty(otherPipeline)); + tester.testEquals(); + } + + @Test + public void testExpandHasMatchingTags() { + TupleTag<Integer> intTag = new TupleTag<>(); + TupleTag<String> strTag = new TupleTag<>(); + TupleTag<Long> longTag = new TupleTag<>(); + + Pipeline p = TestPipeline.create(); + PCollection<Long> longs = p.apply(CountingInput.upTo(100L)); + PCollection<String> strs = p.apply(Create.of("foo", "bar", "baz")); + PCollection<Integer> ints = longs.apply(MapElements.via(new SimpleFunction<Long, Integer>() { + @Override + public Integer apply(Long input) { + return input.intValue(); + } + })); + + Map<TupleTag<?>, PCollection<?>> pcsByTag = + ImmutableMap.<TupleTag<?>, PCollection<?>>builder() + .put(strTag, strs) + .put(intTag, ints) + .put(longTag, longs) + .build(); + PCollectionTuple tuple = + PCollectionTuple.of(intTag, ints).and(longTag, longs).and(strTag, strs); + assertThat(tuple.getAll(), equalTo(pcsByTag)); + PCollectionTuple reconstructed = PCollectionTuple.empty(p); + for (TaggedPValue taggedValue : tuple.expand()) { + TupleTag<?> tag = taggedValue.getTag(); + PValue value = taggedValue.getValue(); + assertThat("The tag should map back to the value", tuple.get(tag), equalTo(value)); + assertThat(value, Matchers.<PValue>equalTo(pcsByTag.get(tag))); + reconstructed = reconstructed.and(tag, (PCollection) value); + } + + assertThat(reconstructed, equalTo(tuple)); + } }