Use TestSparkRunner in tests.

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2bcd40c2
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2bcd40c2
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2bcd40c2

Branch: refs/heads/master
Commit: 2bcd40c206d6b615135375276b249364fedf32af
Parents: c25a02f
Author: Sela <ans...@paypal.com>
Authored: Sat Feb 18 22:10:50 2017 +0200
Committer: Sela <ans...@paypal.com>
Committed: Wed Mar 1 00:17:59 2017 +0200

----------------------------------------------------------------------
 .../beam/runners/spark/TestSparkRunner.java     | 92 ++++++++++++--------
 .../utils/SparkTestPipelineOptions.java         |  4 +-
 .../SparkTestPipelineOptionsForStreaming.java   | 11 +++
 3 files changed, 69 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/2bcd40c2/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index 8b8f9ba..24bc038 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -21,12 +21,12 @@ package org.apache.beam.runners.spark;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.is;
 
+import java.io.File;
+import java.io.IOException;
 import org.apache.beam.runners.core.UnboundedReadFromBoundedSource;
-import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
-import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
+import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton;
 import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
 import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.PipelineResult.State;
 import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
@@ -40,6 +40,10 @@ import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
+import org.apache.commons.io.FileUtils;
+import org.joda.time.Duration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 
 /**
@@ -65,6 +69,8 @@ import org.apache.beam.sdk.values.POutput;
  */
 public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> 
{
 
+  private static final Logger LOG = 
LoggerFactory.getLogger(TestSparkRunner.class);
+
   private SparkRunner delegate;
   private boolean isForceStreaming;
   private int expectedNumberOfAssertions = 0;
@@ -87,7 +93,7 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
   @SuppressWarnings("unchecked")
   @Override
   public <OutputT extends POutput, InputT extends PInput> OutputT apply(
-          PTransform<InputT, OutputT> transform, InputT input) {
+      PTransform<InputT, OutputT> transform, InputT input) {
     // if the pipeline forces execution as a streaming pipeline,
     // and the source is an adapted unbounded source (as bounded),
     // read it as unbounded source via UnboundedReadFromBoundedSource.
@@ -108,38 +114,52 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
 
   @Override
   public SparkPipelineResult run(Pipeline pipeline) {
-    // clear state of Aggregators, Metrics and Watermarks if exists.
-    AggregatorsAccumulator.clear();
-    SparkMetricsContainer.clear();
-    GlobalWatermarkHolder.clear();
-
-    TestPipelineOptions testPipelineOptions = 
pipeline.getOptions().as(TestPipelineOptions.class);
-    SparkPipelineResult result = delegate.run(pipeline);
-    result.waitUntilFinish();
-
-
-    // make sure the test pipeline finished successfully.
-    State resultState = result.getState();
-    assertThat(
-        String.format("Test pipeline result state was %s instead of %s", 
resultState, State.DONE),
-        resultState,
-        is(State.DONE));
-    assertThat(result, testPipelineOptions.getOnCreateMatcher());
-    assertThat(result, testPipelineOptions.getOnSuccessMatcher());
-
-    // if the pipeline was executed in streaming mode, validate aggregators.
-    if (isForceStreaming) {
-      // validate assertion succeeded (at least once).
-      int success = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, 
Integer.class);
-      assertThat(
-          String.format(
-              "Expected %d successful assertions, but found %d.",
-              expectedNumberOfAssertions, success),
-          success,
-          is(expectedNumberOfAssertions));
-      // validate assertion didn't fail.
-      int failure = result.getAggregatorValue(PAssert.FAILURE_COUNTER, 
Integer.class);
-      assertThat("Failure aggregator should be zero.", failure, is(0));
+    SparkPipelineOptions sparkOptions = 
pipeline.getOptions().as(SparkPipelineOptions.class);
+    long timeout = sparkOptions.getForcedTimeout();
+    SparkPipelineResult result = null;
+    try {
+      // clear state of Accumulators and Aggregators.
+      AccumulatorSingleton.clear();
+      GlobalWatermarkHolder.clear();
+
+      TestPipelineOptions testPipelineOptions = 
pipeline.getOptions().as(TestPipelineOptions.class);
+      LOG.info("About to run test pipeline " + sparkOptions.getJobName());
+      result = delegate.run(pipeline);
+      result.waitUntilFinish(Duration.millis(timeout));
+
+      assertThat(result, testPipelineOptions.getOnCreateMatcher());
+      assertThat(result, testPipelineOptions.getOnSuccessMatcher());
+
+      // if the pipeline was executed in streaming mode, validate aggregators.
+      if (isForceStreaming) {
+        // validate assertion succeeded (at least once).
+        int successAssertions = 
result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
+        assertThat(
+            String.format(
+                "Expected %d successful assertions, but found %d.",
+                expectedNumberOfAssertions, successAssertions),
+            successAssertions,
+            is(expectedNumberOfAssertions));
+        // validate assertion didn't fail.
+        int failedAssertions = 
result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class);
+        assertThat(
+            String.format("Found %d failed assertions.", failedAssertions),
+            failedAssertions,
+            is(0));
+
+        LOG.info(
+            String.format(
+                "Successfully asserted pipeline %s with %d successful 
assertions.",
+                sparkOptions.getJobName(),
+                successAssertions));
+      }
+    } finally {
+      try {
+        // cleanup checkpoint dir.
+        FileUtils.deleteDirectory(new File(sparkOptions.getCheckpointDir()));
+      } catch (IOException e) {
+        throw new RuntimeException("Failed to clear checkpoint tmp dir.", e);
+      }
     }
     return result;
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/2bcd40c2/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
index 2da9888..efc17d3 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
@@ -18,7 +18,7 @@
 package org.apache.beam.runners.spark.translation.streaming.utils;
 
 import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.runners.spark.TestSparkRunner;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.junit.rules.ExternalResource;
 
@@ -32,7 +32,7 @@ public class SparkTestPipelineOptions extends 
ExternalResource {
 
   @Override
   protected void before() throws Throwable {
-    options.setRunner(SparkRunner.class);
+    options.setRunner(TestSparkRunner.class);
     options.setEnableSparkMetricSinks(false);
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/2bcd40c2/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
index 28f6d5d..dd3e4c8 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
@@ -20,6 +20,7 @@ package 
org.apache.beam.runners.spark.translation.streaming.utils;
 
 import java.io.IOException;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.joda.time.Duration;
 import org.junit.rules.TemporaryFolder;
 
 
@@ -28,10 +29,20 @@ import org.junit.rules.TemporaryFolder;
  */
 public class SparkTestPipelineOptionsForStreaming extends 
SparkTestPipelineOptions {
 
+  private static final int DEFAULT_NUMBER_OF_BATCHES_TIMEOUT = 5;
+
   public SparkPipelineOptions withTmpCheckpointDir(TemporaryFolder parent)
       throws IOException {
     // tests use JUnit's TemporaryFolder path in the form of: /.../junit/...
     
options.setCheckpointDir(parent.newFolder(options.getJobName()).toURI().toURL().toString());
+    options.setForceStreaming(true);
+    // set the default timeout to DEFAULT_NUMBER_OF_BATCHES_TIMEOUT x 
batchDuration
+    // to allow pipelines to finish.
+    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
+    // set the checkpoint duration to match interval.
+    options.setCheckpointDurationMillis(batchDuration.getMillis());
+    long forcedTimeout = 
batchDuration.multipliedBy(DEFAULT_NUMBER_OF_BATCHES_TIMEOUT).getMillis();
+    options.setForcedTimeout(forcedTimeout);
     return options;
   }
 }

Reply via email to