Add some key-preserving to KeyedPValueTrackingVisitor
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/81702e67 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/81702e67 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/81702e67 Branch: refs/heads/python-sdk Commit: 81702e67b92a23849cbc8f4a16b2a619e4b477a1 Parents: 22e25a4 Author: Kenneth Knowles <k...@google.com> Authored: Thu Dec 8 11:49:15 2016 -0800 Committer: Kenneth Knowles <k...@google.com> Committed: Tue Dec 20 11:18:02 2016 -0800 ---------------------------------------------------------------------- .../beam/runners/direct/DirectRunner.java | 9 +-- .../direct/KeyedPValueTrackingVisitor.java | 35 +++++--- .../direct/KeyedPValueTrackingVisitorTest.java | 84 +++----------------- 3 files changed, 37 insertions(+), 91 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 78163c0..afa43ff 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -31,8 +31,6 @@ import java.util.Map; import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.runners.core.SplittableParDo; -import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow; -import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly; import org.apache.beam.runners.direct.DirectRunner.DirectPipelineResult; import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory; import org.apache.beam.runners.direct.ViewEvaluatorFactory.ViewOverrideFactory; @@ -306,12 +304,7 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { graphVisitor.finishSpecifyingRemainder(); @SuppressWarnings("rawtypes") - KeyedPValueTrackingVisitor keyedPValueVisitor = - KeyedPValueTrackingVisitor.create( - ImmutableSet.of( - SplittableParDo.GBKIntoKeyedWorkItems.class, - DirectGroupByKeyOnly.class, - DirectGroupAlsoByWindow.class)); + KeyedPValueTrackingVisitor keyedPValueVisitor = KeyedPValueTrackingVisitor.create(); pipeline.traverseTopologically(keyedPValueVisitor); DisplayDataValidator.validatePipeline(pipeline); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java index 7f85169..e91a768 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitor.java @@ -18,9 +18,15 @@ package org.apache.beam.runners.direct; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Predicates.in; +import static com.google.common.collect.Iterables.all; +import com.google.common.collect.ImmutableSet; import java.util.HashSet; import java.util.Set; +import org.apache.beam.runners.core.SplittableParDo; +import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow; +import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly; import org.apache.beam.sdk.Pipeline.PipelineVisitor; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.GroupByKey; @@ -38,19 +44,21 @@ import org.apache.beam.sdk.values.PValue; // TODO: Handle Key-preserving transforms when appropriate and more aggressively make PTransforms // unkeyed class KeyedPValueTrackingVisitor implements PipelineVisitor { - @SuppressWarnings("rawtypes") - private final Set<Class<? extends PTransform>> producesKeyedOutputs; + + private static final Set<Class<? extends PTransform>> PRODUCES_KEYED_OUTPUTS = + ImmutableSet.of( + SplittableParDo.GBKIntoKeyedWorkItems.class, + DirectGroupByKeyOnly.class, + DirectGroupAlsoByWindow.class); + private final Set<PValue> keyedValues; private boolean finalized; - public static KeyedPValueTrackingVisitor create( - @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) { - return new KeyedPValueTrackingVisitor(producesKeyedOutputs); + public static KeyedPValueTrackingVisitor create() { + return new KeyedPValueTrackingVisitor(); } - private KeyedPValueTrackingVisitor( - @SuppressWarnings("rawtypes") Set<Class<? extends PTransform>> producesKeyedOutputs) { - this.producesKeyedOutputs = producesKeyedOutputs; + private KeyedPValueTrackingVisitor() { this.keyedValues = new HashSet<>(); } @@ -73,7 +81,7 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor { node); if (node.isRootNode()) { finalized = true; - } else if (producesKeyedOutputs.contains(node.getTransform().getClass())) { + } else if (PRODUCES_KEYED_OUTPUTS.contains(node.getTransform().getClass())) { keyedValues.addAll(node.getOutputs()); } } @@ -83,7 +91,9 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor { @Override public void visitValue(PValue value, TransformHierarchy.Node producer) { - if (producesKeyedOutputs.contains(producer.getTransform().getClass())) { + if (PRODUCES_KEYED_OUTPUTS.contains(producer.getTransform().getClass()) + || (isKeyPreserving(producer.getTransform()) + && all(producer.getInputs(), in(keyedValues)))) { keyedValues.add(value); } } @@ -93,4 +103,9 @@ class KeyedPValueTrackingVisitor implements PipelineVisitor { finalized, "can't call getKeyedPValues before a Pipeline has been completely traversed"); return keyedValues; } + + private static boolean isKeyPreserving(PTransform<?, ?> transform) { + // There are currently no key-preserving transforms; this lays the infrastructure for them + return false; + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/81702e67/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java index eef3375..a357005 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/KeyedPValueTrackingVisitorTest.java @@ -21,9 +21,7 @@ import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; -import com.google.common.collect.ImmutableSet; import java.util.Collections; -import java.util.Set; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.VarIntCoder; @@ -33,7 +31,6 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.Keys; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; @@ -57,54 +54,20 @@ public class KeyedPValueTrackingVisitorTest { @Before public void setup() { - - @SuppressWarnings("rawtypes") - Set<Class<? extends PTransform>> producesKeyed = - ImmutableSet.<Class<? extends PTransform>>of(PrimitiveKeyer.class, CompositeKeyer.class); - visitor = KeyedPValueTrackingVisitor.create(producesKeyed); - } - - @Test - public void primitiveProducesKeyedOutputUnkeyedInputKeyedOutput() { - PCollection<Integer> keyed = - p.apply(Create.<Integer>of(1, 2, 3)).apply(new PrimitiveKeyer<Integer>()); - - p.traverseTopologically(visitor); - assertThat(visitor.getKeyedPValues(), hasItem(keyed)); - } - - @Test - public void primitiveProducesKeyedOutputKeyedInputKeyedOutut() { - PCollection<Integer> keyed = - p.apply(Create.<Integer>of(1, 2, 3)) - .apply("firstKey", new PrimitiveKeyer<Integer>()) - .apply("secondKey", new PrimitiveKeyer<Integer>()); - - p.traverseTopologically(visitor); - assertThat(visitor.getKeyedPValues(), hasItem(keyed)); - } - - @Test - public void compositeProducesKeyedOutputUnkeyedInputKeyedOutput() { - PCollection<Integer> keyed = - p.apply(Create.<Integer>of(1, 2, 3)).apply(new CompositeKeyer<Integer>()); - - p.traverseTopologically(visitor); - assertThat(visitor.getKeyedPValues(), hasItem(keyed)); + p = TestPipeline.create(); + visitor = KeyedPValueTrackingVisitor.create(); } @Test - public void compositeProducesKeyedOutputKeyedInputKeyedOutut() { - PCollection<Integer> keyed = - p.apply(Create.<Integer>of(1, 2, 3)) - .apply("firstKey", new CompositeKeyer<Integer>()) - .apply("secondKey", new CompositeKeyer<Integer>()); + public void groupByKeyProducesKeyedOutput() { + PCollection<KV<String, Iterable<Integer>>> keyed = + p.apply(Create.of(KV.of("foo", 3))) + .apply(GroupByKey.<String, Integer>create()); p.traverseTopologically(visitor); assertThat(visitor.getKeyedPValues(), hasItem(keyed)); } - @Test public void noInputUnkeyedOutput() { PCollection<KV<Integer, Iterable<Void>>> unkeyed = @@ -117,26 +80,17 @@ public class KeyedPValueTrackingVisitorTest { } @Test - public void keyedInputNotProducesKeyedOutputUnkeyedOutput() { - PCollection<Integer> onceKeyed = - p.apply(Create.<Integer>of(1, 2, 3)) - .apply(new PrimitiveKeyer<Integer>()) - .apply(ParDo.of(new IdentityFn<Integer>())); + public void keyedInputWithoutKeyPreserving() { + PCollection<KV<String, Iterable<Integer>>> onceKeyed = + p.apply(Create.of(KV.of("hello", 42))) + .apply(GroupByKey.<String, Integer>create()) + .apply(ParDo.of(new IdentityFn<KV<String, Iterable<Integer>>>())); p.traverseTopologically(visitor); assertThat(visitor.getKeyedPValues(), not(hasItem(onceKeyed))); } @Test - public void unkeyedInputNotProducesKeyedOutputUnkeyedOutput() { - PCollection<Integer> unkeyed = - p.apply(Create.<Integer>of(1, 2, 3)).apply(ParDo.of(new IdentityFn<Integer>())); - - p.traverseTopologically(visitor); - assertThat(visitor.getKeyedPValues(), not(hasItem(unkeyed))); - } - - @Test public void traverseMultipleTimesThrows() { p.apply( Create.<KV<Integer, Void>>of( @@ -161,22 +115,6 @@ public class KeyedPValueTrackingVisitorTest { visitor.getKeyedPValues(); } - private static class PrimitiveKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> { - @Override - public PCollection<K> expand(PCollection<K> input) { - return PCollection.<K>createPrimitiveOutputInternal( - input.getPipeline(), input.getWindowingStrategy(), input.isBounded()) - .setCoder(input.getCoder()); - } - } - - private static class CompositeKeyer<K> extends PTransform<PCollection<K>, PCollection<K>> { - @Override - public PCollection<K> expand(PCollection<K> input) { - return input.apply(new PrimitiveKeyer<K>()).apply(ParDo.of(new IdentityFn<K>())); - } - } - private static class IdentityFn<K> extends DoFn<K, K> { @ProcessElement public void processElement(ProcessContext c) throws Exception {