Streaming sources tracking test.

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/add87166
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/add87166
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/add87166

Branch: refs/heads/master
Commit: add87166fd445d634f6faf0f838f234c3f908c2e
Parents: 9784f20
Author: Sela <ans...@paypal.com>
Authored: Sun Feb 12 19:23:28 2017 +0200
Committer: Sela <ans...@paypal.com>
Committed: Mon Feb 20 11:30:16 2017 +0200

----------------------------------------------------------------------
 .../translation/streaming/UnboundedDataset.java |   6 +
 .../streaming/TrackStreamingSourcesTest.java    | 152 +++++++++++++++++++
 2 files changed, 158 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/add87166/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
index 80c0515..08d1ab6 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.runners.spark.translation.streaming;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Iterables;
 
 import java.util.ArrayList;
@@ -72,6 +73,11 @@ public class UnboundedDataset<T> implements Dataset {
     this.streamingSources.add(queuedStreamIds.decrementAndGet());
   }
 
+  @VisibleForTesting
+  public static void resetQueuedStreamIds() {
+    queuedStreamIds.set(0);
+  }
+
   @SuppressWarnings("ConstantConditions")
   JavaDStream<WindowedValue<T>> getDStream() {
     if (dStream == null) {

http://git-wip-us.apache.org/repos/asf/beam/blob/add87166/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
new file mode 100644
index 0000000..f102ac8
--- /dev/null
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
@@ -0,0 +1,152 @@
+package org.apache.beam.runners.spark.translation.streaming;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.core.IsEqual.equalTo;
+import static org.junit.Assert.assertThat;
+
+import java.util.Collections;
+import java.util.List;
+import org.apache.beam.runners.spark.ReuseSparkContext;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.runners.spark.io.CreateStream;
+import org.apache.beam.runners.spark.translation.Dataset;
+import org.apache.beam.runners.spark.translation.EvaluationContext;
+import org.apache.beam.runners.spark.translation.SparkContextFactory;
+import org.apache.beam.runners.spark.translation.TransformTranslator;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.spark.SparkStatusTracker;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.streaming.Duration;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+
+
+/**
+ * A test suite that tests tracking of the streaming sources created an
+ * {@link 
org.apache.beam.runners.spark.translation.streaming.UnboundedDataset}.
+ */
+public class TrackStreamingSourcesTest {
+
+  @Rule
+  public ReuseSparkContext reuseContext = ReuseSparkContext.yes();
+
+  private static final transient SparkPipelineOptions options =
+      PipelineOptionsFactory.create().as(SparkPipelineOptions.class);
+
+  @Before
+  public void before() {
+    UnboundedDataset.resetQueuedStreamIds();
+    StreamingSourceTracker.numAssertions = 0;
+  }
+
+  @Test
+  public void testTrackSingle() {
+    options.setRunner(SparkRunner.class);
+    JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
+    JavaStreamingContext jssc = new JavaStreamingContext(jsc,
+        new Duration(options.getBatchIntervalMillis()));
+
+    Pipeline p = Pipeline.create(options);
+
+    CreateStream.QueuedValues<Integer> queueStream =
+        CreateStream.fromQueue(Collections.<Iterable<Integer>>emptyList());
+
+    p.apply(queueStream).setCoder(VarIntCoder.of())
+        .apply(ParDo.of(new PassthroughFn<>()));
+
+    p.traverseTopologically(new StreamingSourceTracker(jssc, p, 
ParDo.Bound.class,  -1));
+    assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
+  }
+
+  @Test
+  public void testTrackFlattened() {
+    options.setRunner(SparkRunner.class);
+    JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
+    JavaStreamingContext jssc = new JavaStreamingContext(jsc,
+        new Duration(options.getBatchIntervalMillis()));
+
+    Pipeline p = Pipeline.create(options);
+
+    CreateStream.QueuedValues<Integer> queueStream1 =
+        CreateStream.fromQueue(Collections.<Iterable<Integer>>emptyList());
+    CreateStream.QueuedValues<Integer> queueStream2 =
+        CreateStream.fromQueue(Collections.<Iterable<Integer>>emptyList());
+
+    PCollection<Integer> pcol1 = 
p.apply(queueStream1).setCoder(VarIntCoder.of());
+    PCollection<Integer> pcol2 = 
p.apply(queueStream2).setCoder(VarIntCoder.of());
+    PCollection<Integer> flattened =
+        
PCollectionList.of(pcol1).and(pcol2).apply(Flatten.<Integer>pCollections());
+    flattened.apply(ParDo.of(new PassthroughFn<>()));
+
+    p.traverseTopologically(new StreamingSourceTracker(jssc, p, 
ParDo.Bound.class, -1, -2));
+    assertThat(StreamingSourceTracker.numAssertions, equalTo(1));
+  }
+
+  private static class PassthroughFn<T> extends DoFn<T, T> {
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      c.output(c.element());
+    }
+  }
+
+  private static class StreamingSourceTracker extends 
Pipeline.PipelineVisitor.Defaults {
+    private final EvaluationContext ctxt;
+    private final SparkRunner.Evaluator evaluator;
+    private final Class<? extends PTransform> transformClassToAssert;
+    private final Integer[] expected;
+
+    private static int numAssertions = 0;
+
+    private StreamingSourceTracker(
+        JavaStreamingContext jssc,
+        Pipeline pipeline,
+        Class<? extends PTransform> transformClassToAssert,
+        Integer... expected) {
+      this.ctxt = new EvaluationContext(jssc.sparkContext(), pipeline, jssc);
+      this.evaluator = new SparkRunner.Evaluator(
+          new StreamingTransformTranslator.Translator(new 
TransformTranslator.Translator()), ctxt);
+      this.transformClassToAssert = transformClassToAssert;
+      this.expected = expected;
+    }
+
+    private void assertSourceIds(List<Integer> streamingSources) {
+      numAssertions++;
+      assertThat(streamingSources, containsInAnyOrder(expected));
+    }
+
+    @Override
+    public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node 
node) {
+      return evaluator.enterCompositeTransform(node);
+    }
+
+    @Override
+    public void visitPrimitiveTransform(TransformHierarchy.Node node) {
+      PTransform transform = node.getTransform();
+      if (transform.getClass() == transformClassToAssert) {
+        AppliedPTransform<?, ?, ?> appliedTransform = 
node.toAppliedPTransform();
+        ctxt.setCurrentTransform(appliedTransform);
+        //noinspection unchecked
+        Dataset dataset = ctxt.borrowDataset((PTransform<? extends PValue, ?>) 
transform);
+        assertSourceIds(((UnboundedDataset<?>) dataset).getStreamingSources());
+        ctxt.setCurrentTransform(null);
+      } else {
+        evaluator.visitPrimitiveTransform(node);
+      }
+    }
+  }
+
+}

Reply via email to