Use a PipelineRule for test pipelines.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/4ca56806 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/4ca56806 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/4ca56806 Branch: refs/heads/master Commit: 4ca5680681e14c4be8b0ac1f3e1244c9d439e6d8 Parents: bd41d9a Author: Sela <ans...@paypal.com> Authored: Mon Feb 27 19:24:45 2017 +0200 Committer: Sela <ans...@paypal.com> Committed: Wed Mar 1 00:18:07 2017 +0200 ---------------------------------------------------------------------- runners/spark/pom.xml | 1 - .../runners/spark/SparkPipelineOptions.java | 9 -- .../runners/spark/TestSparkPipelineOptions.java | 18 ++++ .../beam/runners/spark/TestSparkRunner.java | 28 ++--- .../beam/runners/spark/ForceStreamingTest.java | 8 +- .../apache/beam/runners/spark/PipelineRule.java | 103 +++++++++++++++++++ .../metrics/sink/NamedAggregatorsTest.java | 9 +- .../beam/runners/spark/io/AvroPipelineTest.java | 6 +- .../beam/runners/spark/io/NumShardsTest.java | 6 +- .../io/hadoop/HadoopFileFormatPipelineTest.java | 6 +- .../spark/translation/StorageLevelTest.java | 8 +- .../translation/streaming/CreateStreamTest.java | 77 ++++---------- .../utils/SparkTestPipelineOptions.java | 42 -------- .../SparkTestPipelineOptionsForStreaming.java | 48 --------- 14 files changed, 175 insertions(+), 194 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/pom.xml ---------------------------------------------------------------------- diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index f926bf5..409fc27 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -80,7 +80,6 @@ org.apache.beam.sdk.testing.UsesCommittedMetrics </excludedGroups> <parallel>none</parallel> - <threadCount>4</threadCount> <forkCount>1</forkCount> <reuseForks>false</reuseForks> <failIfNoTests>true</failIfNoTests> http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java index 52d8ce1..26b549b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java @@ -101,13 +101,4 @@ public interface SparkPipelineOptions boolean getUsesProvidedSparkContext(); void setUsesProvidedSparkContext(boolean value); - @Description("A special flag that forces streaming in tests.") - @Default.Boolean(false) - boolean isForceStreaming(); - void setForceStreaming(boolean forceStreaming); - - @Description("A forced timeout (millis), mostly for testing.") - @Default.Long(3000L) - Long getForcedTimeout(); - void setForcedTimeout(Long forcedTimeout); } http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java new file mode 100644 index 0000000..2cb58d8 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkPipelineOptions.java @@ -0,0 +1,18 @@ +package org.apache.beam.runners.spark; + +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.testing.TestPipelineOptions; + + +/** + * A {@link SparkPipelineOptions} for tests. + */ +public interface TestSparkPipelineOptions extends SparkPipelineOptions, TestPipelineOptions { + + @Description("A special flag that forces streaming in tests.") + @Default.Boolean(false) + boolean isForceStreaming(); + void setForceStreaming(boolean forceStreaming); + +} http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index 5d71ea5..035da00 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.spark; +import static com.google.common.base.Preconditions.checkNotNull; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -33,7 +34,6 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; import org.apache.beam.sdk.runners.PipelineRunner; import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipelineOptions; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.util.ValueWithRecordId; @@ -76,15 +76,15 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { private boolean isForceStreaming; private int expectedNumberOfAssertions = 0; - private TestSparkRunner(SparkPipelineOptions options) { + private TestSparkRunner(TestSparkPipelineOptions options) { this.delegate = SparkRunner.fromOptions(options); this.isForceStreaming = options.isForceStreaming(); } public static TestSparkRunner fromOptions(PipelineOptions options) { // Default options suffice to set it up as a test runner - SparkPipelineOptions sparkOptions = - PipelineOptionsValidator.validate(SparkPipelineOptions.class, options); + TestSparkPipelineOptions sparkOptions = + PipelineOptionsValidator.validate(TestSparkPipelineOptions.class, options); return new TestSparkRunner(sparkOptions); } @@ -115,22 +115,22 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { @Override public SparkPipelineResult run(Pipeline pipeline) { - SparkPipelineOptions sparkOptions = pipeline.getOptions().as(SparkPipelineOptions.class); + TestSparkPipelineOptions testSparkPipelineOptions = + pipeline.getOptions().as(TestSparkPipelineOptions.class); SparkPipelineResult result = null; // clear state of Aggregators, Metrics and Watermarks. AggregatorsAccumulator.clear(); SparkMetricsContainer.clear(); GlobalWatermarkHolder.clear(); - TestPipelineOptions testPipelineOptions = pipeline.getOptions().as(TestPipelineOptions.class); - LOG.info("About to run test pipeline " + sparkOptions.getJobName()); + LOG.info("About to run test pipeline " + testSparkPipelineOptions.getJobName()); // if the pipeline was executed in streaming mode, validate aggregators. if (isForceStreaming) { try { result = delegate.run(pipeline); - long timeout = sparkOptions.getForcedTimeout(); - result.waitUntilFinish(Duration.millis(timeout)); + Long timeout = testSparkPipelineOptions.getTestTimeoutSeconds(); + result.waitUntilFinish(Duration.standardSeconds(checkNotNull(timeout))); // validate assertion succeeded (at least once). int successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); assertThat( @@ -149,12 +149,12 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { LOG.info( String.format( "Successfully asserted pipeline %s with %d successful assertions.", - sparkOptions.getJobName(), + testSparkPipelineOptions.getJobName(), successAssertions)); - } finally { + } finally { try { // cleanup checkpoint dir. - FileUtils.deleteDirectory(new File(sparkOptions.getCheckpointDir())); + FileUtils.deleteDirectory(new File(testSparkPipelineOptions.getCheckpointDir())); } catch (IOException e) { throw new RuntimeException("Failed to clear checkpoint tmp dir.", e); } @@ -164,8 +164,8 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { result = delegate.run(pipeline); result.waitUntilFinish(); // assert via matchers. - assertThat(result, testPipelineOptions.getOnCreateMatcher()); - assertThat(result, testPipelineOptions.getOnSuccessMatcher()); + assertThat(result, testSparkPipelineOptions.getOnCreateMatcher()); + assertThat(result, testSparkPipelineOptions.getOnSuccessMatcher()); } return result; } http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java index c3026ce..9b39558 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/ForceStreamingTest.java @@ -21,7 +21,6 @@ package org.apache.beam.runners.spark; import static org.hamcrest.MatcherAssert.assertThat; import java.io.IOException; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptionsForStreaming; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource; import org.apache.beam.sdk.io.CountingSource; @@ -46,14 +45,11 @@ import org.junit.Test; public class ForceStreamingTest { @Rule - public SparkTestPipelineOptionsForStreaming commonOptions = - new SparkTestPipelineOptionsForStreaming(); + public final PipelineRule pipelineRule = PipelineRule.streaming(); @Test public void test() throws IOException { - SparkPipelineOptions options = commonOptions.getOptions(); - options.setForceStreaming(true); - Pipeline pipeline = Pipeline.create(options); + Pipeline pipeline = pipelineRule.createPipeline(); // apply the BoundedReadFromUnboundedSource. BoundedReadFromUnboundedSource<?> boundedRead = http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java new file mode 100644 index 0000000..bb42510 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/PipelineRule.java @@ -0,0 +1,103 @@ +package org.apache.beam.runners.spark; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.joda.time.Duration; +import org.junit.rules.ExternalResource; +import org.junit.rules.RuleChain; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +/** + * A {@link org.junit.Rule} to provide a {@link Pipeline} instance for Spark runner tests. + */ +public class PipelineRule implements TestRule { + + private final TestName testName = new TestName(); + + private final SparkPipelineRule delegate; + private final RuleChain chain; + + private PipelineRule() { + this.delegate = new SparkPipelineRule(testName); + this.chain = RuleChain.outerRule(testName).around(this.delegate); + } + + private PipelineRule(Duration forcedTimeout) { + this.delegate = new SparkStreamingPipelineRule(forcedTimeout, testName); + this.chain = RuleChain.outerRule(testName).around(this.delegate); + } + + public static PipelineRule streaming() { + return new PipelineRule(Duration.standardSeconds(5)); + } + + public static PipelineRule batch() { + return new PipelineRule(); + } + + public Duration batchDuration() { + return Duration.millis(delegate.options.getBatchIntervalMillis()); + } + + public SparkPipelineOptions getOptions() { + return delegate.options; + } + + public Pipeline createPipeline() { + return Pipeline.create(delegate.options); + } + + @Override + public Statement apply(Statement statement, Description description) { + return chain.apply(statement, description); + } + + private static class SparkStreamingPipelineRule extends SparkPipelineRule { + + private final TemporaryFolder temporaryFolder = new TemporaryFolder(); + private final Duration forcedTimeout; + + SparkStreamingPipelineRule(Duration forcedTimeout, TestName testName) { + super(testName); + this.forcedTimeout = forcedTimeout; + } + + @Override + protected void before() throws Throwable { + super.before(); + temporaryFolder.create(); + options.setForceStreaming(true); + options.setTestTimeoutSeconds(forcedTimeout.getStandardSeconds()); + options.setCheckpointDir( + temporaryFolder.newFolder(options.getJobName()).toURI().toURL().toString()); + } + + @Override + protected void after() { + temporaryFolder.delete(); + } + } + + private static class SparkPipelineRule extends ExternalResource { + + protected final TestSparkPipelineOptions options = + PipelineOptionsFactory.as(TestSparkPipelineOptions.class); + + private final TestName testName; + + private SparkPipelineRule(TestName testName) { + this.testName = testName; + } + + @Override + protected void before() throws Throwable { + options.setRunner(TestSparkRunner.class); + options.setEnableSparkMetricSinks(false); + options.setJobName(testName.getMethodName()); + } + } +} http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java index 2f7202c..a192807 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/aggregators/metrics/sink/NamedAggregatorsTest.java @@ -26,12 +26,12 @@ import com.google.common.collect.ImmutableSet; import java.util.Arrays; import java.util.List; import java.util.Set; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.aggregators.ClearAggregatorsRule; import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.examples.WordCount; import org.apache.beam.runners.spark.translation.SparkContextFactory; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -56,12 +56,11 @@ public class NamedAggregatorsTest { public ClearAggregatorsRule clearAggregators = new ClearAggregatorsRule(); @Rule - public final SparkTestPipelineOptions pipelineOptions = new SparkTestPipelineOptions(); + public final PipelineRule pipelineRule = PipelineRule.batch(); private Pipeline createSparkPipeline() { - SparkPipelineOptions options = pipelineOptions.getOptions(); - options.setEnableSparkMetricSinks(true); - return Pipeline.create(options); + pipelineRule.getOptions().setEnableSparkMetricSinks(true); + return pipelineRule.createPipeline(); } private void runPipeline() { http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java index c5bb583..2a73c28 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/AvroPipelineTest.java @@ -33,7 +33,7 @@ import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.AvroIO; import org.apache.beam.sdk.values.PCollection; @@ -54,7 +54,7 @@ public class AvroPipelineTest { public final TemporaryFolder tmpDir = new TemporaryFolder(); @Rule - public final SparkTestPipelineOptions pipelineOptions = new SparkTestPipelineOptions(); + public final PipelineRule pipelineRule = PipelineRule.batch(); @Before public void setUp() throws IOException { @@ -72,7 +72,7 @@ public class AvroPipelineTest { savedRecord.put("siblingnames", Lists.newArrayList("Jimmy", "Jane")); populateGenericFile(Lists.newArrayList(savedRecord), schema); - Pipeline p = Pipeline.create(pipelineOptions.getOptions()); + Pipeline p = pipelineRule.createPipeline(); PCollection<GenericRecord> input = p.apply( AvroIO.Read.from(inputFile.getAbsolutePath()).withSchema(schema)); input.apply(AvroIO.Write.to(outputDir.getAbsolutePath()).withSchema(schema)); http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java index 34d6818..c936ed3 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/NumShardsTest.java @@ -30,8 +30,8 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Set; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.runners.spark.examples.WordCount; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.TextIO; @@ -59,7 +59,7 @@ public class NumShardsTest { public final TemporaryFolder tmpDir = new TemporaryFolder(); @Rule - public final SparkTestPipelineOptions pipelineOptions = new SparkTestPipelineOptions(); + public final PipelineRule pipelineRule = PipelineRule.batch(); @Before public void setUp() throws IOException { @@ -69,7 +69,7 @@ public class NumShardsTest { @Test public void testText() throws Exception { - Pipeline p = Pipeline.create(pipelineOptions.getOptions()); + Pipeline p = pipelineRule.createPipeline(); PCollection<String> inputWords = p.apply(Create.of(WORDS).withCoder(StringUtf8Coder.of())); PCollection<String> output = inputWords.apply(new WordCount.CountWords()) .apply(MapElements.via(new WordCount.FormatAsTextFn())); http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java index 9efc670..a5072d6 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/io/hadoop/HadoopFileFormatPipelineTest.java @@ -22,8 +22,8 @@ import static org.junit.Assert.assertEquals; import java.io.File; import java.io.IOException; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.runners.spark.coders.WritableCoder; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.values.KV; @@ -52,7 +52,7 @@ public class HadoopFileFormatPipelineTest { private File outputFile; @Rule - public final SparkTestPipelineOptions pipelineOptions = new SparkTestPipelineOptions(); + public final PipelineRule pipelineRule = PipelineRule.batch(); @Rule public final TemporaryFolder tmpDir = new TemporaryFolder(); @@ -68,7 +68,7 @@ public class HadoopFileFormatPipelineTest { public void testSequenceFile() throws Exception { populateFile(); - Pipeline p = Pipeline.create(pipelineOptions.getOptions()); + Pipeline p = pipelineRule.createPipeline(); @SuppressWarnings("unchecked") Class<? extends FileInputFormat<IntWritable, Text>> inputFormatClass = (Class<? extends FileInputFormat<IntWritable, Text>>) http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java index 48105e1..4dc5dee 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java @@ -17,7 +17,7 @@ */ package org.apache.beam.runners.spark.translation; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptions; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Count; @@ -32,12 +32,12 @@ import org.junit.Test; public class StorageLevelTest { @Rule - public final transient SparkTestPipelineOptions pipelineOptions = new SparkTestPipelineOptions(); + public final transient PipelineRule pipelineRule = PipelineRule.batch(); @Test public void test() throws Exception { - pipelineOptions.getOptions().setStorageLevel("DISK_ONLY"); - Pipeline p = Pipeline.create(pipelineOptions.getOptions()); + pipelineRule.getOptions().setStorageLevel("DISK_ONLY"); + Pipeline p = pipelineRule.createPipeline(); PCollection<String> pCollection = p.apply(Create.of("foo")); http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java index 9ee5cc5..f2783a1 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/CreateStreamTest.java @@ -24,10 +24,9 @@ import static org.junit.Assert.assertThat; import java.io.IOException; import java.io.Serializable; +import org.apache.beam.runners.spark.PipelineRule; import org.apache.beam.runners.spark.ReuseSparkContextRule; -import org.apache.beam.runners.spark.SparkPipelineOptions; import org.apache.beam.runners.spark.io.CreateStream; -import org.apache.beam.runners.spark.translation.streaming.utils.SparkTestPipelineOptionsForStreaming; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; @@ -59,8 +58,6 @@ import org.joda.time.Instant; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; -import org.junit.rules.TemporaryFolder; -import org.junit.rules.TestName; /** @@ -76,27 +73,18 @@ import org.junit.rules.TestName; public class CreateStreamTest implements Serializable { @Rule - public transient TemporaryFolder checkpointParentDir = new TemporaryFolder(); + public final transient PipelineRule pipelineRule = PipelineRule.streaming(); @Rule - public transient SparkTestPipelineOptionsForStreaming commonOptions = - new SparkTestPipelineOptionsForStreaming(); + public final transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no(); @Rule - public transient ReuseSparkContextRule noContextResue = ReuseSparkContextRule.no(); - @Rule - public transient TestName testName = new TestName(); - @Rule - public transient ExpectedException thrown = ExpectedException.none(); + public final transient ExpectedException thrown = ExpectedException.none(); @Test public void testLateDataAccumulating() throws IOException { - SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir); - Pipeline p = Pipeline.create(options); - options.setJobName(testName.getMethodName()); - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - + Pipeline p = pipelineRule.createPipeline(); Instant instant = new Instant(0); CreateStream<TimestampedValue<Integer>> source = - CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch() .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(6))) .nextBatch( @@ -167,13 +155,9 @@ public class CreateStreamTest implements Serializable { @Test public void testDiscardingMode() throws IOException { - SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir); - Pipeline p = Pipeline.create(options); - options.setJobName(testName.getMethodName()); - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - + Pipeline p = pipelineRule.createPipeline(); CreateStream<TimestampedValue<String>> source = - CreateStream.<TimestampedValue<String>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<String>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch( TimestampedValue.of("firstPane", new Instant(100)), TimestampedValue.of("alsoFirstPane", new Instant(200))) @@ -221,14 +205,10 @@ public class CreateStreamTest implements Serializable { @Test public void testFirstElementLate() throws IOException { - SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir); - Pipeline p = Pipeline.create(options); - options.setJobName(testName.getMethodName()); - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - + Pipeline p = pipelineRule.createPipeline(); Instant lateElementTimestamp = new Instant(-1_000_000); CreateStream<TimestampedValue<String>> source = - CreateStream.<TimestampedValue<String>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<String>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch() .advanceWatermarkForNextBatch(new Instant(0)) .nextBatch( @@ -261,14 +241,10 @@ public class CreateStreamTest implements Serializable { @Test public void testElementsAtAlmostPositiveInfinity() throws IOException { - SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir); - Pipeline p = Pipeline.create(options); - options.setJobName(testName.getMethodName()); - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - + Pipeline p = pipelineRule.createPipeline(); Instant endOfGlobalWindow = GlobalWindow.INSTANCE.maxTimestamp(); CreateStream<TimestampedValue<String>> source = - CreateStream.<TimestampedValue<String>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<String>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch( TimestampedValue.of("foo", endOfGlobalWindow), TimestampedValue.of("bar", endOfGlobalWindow)) @@ -292,17 +268,13 @@ public class CreateStreamTest implements Serializable { @Test public void testMultipleStreams() throws IOException { - SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir); - Pipeline p = Pipeline.create(options); - options.setJobName(testName.getMethodName()); - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - + Pipeline p = pipelineRule.createPipeline(); CreateStream<String> source = - CreateStream.<String>withBatchInterval(batchDuration) + CreateStream.<String>withBatchInterval(pipelineRule.batchDuration()) .nextBatch("foo", "bar") .advanceNextBatchWatermarkToInfinity(); CreateStream<Integer> other = - CreateStream.<Integer>withBatchInterval(batchDuration) + CreateStream.<Integer>withBatchInterval(pipelineRule.batchDuration()) .nextBatch(1, 2, 3, 4) .advanceNextBatchWatermarkToInfinity(); @@ -327,14 +299,10 @@ public class CreateStreamTest implements Serializable { @Test public void testFlattenedWithWatermarkHold() throws IOException { - SparkPipelineOptions options = commonOptions.withTmpCheckpointDir(checkpointParentDir); - Pipeline p = Pipeline.create(options); - options.setJobName(testName.getMethodName()); - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - + Pipeline p = pipelineRule.createPipeline(); Instant instant = new Instant(0); CreateStream<TimestampedValue<Integer>> source1 = - CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch() .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(5))) .nextBatch( @@ -343,7 +311,7 @@ public class CreateStreamTest implements Serializable { TimestampedValue.of(3, instant)) .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(10))); CreateStream<TimestampedValue<Integer>> source2 = - CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch() .advanceWatermarkForNextBatch(instant.plus(Duration.standardMinutes(1))) .nextBatch( @@ -384,9 +352,8 @@ public class CreateStreamTest implements Serializable { @Test public void testElementAtPositiveInfinityThrows() { - Duration batchDuration = Duration.millis(commonOptions.getOptions().getBatchIntervalMillis()); CreateStream<TimestampedValue<Integer>> source = - CreateStream.<TimestampedValue<Integer>>withBatchInterval(batchDuration) + CreateStream.<TimestampedValue<Integer>>withBatchInterval(pipelineRule.batchDuration()) .nextBatch(TimestampedValue.of(-1, BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L))); thrown.expect(IllegalArgumentException.class); source.nextBatch(TimestampedValue.of(1, BoundedWindow.TIMESTAMP_MAX_VALUE)); @@ -394,9 +361,8 @@ public class CreateStreamTest implements Serializable { @Test public void testAdvanceWatermarkNonMonotonicThrows() { - Duration batchDuration = Duration.millis(commonOptions.getOptions().getBatchIntervalMillis()); CreateStream<Integer> source = - CreateStream.<Integer>withBatchInterval(batchDuration) + CreateStream.<Integer>withBatchInterval(pipelineRule.batchDuration()) .advanceWatermarkForNextBatch(new Instant(0L)); thrown.expect(IllegalArgumentException.class); source.advanceWatermarkForNextBatch(new Instant(-1L)); @@ -404,9 +370,8 @@ public class CreateStreamTest implements Serializable { @Test public void testAdvanceWatermarkEqualToPositiveInfinityThrows() { - Duration batchDuration = Duration.millis(commonOptions.getOptions().getBatchIntervalMillis()); CreateStream<Integer> source = - CreateStream.<Integer>withBatchInterval(batchDuration) + CreateStream.<Integer>withBatchInterval(pipelineRule.batchDuration()) .advanceWatermarkForNextBatch(BoundedWindow.TIMESTAMP_MAX_VALUE.minus(1L)); thrown.expect(IllegalArgumentException.class); source.advanceWatermarkForNextBatch(BoundedWindow.TIMESTAMP_MAX_VALUE); http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java deleted file mode 100644 index efc17d3..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.translation.streaming.utils; - -import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.apache.beam.runners.spark.TestSparkRunner; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.junit.rules.ExternalResource; - -/** - * A rule to create a common {@link SparkPipelineOptions} test options for spark-runner. - */ -public class SparkTestPipelineOptions extends ExternalResource { - - protected final SparkPipelineOptions options = - PipelineOptionsFactory.as(SparkPipelineOptions.class); - - @Override - protected void before() throws Throwable { - options.setRunner(TestSparkRunner.class); - options.setEnableSparkMetricSinks(false); - } - - public SparkPipelineOptions getOptions() { - return options; - } -} http://git-wip-us.apache.org/repos/asf/beam/blob/4ca56806/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java deleted file mode 100644 index dd3e4c8..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.spark.translation.streaming.utils; - -import java.io.IOException; -import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.joda.time.Duration; -import org.junit.rules.TemporaryFolder; - - -/** - * A rule to create a common {@link SparkPipelineOptions} for testing streaming pipelines. - */ -public class SparkTestPipelineOptionsForStreaming extends SparkTestPipelineOptions { - - private static final int DEFAULT_NUMBER_OF_BATCHES_TIMEOUT = 5; - - public SparkPipelineOptions withTmpCheckpointDir(TemporaryFolder parent) - throws IOException { - // tests use JUnit's TemporaryFolder path in the form of: /.../junit/... - options.setCheckpointDir(parent.newFolder(options.getJobName()).toURI().toURL().toString()); - options.setForceStreaming(true); - // set the default timeout to DEFAULT_NUMBER_OF_BATCHES_TIMEOUT x batchDuration - // to allow pipelines to finish. - Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); - // set the checkpoint duration to match interval. - options.setCheckpointDurationMillis(batchDuration.getMillis()); - long forcedTimeout = batchDuration.multipliedBy(DEFAULT_NUMBER_OF_BATCHES_TIMEOUT).getMillis(); - options.setForcedTimeout(forcedTimeout); - return options; - } -}