Repository: beam
Updated Branches:
  refs/heads/master 5e1be9fa7 -> 9ac1ffcea


[BEAM-1074] Set default-partitioner in SourceRDD.Unbounded


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/623a5696
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/623a5696
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/623a5696

Branch: refs/heads/master
Commit: 623a5696bc328a9a55bf5de67ad0070a985c96ee
Parents: 5e1be9f
Author: Aviem Zur <aviem...@gmail.com>
Authored: Wed Mar 22 15:20:51 2017 +0200
Committer: Aviem Zur <aviem...@gmail.com>
Committed: Thu Mar 23 16:18:16 2017 +0200

----------------------------------------------------------------------
 .../spark/SparkNativePipelineVisitor.java       |  1 -
 .../beam/runners/spark/io/SourceDStream.java    | 52 +++++++++++++++-----
 .../apache/beam/runners/spark/io/SourceRDD.java | 19 +++++--
 .../runners/spark/io/SparkUnboundedSource.java  | 15 +++---
 4 files changed, 63 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java
index c2784a2..c2d38d7 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkNativePipelineVisitor.java
@@ -92,7 +92,6 @@ public class SparkNativePipelineVisitor extends 
SparkRunner.Evaluator {
   @Override
   <TransformT extends PTransform<? super PInput, POutput>> void
   doVisitTransform(TransformHierarchy.Node node) {
-    super.doVisitTransform(node);
     @SuppressWarnings("unchecked")
     TransformT transform = (TransformT) node.getTransform();
     @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
index 8a0763b..3f2c10a 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceDStream.java
@@ -28,6 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext$;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.streaming.StreamingContext;
 import org.apache.spark.streaming.Time;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
 import org.apache.spark.streaming.dstream.InputDStream;
 import org.apache.spark.streaming.scheduler.RateController;
 import org.apache.spark.streaming.scheduler.RateController$;
@@ -36,7 +37,6 @@ import 
org.apache.spark.streaming.scheduler.rate.RateEstimator$;
 import org.joda.time.Duration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-
 import scala.Tuple2;
 
 
@@ -60,6 +60,9 @@ class SourceDStream<T, CheckpointMarkT extends 
UnboundedSource.CheckpointMark>
   private final UnboundedSource<T, CheckpointMarkT> unboundedSource;
   private final SparkRuntimeContext runtimeContext;
   private final Duration boundReadDuration;
+  // Number of partitions for the DStream is final and remains the same 
throughout the entire
+  // lifetime of the pipeline, including when resuming from checkpoint.
+  private final int numPartitions;
   // the initial parallelism, set by Spark's backend, will be determined once 
when the job starts.
   // in case of resuming/recovering from checkpoint, the DStream will be 
reconstructed and this
   // property should not be reset.
@@ -67,40 +70,55 @@ class SourceDStream<T, CheckpointMarkT extends 
UnboundedSource.CheckpointMark>
   // the bound on max records is optional.
   // in case it is set explicitly via PipelineOptions, it takes precedence
   // otherwise it could be activated via RateController.
-  private Long boundMaxRecords = null;
+  private final long boundMaxRecords;
 
   SourceDStream(
       StreamingContext ssc,
       UnboundedSource<T, CheckpointMarkT> unboundedSource,
-      SparkRuntimeContext runtimeContext) {
-
+      SparkRuntimeContext runtimeContext,
+      Long boundMaxRecords) {
     super(ssc, JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, 
CheckpointMarkT>>fakeClassTag());
     this.unboundedSource = unboundedSource;
     this.runtimeContext = runtimeContext;
+
     SparkPipelineOptions options = runtimeContext.getPipelineOptions().as(
         SparkPipelineOptions.class);
+
     this.boundReadDuration = boundReadDuration(options.getReadTimePercentage(),
         options.getMinReadTimeMillis());
     // set initial parallelism once.
     this.initialParallelism = ssc().sc().defaultParallelism();
     checkArgument(this.initialParallelism > 0, "Number of partitions must be 
greater than zero.");
-  }
 
-  public void setMaxRecordsPerBatch(long maxRecordsPerBatch) {
-    boundMaxRecords = maxRecordsPerBatch;
+    this.boundMaxRecords = boundMaxRecords > 0 ? boundMaxRecords : 
rateControlledMaxRecords();
+
+    try {
+      this.numPartitions =
+          createMicrobatchSource()
+              .splitIntoBundles(initialParallelism, options)
+              .size();
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
   }
 
   @Override
   public scala.Option<RDD<Tuple2<Source<T>, CheckpointMarkT>>> compute(Time 
validTime) {
-    long maxNumRecords = boundMaxRecords != null ? boundMaxRecords : 
rateControlledMaxRecords();
-    MicrobatchSource<T, CheckpointMarkT> microbatchSource = new 
MicrobatchSource<>(
-        unboundedSource, boundReadDuration, initialParallelism, maxNumRecords, 
-1,
-        id());
-    RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd = new 
SourceRDD.Unbounded<>(
-        ssc().sc(), runtimeContext, microbatchSource);
+    RDD<scala.Tuple2<Source<T>, CheckpointMarkT>> rdd =
+        new SourceRDD.Unbounded<>(
+            ssc().sc(),
+            runtimeContext,
+            createMicrobatchSource(),
+            numPartitions);
     return scala.Option.apply(rdd);
   }
 
+
+  private MicrobatchSource<T, CheckpointMarkT> createMicrobatchSource() {
+    return new MicrobatchSource<>(unboundedSource, boundReadDuration, 
initialParallelism,
+        boundMaxRecords, -1, id());
+  }
+
   @Override
   public void start() { }
 
@@ -112,6 +130,14 @@ class SourceDStream<T, CheckpointMarkT extends 
UnboundedSource.CheckpointMark>
     return "Beam UnboundedSource [" + id() + "]";
   }
 
+  /**
+   * Number of partitions is exposed so clients of {@link SourceDStream} can 
use this to set
+   * appropriate partitioning for operations such as {@link 
JavaPairDStream#mapWithState}.
+   */
+  int getNumPartitions() {
+    return numPartitions;
+  }
+
   //---- Bound by time.
 
   // return the largest between the proportional read time (%batchDuration 
dedicated for read)

http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
index cf37b3a..1a3537f 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SourceRDD.java
@@ -30,15 +30,17 @@ import org.apache.beam.sdk.io.Source;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.spark.Dependency;
+import org.apache.spark.HashPartitioner;
 import org.apache.spark.InterruptibleIterator;
 import org.apache.spark.Partition;
+import org.apache.spark.Partitioner;
 import org.apache.spark.SparkContext;
 import org.apache.spark.TaskContext;
 import org.apache.spark.api.java.JavaSparkContext$;
 import org.apache.spark.rdd.RDD;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-
+import scala.Option;
 
 
 /**
@@ -213,8 +215,10 @@ public class SourceRDD {
    */
   public static class Unbounded<T, CheckpointMarkT extends
         UnboundedSource.CheckpointMark> extends RDD<scala.Tuple2<Source<T>, 
CheckpointMarkT>> {
+
     private final MicrobatchSource<T, CheckpointMarkT> microbatchSource;
     private final SparkRuntimeContext runtimeContext;
+    private final Partitioner partitioner;
 
     // to satisfy Scala API.
     private static final scala.collection.immutable.List<Dependency<?>> NIL =
@@ -222,12 +226,14 @@ public class SourceRDD {
             .asScalaBuffer(Collections.<Dependency<?>>emptyList()).toList();
 
     public Unbounded(SparkContext sc,
-                     SparkRuntimeContext runtimeContext,
-                     MicrobatchSource<T, CheckpointMarkT> microbatchSource) {
+        SparkRuntimeContext runtimeContext,
+        MicrobatchSource<T, CheckpointMarkT> microbatchSource,
+        int initialNumPartitions) {
       super(sc, NIL,
           JavaSparkContext$.MODULE$.<scala.Tuple2<Source<T>, 
CheckpointMarkT>>fakeClassTag());
       this.runtimeContext = runtimeContext;
       this.microbatchSource = microbatchSource;
+      this.partitioner = new HashPartitioner(initialNumPartitions);
     }
 
     @Override
@@ -247,6 +253,13 @@ public class SourceRDD {
     }
 
     @Override
+    public Option<Partitioner> partitioner() {
+      // setting the partitioner helps to "keep" the same partitioner in the 
following
+      // mapWithState read for Read.Unbounded, preventing a post-mapWithState 
shuffle.
+      return scala.Some.apply(partitioner);
+    }
+
+    @Override
     public scala.collection.Iterator<scala.Tuple2<Source<T>, CheckpointMarkT>>
     compute(Partition split, TaskContext context) {
       @SuppressWarnings("unchecked")

http://git-wip-us.apache.org/repos/asf/beam/blob/623a5696/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
index e5bbaf1..6c047ac 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/io/SparkUnboundedSource.java
@@ -77,11 +77,9 @@ public class SparkUnboundedSource {
 
     SparkPipelineOptions options = 
rc.getPipelineOptions().as(SparkPipelineOptions.class);
     Long maxRecordsPerBatch = options.getMaxRecordsPerBatch();
-    SourceDStream<T, CheckpointMarkT> sourceDStream = new 
SourceDStream<>(jssc.ssc(), source, rc);
-    // if max records per batch was set by the user.
-    if (maxRecordsPerBatch > 0) {
-      sourceDStream.setMaxRecordsPerBatch(maxRecordsPerBatch);
-    }
+    SourceDStream<T, CheckpointMarkT> sourceDStream =
+        new SourceDStream<>(jssc.ssc(), source, rc, maxRecordsPerBatch);
+
     JavaPairInputDStream<Source<T>, CheckpointMarkT> inputDStream =
         JavaPairInputDStream$.MODULE$.fromInputDStream(sourceDStream,
             JavaSparkContext$.MODULE$.<Source<T>>fakeClassTag(),
@@ -89,8 +87,11 @@ public class SparkUnboundedSource {
 
     // call mapWithState to read from a checkpointable sources.
     JavaMapWithStateDStream<Source<T>, CheckpointMarkT, Tuple2<byte[], 
Instant>,
-        Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream = 
inputDStream.mapWithState(
-            StateSpec.function(StateSpecFunctions.<T, 
CheckpointMarkT>mapSourceFunction(rc)));
+        Tuple2<Iterable<byte[]>, Metadata>> mapWithStateDStream =
+        inputDStream.mapWithState(
+            StateSpec
+                .function(StateSpecFunctions.<T, 
CheckpointMarkT>mapSourceFunction(rc))
+                .numPartitions(sourceDStream.getNumPartitions()));
 
     // set checkpoint duration for read stream, if set.
     checkpointStream(mapWithStateDStream, options);

Reply via email to