Use a PipelineRule for test pipelines.

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/4ca56806
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/4ca56806
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/4ca56806

Branch: refs/heads/master
Commit: 4ca5680681e14c4be8b0ac1f3e1244c9d439e6d8
Parents: bd41d9a
Author: Sela <ans...@paypal.com>
Authored: Mon Feb 27 19:24:45 2017 +0200
Committer: Sela <ans...@paypal.com>
Committed: Wed Mar 1 00:18:07 2017 +0200

----------------------------------------------------------------------
 runners/spark/pom.xml                           |   1 -
 .../runners/spark/SparkPipelineOptions.java     |   9 --
 .../runners/spark/TestSparkPipelineOptions.java |  18 ++++
 .../beam/runners/spark/TestSparkRunner.java     |  28 ++---
 .../beam/runners/spark/ForceStreamingTest.java  |   8 +-
 .../apache/beam/runners/spark/PipelineRule.java | 103 +++++++++++++++++++
 .../metrics/sink/NamedAggregatorsTest.java      |   9 +-
 .../beam/runners/spark/io/AvroPipelineTest.java |   6 +-
 .../beam/runners/spark/io/NumShardsTest.java    |   6 +-
 .../io/hadoop/HadoopFileFormatPipelineTest.java |   6 +-
 .../spark/translation/StorageLevelTest.java     |   8 +-
 .../translation/streaming/CreateStreamTest.java |  77 ++++----------
 .../utils/SparkTestPipelineOptions.java         |  42 --------
 .../SparkTestPipelineOptionsForStreaming.java   |  48 ---------
 14 files changed, 175 insertions(+), 194 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/pom.xml
----------------------------------------------------------------------
diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
index f926bf5..409fc27 100644
--- a/runners/spark/pom.xml
+++ b/runners/spark/pom.xml
@@ -80,7 +80,6 @@
                     org.apache.beam.sdk.testing.UsesCommittedMetrics
                   </excludedGroups>
                   <parallel>none</parallel>
-                  <threadCount>4</threadCount>
                   <forkCount>1</forkCount>
                   <reuseForks>false</reuseForks>
                   <failIfNoTests>true</failIfNoTests>

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
index 52d8ce1..26b549b 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
@@ -101,13 +101,4 @@ public interface SparkPipelineOptions
   boolean getUsesProvidedSparkContext();
   void setUsesProvidedSparkContext(boolean value);
 
-  @Description("A special flag that forces streaming in tests.")
-  @Default.Boolean(false)
-  boolean isForceStreaming();
-  void setForceStreaming(boolean forceStreaming);
-
-  @Description("A forced timeout (millis), mostly for testing.")
-  @Default.Long(3000L)
-  Long getForcedTimeout();
-  void setForcedTimeout(Long forcedTimeout);
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
new file mode 100644
index 0000000..2cb58d8
--- /dev/null
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java
@@ -0,0 +1,18 @@
+package org.apache.beam.runners.spark;
+
+import org.apache.beam.sdk.options.Default;
+import org.apache.beam.sdk.options.Description;
+import org.apache.beam.sdk.testing.TestPipelineOptions;
+
+
+/**
+ * A {@link SparkPipelineOptions} for tests.
+ */
+public interface TestSparkPipelineOptions extends SparkPipelineOptions, 
TestPipelineOptions {
+
+  @Description("A special flag that forces streaming in tests.")
+  @Default.Boolean(false)
+  boolean isForceStreaming();
+  void setForceStreaming(boolean forceStreaming);
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index 5d71ea5..035da00 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -18,6 +18,7 @@
 
 package org.apache.beam.runners.spark;
 
+import static com.google.common.base.Preconditions.checkNotNull;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.is;
 
@@ -33,7 +34,6 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.testing.TestPipelineOptions;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.util.ValueWithRecordId;
@@ -76,15 +76,15 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
   private boolean isForceStreaming;
   private int expectedNumberOfAssertions = 0;
 
-  private TestSparkRunner(SparkPipelineOptions options) {
+  private TestSparkRunner(TestSparkPipelineOptions options) {
     this.delegate = SparkRunner.fromOptions(options);
     this.isForceStreaming = options.isForceStreaming();
   }
 
   public static TestSparkRunner fromOptions(PipelineOptions options) {
     // Default options suffice to set it up as a test runner
-    SparkPipelineOptions sparkOptions =
-        PipelineOptionsValidator.validate(SparkPipelineOptions.class, options);
+    TestSparkPipelineOptions sparkOptions =
+        PipelineOptionsValidator.validate(TestSparkPipelineOptions.class, 
options);
     return new TestSparkRunner(sparkOptions);
   }
 
@@ -115,22 +115,22 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
 
   @Override
   public SparkPipelineResult run(Pipeline pipeline) {
-    SparkPipelineOptions sparkOptions = 
pipeline.getOptions().as(SparkPipelineOptions.class);
+    TestSparkPipelineOptions testSparkPipelineOptions =
+        pipeline.getOptions().as(TestSparkPipelineOptions.class);
     SparkPipelineResult result = null;
     // clear state of Aggregators, Metrics and Watermarks.
     AggregatorsAccumulator.clear();
     SparkMetricsContainer.clear();
     GlobalWatermarkHolder.clear();
 
-    TestPipelineOptions testPipelineOptions = 
pipeline.getOptions().as(TestPipelineOptions.class);
-    LOG.info("About to run test pipeline " + sparkOptions.getJobName());
+    LOG.info("About to run test pipeline " + 
testSparkPipelineOptions.getJobName());
 
     // if the pipeline was executed in streaming mode, validate aggregators.
     if (isForceStreaming) {
       try {
         result = delegate.run(pipeline);
-        long timeout = sparkOptions.getForcedTimeout();
-        result.waitUntilFinish(Duration.millis(timeout));
+        Long timeout = testSparkPipelineOptions.getTestTimeoutSeconds();
+        
result.waitUntilFinish(Duration.standardSeconds(checkNotNull(timeout)));
         // validate assertion succeeded (at least once).
         int successAssertions = 
result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class);
         assertThat(
@@ -149,12 +149,12 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
         LOG.info(
             String.format(
                 "Successfully asserted pipeline %s with %d successful 
assertions.",
-                sparkOptions.getJobName(),
+                testSparkPipelineOptions.getJobName(),
                 successAssertions));
-        } finally {
+      } finally {
         try {
           // cleanup checkpoint dir.
-          FileUtils.deleteDirectory(new File(sparkOptions.getCheckpointDir()));
+          FileUtils.deleteDirectory(new 
File(testSparkPipelineOptions.getCheckpointDir()));
         } catch (IOException e) {
           throw new RuntimeException("Failed to clear checkpoint tmp dir.", e);
         }
@@ -164,8 +164,8 @@ public final class TestSparkRunner extends 
PipelineRunner<SparkPipelineResult> {
       result = delegate.run(pipeline);
       result.waitUntilFinish();
       // assert via matchers.
-      assertThat(result, testPipelineOptions.getOnCreateMatcher());
-      assertThat(result, testPipelineOptions.getOnSuccessMatcher());
+      assertThat(result, testSparkPipelineOptions.getOnCreateMatcher());
+      assertThat(result, testSparkPipelineOptions.getOnSuccessMatcher());
     }
     return result;
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
index c3026ce..9b39558 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java
@@ -21,7 +21,6 @@ package org.apache.beam.runners.spark;
 import static org.hamcrest.MatcherAssert.assertThat;
 
 import java.io.IOException;
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptionsForStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource;
 import org.apache.beam.sdk.io.CountingSource;
@@ -46,14 +45,11 @@ import org.junit.Test;
 public class ForceStreamingTest {
 
   @Rule
-  public SparkTestPipelineOptionsForStreaming commonOptions =
-      new SparkTestPipelineOptionsForStreaming();
+  public final PipelineRule pipelineRule = PipelineRule.streaming();
 
   @Test
   public void test() throws IOException {
-    SparkPipelineOptions options = commonOptions.getOptions();
-    options.setForceStreaming(true);
-    Pipeline pipeline = Pipeline.create(options);
+    Pipeline pipeline = pipelineRule.createPipeline();
 
     // apply the BoundedReadFromUnboundedSource.
     BoundedReadFromUnboundedSource<?> boundedRead =

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
new file mode 100644
index 0000000..bb42510
--- /dev/null
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java
@@ -0,0 +1,103 @@
+package org.apache.beam.runners.spark;
+
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptionsFactory;
+import org.joda.time.Duration;
+import org.junit.rules.ExternalResource;
+import org.junit.rules.RuleChain;
+import org.junit.rules.TemporaryFolder;
+import org.junit.rules.TestName;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runners.model.Statement;
+
+/**
+ * A {@link org.junit.Rule} to provide a {@link Pipeline} instance for Spark 
runner tests.
+ */
+public class PipelineRule implements TestRule {
+
+  private final TestName testName = new TestName();
+
+  private final SparkPipelineRule delegate;
+  private final RuleChain chain;
+
+  private PipelineRule() {
+    this.delegate = new SparkPipelineRule(testName);
+    this.chain = RuleChain.outerRule(testName).around(this.delegate);
+  }
+
+  private PipelineRule(Duration forcedTimeout) {
+    this.delegate = new SparkStreamingPipelineRule(forcedTimeout, testName);
+    this.chain = RuleChain.outerRule(testName).around(this.delegate);
+  }
+
+  public static PipelineRule streaming() {
+    return new PipelineRule(Duration.standardSeconds(5));
+  }
+
+  public static PipelineRule batch() {
+    return new PipelineRule();
+  }
+
+  public Duration batchDuration() {
+    return Duration.millis(delegate.options.getBatchIntervalMillis());
+  }
+
+  public SparkPipelineOptions getOptions() {
+    return delegate.options;
+  }
+
+  public Pipeline createPipeline() {
+    return Pipeline.create(delegate.options);
+  }
+
+  @Override
+  public Statement apply(Statement statement, Description description) {
+    return chain.apply(statement, description);
+  }
+
+  private static class SparkStreamingPipelineRule extends SparkPipelineRule {
+
+    private final TemporaryFolder temporaryFolder = new TemporaryFolder();
+    private final Duration forcedTimeout;
+
+    SparkStreamingPipelineRule(Duration forcedTimeout, TestName testName) {
+      super(testName);
+      this.forcedTimeout = forcedTimeout;
+    }
+
+    @Override
+    protected void before() throws Throwable {
+      super.before();
+      temporaryFolder.create();
+      options.setForceStreaming(true);
+      options.setTestTimeoutSeconds(forcedTimeout.getStandardSeconds());
+      options.setCheckpointDir(
+          
temporaryFolder.newFolder(options.getJobName()).toURI().toURL().toString());
+    }
+
+    @Override
+    protected void after() {
+      temporaryFolder.delete();
+    }
+  }
+
+  private static class SparkPipelineRule extends ExternalResource {
+
+    protected final TestSparkPipelineOptions options =
+        PipelineOptionsFactory.as(TestSparkPipelineOptions.class);
+
+    private final TestName testName;
+
+    private SparkPipelineRule(TestName testName) {
+      this.testName = testName;
+    }
+
+    @Override
+    protected void before() throws Throwable {
+      options.setRunner(TestSparkRunner.class);
+      options.setEnableSparkMetricSinks(false);
+      options.setJobName(testName.getMethodName());
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java
index 2f7202c..a192807 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java
@@ -26,12 +26,12 @@ import com.google.common.collect.ImmutableSet;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Set;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.aggregators.ClearAggregatorsRule;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.examples.WordCount;
 import org.apache.beam.runners.spark.translation.SparkContextFactory;
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
@@ -56,12 +56,11 @@ public class NamedAggregatorsTest {
   public ClearAggregatorsRule clearAggregators = new ClearAggregatorsRule();
 
   @Rule
-  public final SparkTestPipelineOptions pipelineOptions = new 
SparkTestPipelineOptions();
+  public final PipelineRule pipelineRule = PipelineRule.batch();
 
   private Pipeline createSparkPipeline() {
-    SparkPipelineOptions options = pipelineOptions.getOptions();
-    options.setEnableSparkMetricSinks(true);
-    return Pipeline.create(options);
+    pipelineRule.getOptions().setEnableSparkMetricSinks(true);
+    return pipelineRule.createPipeline();
   }
 
   private void runPipeline() {

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java
index c5bb583..2a73c28 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java
@@ -33,7 +33,7 @@ import org.apache.avro.generic.GenericData;
 import org.apache.avro.generic.GenericDatumReader;
 import org.apache.avro.generic.GenericDatumWriter;
 import org.apache.avro.generic.GenericRecord;
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.AvroIO;
 import org.apache.beam.sdk.values.PCollection;
@@ -54,7 +54,7 @@ public class AvroPipelineTest {
   public final TemporaryFolder tmpDir = new TemporaryFolder();
 
   @Rule
-  public final SparkTestPipelineOptions pipelineOptions = new 
SparkTestPipelineOptions();
+  public final PipelineRule pipelineRule = PipelineRule.batch();
 
   @Before
   public void setUp() throws IOException {
@@ -72,7 +72,7 @@ public class AvroPipelineTest {
     savedRecord.put("siblingnames", Lists.newArrayList("Jimmy", "Jane"));
     populateGenericFile(Lists.newArrayList(savedRecord), schema);
 
-    Pipeline p = Pipeline.create(pipelineOptions.getOptions());
+    Pipeline p = pipelineRule.createPipeline();
     PCollection<GenericRecord> input = p.apply(
         AvroIO.Read.from(inputFile.getAbsolutePath()).withSchema(schema));
     
input.apply(AvroIO.Write.to(outputDir.getAbsolutePath()).withSchema(schema));

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
index 34d6818..c936ed3 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java
@@ -30,8 +30,8 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Set;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.runners.spark.examples.WordCount;
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.TextIO;
@@ -59,7 +59,7 @@ public class NumShardsTest {
   public final TemporaryFolder tmpDir = new TemporaryFolder();
 
   @Rule
-  public final SparkTestPipelineOptions pipelineOptions = new 
SparkTestPipelineOptions();
+  public final PipelineRule pipelineRule = PipelineRule.batch();
 
   @Before
   public void setUp() throws IOException {
@@ -69,7 +69,7 @@ public class NumShardsTest {
 
   @Test
   public void testText() throws Exception {
-    Pipeline p = Pipeline.create(pipelineOptions.getOptions());
+    Pipeline p = pipelineRule.createPipeline();
     PCollection<String> inputWords = 
p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of()));
     PCollection<String> output = inputWords.apply(new WordCount.CountWords())
         .apply(MapElements.via(new WordCount.FormatAsTextFn()));

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java
index 9efc670..a5072d6 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java
@@ -22,8 +22,8 @@ import static org.junit.Assert.assertEquals;
 
 import java.io.File;
 import java.io.IOException;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.runners.spark.coders.WritableCoder;
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.values.KV;
@@ -52,7 +52,7 @@ public class HadoopFileFormatPipelineTest {
   private File outputFile;
 
   @Rule
-  public final SparkTestPipelineOptions pipelineOptions = new 
SparkTestPipelineOptions();
+  public final PipelineRule pipelineRule = PipelineRule.batch();
 
   @Rule
   public final TemporaryFolder tmpDir = new TemporaryFolder();
@@ -68,7 +68,7 @@ public class HadoopFileFormatPipelineTest {
   public void testSequenceFile() throws Exception {
     populateFile();
 
-    Pipeline p = Pipeline.create(pipelineOptions.getOptions());
+    Pipeline p = pipelineRule.createPipeline();
     @SuppressWarnings("unchecked")
     Class<? extends FileInputFormat<IntWritable, Text>> inputFormatClass =
         (Class<? extends FileInputFormat<IntWritable, Text>>)

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
index 48105e1..4dc5dee 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
@@ -17,7 +17,7 @@
  */
 package org.apache.beam.runners.spark.translation;
 
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.Count;
@@ -32,12 +32,12 @@ import org.junit.Test;
 public class StorageLevelTest {
 
   @Rule
-  public final transient SparkTestPipelineOptions pipelineOptions = new 
SparkTestPipelineOptions();
+  public final transient PipelineRule pipelineRule = PipelineRule.batch();
 
   @Test
   public void test() throws Exception {
-    pipelineOptions.getOptions().setStorageLevel("DISK_ONLY");
-    Pipeline p = Pipeline.create(pipelineOptions.getOptions());
+    pipelineRule.getOptions().setStorageLevel("DISK_ONLY");
+    Pipeline p = pipelineRule.createPipeline();
 
     PCollection<String> pCollection = p.apply(Create.of("foo"));
 

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
index 9ee5cc5..f2783a1 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java
@@ -24,10 +24,9 @@ import static org.junit.Assert.assertThat;
 
 import java.io.IOException;
 import java.io.Serializable;
+import org.apache.beam.runners.spark.PipelineRule;
 import org.apache.beam.runners.spark.ReuseSparkContextRule;
-import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.io.CreateStream;
-import 
org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptionsForStreaming;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
@@ -59,8 +58,6 @@ import org.joda.time.Instant;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
-import org.junit.rules.TemporaryFolder;
-import org.junit.rules.TestName;
 
 
 /**
@@ -76,27 +73,18 @@ import org.junit.rules.TestName;
 public class CreateStreamTest implements Serializable {
 
   @Rule
-  public transient TemporaryFolder checkpointParentDir = new TemporaryFolder();
+  public final transient PipelineRule pipelineRule = PipelineRule.streaming();
   @Rule
-  public transient SparkTestPipelineOptionsForStreaming commonOptions =
-      new SparkTestPipelineOptionsForStreaming();
+  public final transient ReuseSparkContextRule noContextResue = 
ReuseSparkContextRule.no();
   @Rule
-  public transient ReuseSparkContextRule noContextResue = 
ReuseSparkContextRule.no();
-  @Rule
-  public transient TestName testName = new TestName();
-  @Rule
-  public transient ExpectedException thrown = ExpectedException.none();
+  public final transient ExpectedException thrown = ExpectedException.none();
 
   @Test
   public void testLateDataAccumulating() throws IOException {
-    SparkPipelineOptions options = 
commonOptions.withTmpCheckpointDir(checkpointParentDir);
-    Pipeline p = Pipeline.create(options);
-    options.setJobName(testName.getMethodName());
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-
+    Pipeline p = pipelineRule.createPipeline();
     Instant instant = new Instant(0);
     CreateStream<TimestampedValue<Integer>> source =
-        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch()
             
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(6)))
             .nextBatch(
@@ -167,13 +155,9 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testDiscardingMode() throws IOException {
-    SparkPipelineOptions options = 
commonOptions.withTmpCheckpointDir(checkpointParentDir);
-    Pipeline p = Pipeline.create(options);
-    options.setJobName(testName.getMethodName());
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-
+    Pipeline p = pipelineRule.createPipeline();
     CreateStream<TimestampedValue<String>> source =
-        CreateStream.<TimestampedValue<String>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<String>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch(
                 TimestampedValue.of("firstPane", new Instant(100)),
                 TimestampedValue.of("alsoFirstPane", new Instant(200)))
@@ -221,14 +205,10 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testFirstElementLate() throws IOException {
-    SparkPipelineOptions options = 
commonOptions.withTmpCheckpointDir(checkpointParentDir);
-    Pipeline p = Pipeline.create(options);
-    options.setJobName(testName.getMethodName());
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-
+    Pipeline p = pipelineRule.createPipeline();
     Instant lateElementTimestamp = new Instant(-1_000_000);
     CreateStream<TimestampedValue<String>> source =
-        CreateStream.<TimestampedValue<String>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<String>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch()
             .advanceWatermarkForNextBatch(new Instant(0))
             .nextBatch(
@@ -261,14 +241,10 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testElementsAtAlmostPositiveInfinity() throws IOException {
-    SparkPipelineOptions options = 
commonOptions.withTmpCheckpointDir(checkpointParentDir);
-    Pipeline p = Pipeline.create(options);
-    options.setJobName(testName.getMethodName());
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-
+    Pipeline p = pipelineRule.createPipeline();
     Instant endOfGlobalWindow = GlobalWindow.INSTANCE.maxTimestamp();
     CreateStream<TimestampedValue<String>> source =
-        CreateStream.<TimestampedValue<String>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<String>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch(
                 TimestampedValue.of("foo", endOfGlobalWindow),
                 TimestampedValue.of("bar", endOfGlobalWindow))
@@ -292,17 +268,13 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testMultipleStreams() throws IOException {
-    SparkPipelineOptions options = 
commonOptions.withTmpCheckpointDir(checkpointParentDir);
-    Pipeline p = Pipeline.create(options);
-    options.setJobName(testName.getMethodName());
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-
+    Pipeline p = pipelineRule.createPipeline();
     CreateStream<String> source =
-        CreateStream.<String>withBatchInterval(batchDuration)
+        CreateStream.<String>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch("foo", "bar")
             .advanceNextBatchWatermarkToInfinity();
     CreateStream<Integer> other =
-        CreateStream.<Integer>withBatchInterval(batchDuration)
+        CreateStream.<Integer>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch(1, 2, 3, 4)
             .advanceNextBatchWatermarkToInfinity();
 
@@ -327,14 +299,10 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testFlattenedWithWatermarkHold() throws IOException {
-    SparkPipelineOptions options = 
commonOptions.withTmpCheckpointDir(checkpointParentDir);
-    Pipeline p = Pipeline.create(options);
-    options.setJobName(testName.getMethodName());
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-
+    Pipeline p = pipelineRule.createPipeline();
     Instant instant = new Instant(0);
     CreateStream<TimestampedValue<Integer>> source1 =
-        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch()
             
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5)))
             .nextBatch(
@@ -343,7 +311,7 @@ public class CreateStreamTest implements Serializable {
                 TimestampedValue.of(3, instant))
             
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(10)));
     CreateStream<TimestampedValue<Integer>> source2 =
-        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch()
             
.advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(1)))
             .nextBatch(
@@ -384,9 +352,8 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testElementAtPositiveInfinityThrows() {
-    Duration batchDuration = 
Duration.millis(commonOptions.getOptions().getBatchIntervalMillis());
     CreateStream<TimestampedValue<Integer>> source =
-        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration)
+        
CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration())
             .nextBatch(TimestampedValue.of(-1, 
BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L)));
     thrown.expect(IllegalArgumentException.class);
     source.nextBatch(TimestampedValue.of(1, 
BoundedWindow.TIMESTAMP_MAX_VALUE));
@@ -394,9 +361,8 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testAdvanceWatermarkNonMonotonicThrows() {
-    Duration batchDuration = 
Duration.millis(commonOptions.getOptions().getBatchIntervalMillis());
     CreateStream<Integer> source =
-        CreateStream.<Integer>withBatchInterval(batchDuration)
+        CreateStream.<Integer>withBatchInterval(pipelineRule.batchDuration())
             .advanceWatermarkForNextBatch(new Instant(0L));
     thrown.expect(IllegalArgumentException.class);
     source.advanceWatermarkForNextBatch(new Instant(-1L));
@@ -404,9 +370,8 @@ public class CreateStreamTest implements Serializable {
 
   @Test
   public void testAdvanceWatermarkEqualToPositiveInfinityThrows() {
-    Duration batchDuration = 
Duration.millis(commonOptions.getOptions().getBatchIntervalMillis());
     CreateStream<Integer> source =
-        CreateStream.<Integer>withBatchInterval(batchDuration)
+        CreateStream.<Integer>withBatchInterval(pipelineRule.batchDuration())
             
.advanceWatermarkForNextBatch(BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L));
     thrown.expect(IllegalArgumentException.class);
     source.advanceWatermarkForNextBatch(BoundedWindow.TIMESTAMP_MAX_VALUE);

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
deleted file mode 100644
index efc17d3..0000000
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.translation.streaming.utils;
-
-import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.apache.beam.runners.spark.TestSparkRunner;
-import org.apache.beam.sdk.options.PipelineOptionsFactory;
-import org.junit.rules.ExternalResource;
-
-/**
- * A rule to create a common {@link SparkPipelineOptions} test options for 
spark-runner.
- */
-public class SparkTestPipelineOptions extends ExternalResource {
-
-  protected final SparkPipelineOptions options =
-      PipelineOptionsFactory.as(SparkPipelineOptions.class);
-
-  @Override
-  protected void before() throws Throwable {
-    options.setRunner(TestSparkRunner.class);
-    options.setEnableSparkMetricSinks(false);
-  }
-
-  public SparkPipelineOptions getOptions() {
-    return options;
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
deleted file mode 100644
index dd3e4c8..0000000
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.translation.streaming.utils;
-
-import java.io.IOException;
-import org.apache.beam.runners.spark.SparkPipelineOptions;
-import org.joda.time.Duration;
-import org.junit.rules.TemporaryFolder;
-
-
-/**
- * A rule to create a common {@link SparkPipelineOptions} for testing 
streaming pipelines.
- */
-public class SparkTestPipelineOptionsForStreaming extends 
SparkTestPipelineOptions {
-
-  private static final int DEFAULT_NUMBER_OF_BATCHES_TIMEOUT = 5;
-
-  public SparkPipelineOptions withTmpCheckpointDir(TemporaryFolder parent)
-      throws IOException {
-    // tests use JUnit's TemporaryFolder path in the form of: /.../junit/...
-    
options.setCheckpointDir(parent.newFolder(options.getJobName()).toURI().toURL().toString());
-    options.setForceStreaming(true);
-    // set the default timeout to DEFAULT_NUMBER_OF_BATCHES_TIMEOUT x 
batchDuration
-    // to allow pipelines to finish.
-    Duration batchDuration = Duration.millis(options.getBatchIntervalMillis());
-    // set the checkpoint duration to match interval.
-    options.setCheckpointDurationMillis(batchDuration.getMillis());
-    long forcedTimeout = 
batchDuration.multipliedBy(DEFAULT_NUMBER_OF_BATCHES_TIMEOUT).getMillis();
-    options.setForcedTimeout(forcedTimeout);
-    return options;
-  }
-}

Reply via email to