Use TestSparkRunner in tests.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/2bcd40c2 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/2bcd40c2 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/2bcd40c2 Branch: refs/heads/master Commit: 2bcd40c206d6b615135375276b249364fedf32af Parents: c25a02f Author: Sela <ans...@paypal.com> Authored: Sat Feb 18 22:10:50 2017 +0200 Committer: Sela <ans...@paypal.com> Committed: Wed Mar 1 00:17:59 2017 +0200 ---------------------------------------------------------------------- .../beam/runners/spark/TestSparkRunner.java | 92 ++++++++++++-------- .../utils/SparkTestPipelineOptions.java | 4 +- .../SparkTestPipelineOptionsForStreaming.java | 11 +++ 3 files changed, 69 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/2bcd40c2/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java index 8b8f9ba..24bc038 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java @@ -21,12 +21,12 @@ package org.apache.beam.runners.spark; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; +import java.io.File; +import java.io.IOException; import org.apache.beam.runners.core.UnboundedReadFromBoundedSource; -import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; -import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; +import org.apache.beam.runners.spark.aggregators.AccumulatorSingleton; import org.apache.beam.runners.spark.util.GlobalWatermarkHolder; import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.PipelineResult.State; import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsValidator; @@ -40,6 +40,10 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.POutput; +import org.apache.commons.io.FileUtils; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** @@ -65,6 +69,8 @@ import org.apache.beam.sdk.values.POutput; */ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { + private static final Logger LOG = LoggerFactory.getLogger(TestSparkRunner.class); + private SparkRunner delegate; private boolean isForceStreaming; private int expectedNumberOfAssertions = 0; @@ -87,7 +93,7 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { @SuppressWarnings("unchecked") @Override public <OutputT extends POutput, InputT extends PInput> OutputT apply( - PTransform<InputT, OutputT> transform, InputT input) { + PTransform<InputT, OutputT> transform, InputT input) { // if the pipeline forces execution as a streaming pipeline, // and the source is an adapted unbounded source (as bounded), // read it as unbounded source via UnboundedReadFromBoundedSource. @@ -108,38 +114,52 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> { @Override public SparkPipelineResult run(Pipeline pipeline) { - // clear state of Aggregators, Metrics and Watermarks if exists. - AggregatorsAccumulator.clear(); - SparkMetricsContainer.clear(); - GlobalWatermarkHolder.clear(); - - TestPipelineOptions testPipelineOptions = pipeline.getOptions().as(TestPipelineOptions.class); - SparkPipelineResult result = delegate.run(pipeline); - result.waitUntilFinish(); - - - // make sure the test pipeline finished successfully. - State resultState = result.getState(); - assertThat( - String.format("Test pipeline result state was %s instead of %s", resultState, State.DONE), - resultState, - is(State.DONE)); - assertThat(result, testPipelineOptions.getOnCreateMatcher()); - assertThat(result, testPipelineOptions.getOnSuccessMatcher()); - - // if the pipeline was executed in streaming mode, validate aggregators. - if (isForceStreaming) { - // validate assertion succeeded (at least once). - int success = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); - assertThat( - String.format( - "Expected %d successful assertions, but found %d.", - expectedNumberOfAssertions, success), - success, - is(expectedNumberOfAssertions)); - // validate assertion didn't fail. - int failure = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class); - assertThat("Failure aggregator should be zero.", failure, is(0)); + SparkPipelineOptions sparkOptions = pipeline.getOptions().as(SparkPipelineOptions.class); + long timeout = sparkOptions.getForcedTimeout(); + SparkPipelineResult result = null; + try { + // clear state of Accumulators and Aggregators. + AccumulatorSingleton.clear(); + GlobalWatermarkHolder.clear(); + + TestPipelineOptions testPipelineOptions = pipeline.getOptions().as(TestPipelineOptions.class); + LOG.info("About to run test pipeline " + sparkOptions.getJobName()); + result = delegate.run(pipeline); + result.waitUntilFinish(Duration.millis(timeout)); + + assertThat(result, testPipelineOptions.getOnCreateMatcher()); + assertThat(result, testPipelineOptions.getOnSuccessMatcher()); + + // if the pipeline was executed in streaming mode, validate aggregators. + if (isForceStreaming) { + // validate assertion succeeded (at least once). + int successAssertions = result.getAggregatorValue(PAssert.SUCCESS_COUNTER, Integer.class); + assertThat( + String.format( + "Expected %d successful assertions, but found %d.", + expectedNumberOfAssertions, successAssertions), + successAssertions, + is(expectedNumberOfAssertions)); + // validate assertion didn't fail. + int failedAssertions = result.getAggregatorValue(PAssert.FAILURE_COUNTER, Integer.class); + assertThat( + String.format("Found %d failed assertions.", failedAssertions), + failedAssertions, + is(0)); + + LOG.info( + String.format( + "Successfully asserted pipeline %s with %d successful assertions.", + sparkOptions.getJobName(), + successAssertions)); + } + } finally { + try { + // cleanup checkpoint dir. + FileUtils.deleteDirectory(new File(sparkOptions.getCheckpointDir())); + } catch (IOException e) { + throw new RuntimeException("Failed to clear checkpoint tmp dir.", e); + } } return result; } http://git-wip-us.apache.org/repos/asf/beam/blob/2bcd40c2/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java index 2da9888..efc17d3 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptions.java @@ -18,7 +18,7 @@ package org.apache.beam.runners.spark.translation.streaming.utils; import org.apache.beam.runners.spark.SparkPipelineOptions; -import org.apache.beam.runners.spark.SparkRunner; +import org.apache.beam.runners.spark.TestSparkRunner; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.junit.rules.ExternalResource; @@ -32,7 +32,7 @@ public class SparkTestPipelineOptions extends ExternalResource { @Override protected void before() throws Throwable { - options.setRunner(SparkRunner.class); + options.setRunner(TestSparkRunner.class); options.setEnableSparkMetricSinks(false); } http://git-wip-us.apache.org/repos/asf/beam/blob/2bcd40c2/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java index 28f6d5d..dd3e4c8 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/SparkTestPipelineOptionsForStreaming.java @@ -20,6 +20,7 @@ package org.apache.beam.runners.spark.translation.streaming.utils; import java.io.IOException; import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.joda.time.Duration; import org.junit.rules.TemporaryFolder; @@ -28,10 +29,20 @@ import org.junit.rules.TemporaryFolder; */ public class SparkTestPipelineOptionsForStreaming extends SparkTestPipelineOptions { + private static final int DEFAULT_NUMBER_OF_BATCHES_TIMEOUT = 5; + public SparkPipelineOptions withTmpCheckpointDir(TemporaryFolder parent) throws IOException { // tests use JUnit's TemporaryFolder path in the form of: /.../junit/... options.setCheckpointDir(parent.newFolder(options.getJobName()).toURI().toURL().toString()); + options.setForceStreaming(true); + // set the default timeout to DEFAULT_NUMBER_OF_BATCHES_TIMEOUT x batchDuration + // to allow pipelines to finish. + Duration batchDuration = Duration.millis(options.getBatchIntervalMillis()); + // set the checkpoint duration to match interval. + options.setCheckpointDurationMillis(batchDuration.getMillis()); + long forcedTimeout = batchDuration.multipliedBy(DEFAULT_NUMBER_OF_BATCHES_TIMEOUT).getMillis(); + options.setForcedTimeout(forcedTimeout); return options; } }