Repository: beam Updated Branches: refs/heads/master 5e1be9fa7 -> 9ac1ffcea
[BEAM-1074] Set default-partitioner in SourceRDD.Unbounded Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/623a5696 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/623a5696 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/623a5696 Branch: refs/heads/master Commit: 623a5696bc328a9a55bf5de67ad0070a985c96ee Parents: 5e1be9f Author: Aviem Zur <aviem...@gmail.com> Authored: Wed Mar 22 15:20:51 2017 +0200 Committer: Aviem Zur <aviem...@gmail.com> Committed: Thu Mar 23 16:18:16 2017 +0200 ---------------------------------------------------------------------- .../spark/SparkNativePipelineVisitor.java | 1 - .../beam/runners/spark/io/SourceDStream.java | 52 +++++++++++++++----- .../apache/beam/runners/spark/io/SourceRDD.java | 19 +++++-- .../runners/spark/io/SparkUnboundedSource.java | 15 +++--- 4 files changed, 63 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java index c2784a2..c2d38d7 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java @@ -92,7 +92,6 @@ public class SparkNativePipelineVisitor extends SparkRunner.Evaluator { @Override <TransformT extends PTransform<? super PInput, POutput>> void doVisitTransform(TransformHierarchy.Node node) { - super.doVisitTransform(node); @SuppressWarnings("unchecked") TransformT transform = (TransformT) node.getTransform(); @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java index 8a0763b..3f2c10a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java @@ -28,6 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.rdd.RDD; import org.apache.spark.streaming.StreamingContext; import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.dstream.InputDStream; import org.apache.spark.streaming.scheduler.RateController; import org.apache.spark.streaming.scheduler.RateController$; @@ -36,7 +37,6 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator$; import org.joda.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import scala.Tuple2; @@ -60,6 +60,9 @@ class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark> private final UnboundedSource<T, CheckpointMarkT> unboundedSource; private final SparkRuntimeContext runtimeContext; private final Duration boundReadDuration; + // Number of partitions for the DStream is final and remains the same throughout the entire + // lifetime of the pipeline, including when resuming from checkpoint. + private final int numPartitions; // the initial parallelism, set by Spark's backend, will be determined once when the job starts. // in case of resuming/recovering from checkpoint, the DStream will be reconstructed and this // property should not be reset. @@ -67,40 +70,55 @@ class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark> // the bound on max records is optional. // in case it is set explicitly via PipelineOptions, it takes precedence // otherwise it could be activated via RateController. - private Long boundMaxRecords = null; + private final long boundMaxRecords; SourceDStream( StreamingContext ssc, UnboundedSource<T, CheckpointMarkT> unboundedSource, - SparkRuntimeContext runtimeContext) { - + SparkRuntimeContext runtimeContext, + Long boundMaxRecords) { super(ssc, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag()); this.unboundedSource = unboundedSource; this.runtimeContext = runtimeContext; + SparkPipelineOptions options = runtimeContext.getPipelineOptions().as( SparkPipelineOptions.class); + this.boundReadDuration = boundReadDuration(options.getReadTimePercentage(), options.getMinReadTimeMillis()); // set initial parallelism once. this.initialParallelism = ssc().sc().defaultParallelism(); checkArgument(this.initialParallelism > 0, "Number of partitions must be greater than zero."); - } - public void setMaxRecordsPerBatch(long maxRecordsPerBatch) { - boundMaxRecords = maxRecordsPerBatch; + this.boundMaxRecords = boundMaxRecords > 0 ? boundMaxRecords : rateControlledMaxRecords(); + + try { + this.numPartitions = + createMicrobatchSource() + .splitIntoBundles(initialParallelism, options) + .size(); + } catch (Exception e) { + throw new RuntimeException(e); + } } @Override public scala.Option<RDD<Tuple2<Source<T>, CheckpointMarkT>>> compute(Time validTime) { - long maxNumRecords = boundMaxRecords != null ? boundMaxRecords : rateControlledMaxRecords(); - MicrobatchSource<T, CheckpointMarkT> microbatchSource = new MicrobatchSource<>( - unboundedSource, boundReadDuration, initialParallelism, maxNumRecords, -1, - id()); - RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd = new SourceRDD.Unbounded<>( - ssc().sc(), runtimeContext, microbatchSource); + RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd = + new SourceRDD.Unbounded<>( + ssc().sc(), + runtimeContext, + createMicrobatchSource(), + numPartitions); return scala.Option.apply(rdd); } + + private MicrobatchSource<T, CheckpointMarkT> createMicrobatchSource() { + return new MicrobatchSource<>(unboundedSource, boundReadDuration, initialParallelism, + boundMaxRecords, -1, id()); + } + @Override public void start() { } @@ -112,6 +130,14 @@ class SourceDStream<T, CheckpointMarkT extends UnboundedSource.CheckpointMark> return "Beam UnboundedSource [" + id() + "]"; } + /** + * Number of partitions is exposed so clients of {@link SourceDStream} can use this to set + * appropriate partitioning for operations such as {@link JavaPairDStream#mapWithState}. + */ + int getNumPartitions() { + return numPartitions; + } + //---- Bound by time. // return the largest between the proportional read time (%batchDuration dedicated for read) http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java index cf37b3a..1a3537f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java @@ -30,15 +30,17 @@ import org.apache.beam.sdk.io.Source; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.util.WindowedValue; import org.apache.spark.Dependency; +import org.apache.spark.HashPartitioner; import org.apache.spark.InterruptibleIterator; import org.apache.spark.Partition; +import org.apache.spark.Partitioner; import org.apache.spark.SparkContext; import org.apache.spark.TaskContext; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.rdd.RDD; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - +import scala.Option; /** @@ -213,8 +215,10 @@ public class SourceRDD { */ public static class Unbounded<T, CheckpointMarkT extends UnboundedSource.CheckpointMark> extends RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> { + private final MicrobatchSource<T, CheckpointMarkT> microbatchSource; private final SparkRuntimeContext runtimeContext; + private final Partitioner partitioner; // to satisfy Scala API. private static final scala.collection.immutable.List<Dependency<?>> NIL = @@ -222,12 +226,14 @@ public class SourceRDD { .asScalaBuffer(Collections.<Dependency<?>>emptyList()).toList(); public Unbounded(SparkContext sc, - SparkRuntimeContext runtimeContext, - MicrobatchSource<T, CheckpointMarkT> microbatchSource) { + SparkRuntimeContext runtimeContext, + MicrobatchSource<T, CheckpointMarkT> microbatchSource, + int initialNumPartitions) { super(sc, NIL, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, CheckpointMarkT>>fakeClassTag()); this.runtimeContext = runtimeContext; this.microbatchSource = microbatchSource; + this.partitioner = new HashPartitioner(initialNumPartitions); } @Override @@ -247,6 +253,13 @@ public class SourceRDD { } @Override + public Option<Partitioner> partitioner() { + // setting the partitioner helps to "keep" the same partitioner in the following + // mapWithState read for Read.Unbounded, preventing a post-mapWithState shuffle. + return scala.Some.apply(partitioner); + } + + @Override public scala.collection.Iterator<scala.Tuple2<Source<T>, CheckpointMarkT>> compute(Partition split, TaskContext context) { @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java index e5bbaf1..6c047ac 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java @@ -77,11 +77,9 @@ public class SparkUnboundedSource { SparkPipelineOptions options = rc.getPipelineOptions().as(SparkPipelineOptions.class); Long maxRecordsPerBatch = options.getMaxRecordsPerBatch(); - SourceDStream<T, CheckpointMarkT> sourceDStream = new SourceDStream<>(jssc.ssc(), source, rc); - // if max records per batch was set by the user. - if (maxRecordsPerBatch > 0) { - sourceDStream.setMaxRecordsPerBatch(maxRecordsPerBatch); - } + SourceDStream<T, CheckpointMarkT> sourceDStream = + new SourceDStream<>(jssc.ssc(), source, rc, maxRecordsPerBatch); + JavaPairInputDStream<Source<T>, CheckpointMarkT> inputDStream = JavaPairInputDStream$.MODULE$.fromInputDStream(sourceDStream, JavaSparkContext$.MODULE$.<Source<T>>fakeClassTag(), @@ -89,8 +87,11 @@ public class SparkUnboundedSource { // call mapWithState to read from a checkpointable sources. JavaMapWithStateDStream<Source<T>, CheckpointMarkT, Tuple2<byte[], Instant>, - Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream = inputDStream.mapWithState( - StateSpec.function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc))); + Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream = + inputDStream.mapWithState( + StateSpec + .function(StateSpecFunctions.<T, CheckpointMarkT>mapSourceFunction(rc)) + .numPartitions(sourceDStream.getNumPartitions())); // set checkpoint duration for read stream, if set. checkpointStream(mapWithStateDStream, options);