Repository: beam
Updated Branches:
  refs/heads/master fe7fc298f -> c1b7f8695


[BEAM-647] Fault-tolerant sideInputs via Broadcast variables
Fix comments by Amit + rebase from master + checkstyle

Reformat + add unpersist on push


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/130c113e
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/130c113e
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/130c113e

Branch: refs/heads/master
Commit: 130c113e514c8bfe95ef69b68247eceac9301b17
Parents: fe7fc29
Author: ksalant <ksal...@payapal.com>
Authored: Thu Dec 15 19:42:47 2016 +0200
Committer: Sela <ans...@paypal.com>
Committed: Tue Jan 10 19:12:24 2017 +0200

----------------------------------------------------------------------
 .../spark/translation/EvaluationContext.java    |  31 ++++--
 .../spark/translation/SparkPCollectionView.java | 103 ++++++++++++++++++
 .../spark/translation/TransformTranslator.java  |  31 +++++-
 .../spark/translation/TranslationUtils.java     |  26 +++--
 .../streaming/StreamingTransformTranslator.java |  73 +++++++++----
 .../runners/spark/util/BroadcastHelper.java     | 107 ++++++-------------
 .../spark/util/SparkSideInputReader.java        |   1 +
 .../ResumeFromCheckpointStreamingTest.java      |  19 ++++
 8 files changed, 273 insertions(+), 118 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index ec5ad3d..b1a1142 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -55,8 +55,8 @@ public class EvaluationContext {
   private final Set<Dataset> leaves = new LinkedHashSet<>();
   private final Set<PValue> multiReads = new LinkedHashSet<>();
   private final Map<PValue, Object> pobjects = new LinkedHashMap<>();
-  private final Map<PValue, Iterable<? extends WindowedValue<?>>> pview = new 
LinkedHashMap<>();
   private AppliedPTransform<?, ?, ?> currentTransform;
+  private final SparkPCollectionView pviews = new SparkPCollectionView();
 
   public EvaluationContext(JavaSparkContext jsc, Pipeline pipeline) {
     this.jsc = jsc;
@@ -129,10 +129,6 @@ public class EvaluationContext {
     datasets.put((PValue) getOutput(transform), new UnboundedDataset<>(values, 
jssc, coder));
   }
 
-  void putPView(PValue view, Iterable<? extends WindowedValue<?>> value) {
-    pview.put(view, value);
-  }
-
   public Dataset borrowDataset(PTransform<?, ?> transform) {
     return borrowDataset((PValue) getInput(transform));
   }
@@ -149,10 +145,6 @@ public class EvaluationContext {
     return dataset;
   }
 
-  <T> Iterable<? extends WindowedValue<?>> 
getPCollectionView(PCollectionView<T> view) {
-    return pview.get(view);
-  }
-
   /**
    * Computes the outputs for all RDDs that are leaves in the DAG and do not 
have any actions (like
    * saving to a file) registered on them (i.e. they are performed for side 
effects).
@@ -199,6 +191,26 @@ public class EvaluationContext {
     return Iterables.transform(windowedValues, 
WindowingHelpers.<T>unwindowValueFunction());
   }
 
+  /**
+   * Retruns the current views creates in the pipepline.
+   * @return SparkPCollectionView
+   */
+  public SparkPCollectionView getPviews() {
+    return pviews;
+  }
+
+  /**
+   * Adds/Replaces a view to the current views creates in the pipepline.
+   * @param view - Identifier of the view
+   * @param value - Actual value of the view
+   * @param coder - Coder of the value
+   */
+  public void putPView(PCollectionView<?> view,
+      Iterable<WindowedValue<?>> value,
+      Coder<Iterable<WindowedValue<?>>> coder) {
+    pviews.putPView(view, value, coder, jsc);
+  }
+
   <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) 
{
     @SuppressWarnings("unchecked")
     BoundedDataset<T> boundedDataset = (BoundedDataset<T>) 
datasets.get(pcollection);
@@ -209,4 +221,5 @@ public class EvaluationContext {
   private String storageLevel() {
     return 
runtime.getPipelineOptions().as(SparkPipelineOptions.class).getStorageLevel();
   }
+
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
new file mode 100644
index 0000000..e888182
--- /dev/null
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkPCollectionView.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.translation;
+
+import java.io.Serializable;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
+import org.apache.beam.runners.spark.util.BroadcastHelper;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.spark.api.java.JavaSparkContext;
+import scala.Tuple2;
+
+/**
+ * SparkPCollectionView is used to pass serialized views to lambdas.
+ */
+public class SparkPCollectionView implements Serializable {
+
+    // Holds the view --> broadcast mapping. Transient so it will be null from 
resume
+    private transient volatile Map<PCollectionView<?>, BroadcastHelper> 
broadcastHelperMap = null;
+
+    // Holds the Actual data of the views in serialize form
+    private Map<PCollectionView<?>,
+        Tuple2<byte[], Coder<Iterable<WindowedValue<?>>>>> pviews =
+            new LinkedHashMap<>();
+
+    // Driver only - during evaluation stage
+    void putPView(
+        PCollectionView<?> view,
+        Iterable<WindowedValue<?>> value,
+        Coder<Iterable<WindowedValue<?>>> coder,
+        JavaSparkContext context) {
+
+        pviews.put(view, new Tuple2<>(CoderHelpers.toByteArray(value, coder), 
coder));
+        // overwrite/create broadcast - Future improvement is to initialize 
the BH lazily
+        getPCollectionView(view, context, true);
+    }
+
+    BroadcastHelper getPCollectionView(
+        PCollectionView<?> view,
+        JavaSparkContext context) {
+        return getPCollectionView(view, context, false);
+    }
+
+    private BroadcastHelper getPCollectionView(
+        PCollectionView<?> view,
+        JavaSparkContext context,
+        boolean overwrite) {
+        // initialize broadcastHelperMap if needed
+        if (broadcastHelperMap == null) {
+            synchronized (SparkPCollectionView.class) {
+                if (broadcastHelperMap == null) {
+                    broadcastHelperMap = new LinkedHashMap<>();
+                }
+            }
+        }
+
+        //lazily broadcast views
+        BroadcastHelper helper = broadcastHelperMap.get(view);
+        if (helper == null) {
+            synchronized (SparkPCollectionView.class) {
+                helper = broadcastHelperMap.get(view);
+                if (helper == null) {
+                    helper = createBroadcastHelper(view, context);
+                }
+            }
+        } else if (overwrite) {
+            synchronized (SparkPCollectionView.class) {
+                // Currently unsynchronized unpersist, if needed can be 
changed to blocking
+                helper.unpersist();
+                helper = createBroadcastHelper(view, context);
+            }
+        }
+        return helper;
+    }
+
+    private BroadcastHelper createBroadcastHelper(
+        PCollectionView<?> view,
+        JavaSparkContext context) {
+        Tuple2<byte[], Coder<Iterable<WindowedValue<?>>>> tuple2 = 
pviews.get(view);
+        BroadcastHelper helper = BroadcastHelper.create(tuple2._1, tuple2._2);
+        helper.broadcast(context);
+        broadcastHelperMap.put(view, helper);
+        return helper;
+    }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 5dd6beb..0cf3dc6 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -65,6 +65,7 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.NullWritable;
@@ -529,7 +530,15 @@ public final class TransformTranslator {
       public void evaluate(View.AsSingleton<T> transform, EvaluationContext 
context) {
         Iterable<? extends WindowedValue<?>> iter =
             context.getWindowedValues(context.getInput(transform));
-        context.putPView(context.getOutput(transform), iter);
+        PCollectionView<T> output = context.getOutput(transform);
+        Coder<Iterable<WindowedValue<?>>> coderInternal = 
output.getCoderInternal();
+
+        @SuppressWarnings("unchecked")
+        Iterable<WindowedValue<?>> iterCast =  (Iterable<WindowedValue<?>>) 
iter;
+
+        context.putPView(output,
+            iterCast,
+            coderInternal);
       }
     };
   }
@@ -540,7 +549,15 @@ public final class TransformTranslator {
       public void evaluate(View.AsIterable<T> transform, EvaluationContext 
context) {
         Iterable<? extends WindowedValue<?>> iter =
             context.getWindowedValues(context.getInput(transform));
-        context.putPView(context.getOutput(transform), iter);
+        PCollectionView<Iterable<T>> output = context.getOutput(transform);
+        Coder<Iterable<WindowedValue<?>>> coderInternal = 
output.getCoderInternal();
+
+        @SuppressWarnings("unchecked")
+        Iterable<WindowedValue<?>> iterCast =  (Iterable<WindowedValue<?>>) 
iter;
+
+        context.putPView(output,
+            iterCast,
+            coderInternal);
       }
     };
   }
@@ -553,7 +570,15 @@ public final class TransformTranslator {
                            EvaluationContext context) {
         Iterable<? extends WindowedValue<?>> iter =
             context.getWindowedValues(context.getInput(transform));
-        context.putPView(context.getOutput(transform), iter);
+        PCollectionView<WriteT> output = context.getOutput(transform);
+        Coder<Iterable<WindowedValue<?>>> coderInternal = 
output.getCoderInternal();
+
+        @SuppressWarnings("unchecked")
+        Iterable<WindowedValue<?>> iterCast =  (Iterable<WindowedValue<?>>) 
iter;
+
+        context.putPView(output,
+            iterCast,
+            coderInternal);
       }
     };
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
index eddc771..ae9cb3e 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
@@ -26,7 +26,6 @@ import java.util.List;
 import java.util.Map;
 import org.apache.beam.runners.spark.SparkRunner;
 import org.apache.beam.runners.spark.util.BroadcastHelper;
-import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
@@ -42,6 +41,7 @@ import org.apache.beam.sdk.util.state.StateInternalsFactory;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFunction;
@@ -191,8 +191,22 @@ public final class TranslationUtils {
    * @param context The {@link EvaluationContext}.
    * @return a map of tagged {@link BroadcastHelper}s and their {@link 
WindowingStrategy}.
    */
+  static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>>
+  getSideInputs(List<PCollectionView<?>> views, EvaluationContext context) {
+    return getSideInputs(views, context.getSparkContext(), 
context.getPviews());
+  }
+
+  /**
+   * Create SideInputs as Broadcast variables.
+   *
+   * @param views   The {@link PCollectionView}s.
+   * @param context The {@link JavaSparkContext}.
+   * @param pviews  The {@link SparkPCollectionView}.
+   * @return a map of tagged {@link BroadcastHelper}s and their {@link 
WindowingStrategy}.
+   */
   public static Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>>
-      getSideInputs(List<PCollectionView<?>> views, EvaluationContext context) 
{
+  getSideInputs(List<PCollectionView<?>> views, JavaSparkContext context,
+                SparkPCollectionView pviews) {
 
     if (views == null) {
       return ImmutableMap.of();
@@ -200,14 +214,8 @@ public final class TranslationUtils {
       Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> 
sideInputs =
           Maps.newHashMap();
       for (PCollectionView<?> view : views) {
-        Iterable<? extends WindowedValue<?>> collectionView = 
context.getPCollectionView(view);
-        Coder<Iterable<WindowedValue<?>>> coderInternal = 
view.getCoderInternal();
+        BroadcastHelper helper = pviews.getPCollectionView(view, context);
         WindowingStrategy<?, ?> windowingStrategy = 
view.getWindowingStrategyInternal();
-        @SuppressWarnings("unchecked")
-        BroadcastHelper<?> helper =
-            BroadcastHelper.create((Iterable<WindowedValue<?>>) 
collectionView, coderInternal);
-        //broadcast side inputs
-        helper.broadcast(context.getSparkContext());
         sideInputs.put(view.getTagInternal(),
             KV.<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>of(windowingStrategy, helper));
       }

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 070ccbb..0b2b4d6 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -37,6 +37,7 @@ import 
org.apache.beam.runners.spark.translation.GroupCombineFunctions;
 import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
 import org.apache.beam.runners.spark.translation.SparkAssignWindowFn;
 import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn;
+import org.apache.beam.runners.spark.translation.SparkPCollectionView;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
 import org.apache.beam.runners.spark.translation.TransformEvaluator;
@@ -238,12 +239,12 @@ final class StreamingTransformTranslator {
     return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() 
{
       @SuppressWarnings("unchecked")
       @Override
-      public void evaluate(Combine.GroupedValues<K, InputT, OutputT> transform,
+      public void evaluate(final Combine.GroupedValues<K, InputT, OutputT> 
transform,
                            EvaluationContext context) {
         // get the applied combine function.
         PCollection<? extends KV<K, ? extends Iterable<InputT>>> input =
             context.getInput(transform);
-        WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
+        final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
         final CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, 
OutputT> fn =
             (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, 
OutputT>)
                 CombineFnUtil.toFnWithContext(transform.getFn());
@@ -252,13 +253,27 @@ final class StreamingTransformTranslator {
             ((UnboundedDataset<KV<K, Iterable<InputT>>>) 
context.borrowDataset(transform))
                 .getDStream();
 
-        SparkKeyedCombineFn<K, InputT, ?, OutputT> combineFnWithContext =
-            new SparkKeyedCombineFn<>(fn, context.getRuntimeContext(),
-                TranslationUtils.getSideInputs(transform.getSideInputs(), 
context),
-                windowingStrategy);
-        context.putDataset(transform, new UnboundedDataset<>(dStream.map(new 
TranslationUtils
-            .CombineGroupedValues<>(
-            combineFnWithContext))));
+        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
+        final SparkPCollectionView pviews = context.getPviews();
+
+        JavaDStream<WindowedValue<KV<K, OutputT>>> outStream = 
dStream.transform(
+            new Function<JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>>,
+                         JavaRDD<WindowedValue<KV<K, OutputT>>>>() {
+                @Override
+                public JavaRDD<WindowedValue<KV<K, OutputT>>>
+                    call(JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> rdd)
+                        throws Exception {
+                        SparkKeyedCombineFn<K, InputT, ?, OutputT> 
combineFnWithContext =
+                            new SparkKeyedCombineFn<>(fn, runtimeContext,
+                                
TranslationUtils.getSideInputs(transform.getSideInputs(),
+                                new JavaSparkContext(rdd.context()), pviews),
+                                windowingStrategy);
+                    return rdd.map(
+                        new 
TranslationUtils.CombineGroupedValues<>(combineFnWithContext));
+                  }
+                });
+
+        context.putDataset(transform, new UnboundedDataset<>(outStream));
       }
     };
   }
@@ -269,7 +284,8 @@ final class StreamingTransformTranslator {
 
       @SuppressWarnings("unchecked")
       @Override
-      public void evaluate(Combine.Globally<InputT, OutputT> transform, 
EvaluationContext context) {
+      public void evaluate(final Combine.Globally<InputT, OutputT> transform,
+                           EvaluationContext context) {
         final PCollection<InputT> input = context.getInput(transform);
         // serializable arguments to pass.
         final Coder<InputT> iCoder = context.getInput(transform).getCoder();
@@ -279,9 +295,8 @@ final class StreamingTransformTranslator {
                 CombineFnUtil.toFnWithContext(transform.getFn());
         final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
-            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
         final boolean hasDefault = transform.isInsertDefault();
+        final SparkPCollectionView pviews = context.getPviews();
 
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) 
context.borrowDataset(transform)).getDStream();
@@ -291,6 +306,10 @@ final class StreamingTransformTranslator {
           @Override
           public JavaRDD<WindowedValue<OutputT>> 
call(JavaRDD<WindowedValue<InputT>> rdd)
               throws Exception {
+            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
+                TranslationUtils.getSideInputs(transform.getSideInputs(),
+                    JavaSparkContext.fromSparkContext(rdd.context()),
+                    pviews);
             return GroupCombineFunctions.combineGlobally(rdd, combineFn, 
iCoder, oCoder,
                 runtimeContext, windowingStrategy, sideInputs, hasDefault);
           }
@@ -317,8 +336,7 @@ final class StreamingTransformTranslator {
                 CombineFnUtil.toFnWithContext(transform.getFn());
         final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
-            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
+        final SparkPCollectionView pviews = context.getPviews();
 
         JavaDStream<WindowedValue<KV<K, InputT>>> dStream =
             ((UnboundedDataset<KV<K, InputT>>) 
context.borrowDataset(transform)).getDStream();
@@ -329,6 +347,10 @@ final class StreamingTransformTranslator {
           @Override
           public JavaRDD<WindowedValue<KV<K, OutputT>>> call(
               JavaRDD<WindowedValue<KV<K, InputT>>> rdd) throws Exception {
+            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
+                TranslationUtils.getSideInputs(transform.getSideInputs(),
+                    JavaSparkContext.fromSparkContext(rdd.context()),
+                    pviews);
             return GroupCombineFunctions.combinePerKey(rdd, combineFn, 
inputCoder, runtimeContext,
                 windowingStrategy, sideInputs);
           }
@@ -347,10 +369,10 @@ final class StreamingTransformTranslator {
         final DoFn<InputT, OutputT> doFn = transform.getFn();
         rejectStateAndTimers(doFn);
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
-            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
         final WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
+        final SparkPCollectionView pviews = context.getPviews();
+
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) 
context.borrowDataset(transform)).getDStream();
 
@@ -360,8 +382,14 @@ final class StreamingTransformTranslator {
           @Override
           public JavaRDD<WindowedValue<OutputT>> 
call(JavaRDD<WindowedValue<InputT>> rdd) throws
               Exception {
+            final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
+
             final Accumulator<NamedAggregators> accum =
-                SparkAggregators.getNamedAggregators(new 
JavaSparkContext(rdd.context()));
+                SparkAggregators.getNamedAggregators(jsc);
+
+            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
+                TranslationUtils.getSideInputs(transform.getSideInputs(),
+                    jsc, pviews);
             return rdd.mapPartitions(
                 new DoFnFunction<>(accum, doFn, runtimeContext, sideInputs, 
windowingStrategy));
           }
@@ -381,8 +409,7 @@ final class StreamingTransformTranslator {
         final DoFn<InputT, OutputT> doFn = transform.getFn();
         rejectStateAndTimers(doFn);
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
-            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
+        final SparkPCollectionView pviews = context.getPviews();
         final WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
         @SuppressWarnings("unchecked")
@@ -396,8 +423,12 @@ final class StreamingTransformTranslator {
               JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
             final Accumulator<NamedAggregators> accum =
                 SparkAggregators.getNamedAggregators(new 
JavaSparkContext(rdd.context()));
-            return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, doFn,
-                runtimeContext, transform.getMainOutputTag(), sideInputs, 
windowingStrategy));
+
+            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
+                TranslationUtils.getSideInputs(transform.getSideInputs(),
+                    JavaSparkContext.fromSparkContext(rdd.context()), pviews);
+              return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, 
doFn,
+                  runtimeContext, transform.getMainOutputTag(), sideInputs, 
windowingStrategy));
           }
         }).cache();
         PCollectionTuple pct = context.getOutput(transform);

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java
index 5c13b80..946f786 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/BroadcastHelper.java
@@ -21,7 +21,6 @@ package org.apache.beam.runners.spark.util;
 import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.Serializable;
-import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
@@ -31,7 +30,7 @@ import org.slf4j.LoggerFactory;
 /**
  * Broadcast helper.
  */
-public abstract class BroadcastHelper<T> implements Serializable {
+public class BroadcastHelper<T> implements Serializable {
 
   /**
    * If the property {@code beam.spark.directBroadcast} is set to
@@ -39,89 +38,45 @@ public abstract class BroadcastHelper<T> implements 
Serializable {
    * in View objects. By default this property is not set, and values are 
coded using
    * the appropriate {@link Coder}.
    */
-  public static final String DIRECT_BROADCAST = "beam.spark.directBroadcast";
-
   private static final Logger LOG = 
LoggerFactory.getLogger(BroadcastHelper.class);
-
-  public static <T> BroadcastHelper<T> create(T value, Coder<T> coder) {
-    if (Boolean.parseBoolean(System.getProperty(DIRECT_BROADCAST, "false"))) {
-      return new DirectBroadcastHelper<>(value);
-    }
-    return new CodedBroadcastHelper<>(value, coder);
+  private Broadcast<byte[]> bcast;
+  private final Coder<T> coder;
+  private transient T value;
+  private transient byte[] bytes = null;
+
+  private BroadcastHelper(byte[] bytes, Coder<T> coder) {
+    this.bytes = bytes;
+    this.coder = coder;
   }
 
-  public abstract T getValue();
-
-  public abstract void broadcast(JavaSparkContext jsc);
-
-  /**
-   * A {@link BroadcastHelper} that relies on the underlying
-   * Spark serialization (Kryo) to broadcast values. This is appropriate when
-   * broadcasting very large values, since no copy of the object is made.
-   * @param <T> the type of the value stored in the broadcast variable
-   */
-  static class DirectBroadcastHelper<T> extends BroadcastHelper<T> {
-    private Broadcast<T> bcast;
-    private transient T value;
-
-    DirectBroadcastHelper(T value) {
-      this.value = value;
-    }
-
-    @Override
-    public synchronized T getValue() {
-      if (value == null) {
-        value = bcast.getValue();
-      }
-      return value;
-    }
-
-    @Override
-    public void broadcast(JavaSparkContext jsc) {
-      this.bcast = jsc.broadcast(value);
-    }
+  public static <T> BroadcastHelper<T> create(byte[] bytes, Coder<T> coder) {
+    return new BroadcastHelper<>(bytes, coder);
   }
 
-  /**
-   * A {@link BroadcastHelper} that uses a
-   * {@link Coder} to encode values as byte arrays
-   * before broadcasting.
-   * @param <T> the type of the value stored in the broadcast variable
-   */
-  static class CodedBroadcastHelper<T> extends BroadcastHelper<T> {
-    private Broadcast<byte[]> bcast;
-    private final Coder<T> coder;
-    private transient T value;
-
-    CodedBroadcastHelper(T value, Coder<T> coder) {
-      this.value = value;
-      this.coder = coder;
+  public synchronized T getValue() {
+    if (value == null) {
+       value = deserialize();
     }
+    return value;
+  }
 
-    @Override
-    public synchronized T getValue() {
-      if (value == null) {
-        value = deserialize();
-      }
-      return value;
-    }
+  public void broadcast(JavaSparkContext jsc) {
+    this.bcast = jsc.broadcast(bytes);
+  }
 
-    @Override
-    public void broadcast(JavaSparkContext jsc) {
-      this.bcast = jsc.broadcast(CoderHelpers.toByteArray(value, coder));
-    }
+  public void unpersist() {
+    this.bcast.unpersist();
+  }
 
-    private T deserialize() {
-      T val;
-      try {
-        val = coder.decode(new ByteArrayInputStream(bcast.value()),
-            new Coder.Context(true));
-      } catch (IOException ioe) {
-        // this should not ever happen, log it if it does.
-        LOG.warn(ioe.getMessage());
-        val = null;
-      }
-      return val;
+  private T deserialize() {
+    T val;
+    try {
+      val = coder.decode(new ByteArrayInputStream(bcast.value()), new 
Coder.Context(true));
+    } catch (IOException ioe) {
+      // this should not ever happen, log it if it does.
+      LOG.warn(ioe.getMessage());
+      val = null;
     }
+    return val;
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java
index 0a804ae..8167ee0 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/SparkSideInputReader.java
@@ -61,6 +61,7 @@ public class SparkSideInputReader implements SideInputReader {
     //--- match the appropriate sideInput window.
     // a tag will point to all matching sideInputs, that is all windows.
     // now that we've obtained the appropriate sideInputWindow, all that's 
left is to filter by it.
+    @SuppressWarnings("unchecked")
     Iterable<WindowedValue<?>> availableSideInputs =
         (Iterable<WindowedValue<?>>) 
windowedBroadcastHelper.getValue().getValue();
     Iterable<WindowedValue<?>> sideInputForWindow =

http://git-wip-us.apache.org/repos/asf/beam/blob/130c113e/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
index ab04c5c..352a7d8 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/ResumeFromCheckpointStreamingTest.java
@@ -23,7 +23,9 @@ import static org.junit.Assert.assertThat;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.Uninterruptibles;
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 import java.util.Map;
 import java.util.Properties;
 import java.util.concurrent.TimeUnit;
@@ -37,19 +39,23 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.kafka.KafkaIO;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.kafka.clients.producer.KafkaProducer;
 import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.serialization.Serializer;
 import org.apache.kafka.common.serialization.StringSerializer;
 import org.joda.time.Duration;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
@@ -163,8 +169,21 @@ public class ResumeFromCheckpointStreamingTest {
     Duration windowDuration = new Duration(options.getBatchIntervalMillis());
 
     Pipeline p = Pipeline.create(options);
+
+    PCollection<String> expectedCol = 
p.apply(Create.of(EXPECTED).withCoder(StringUtf8Coder.of()));
+    final PCollectionView<List<String>> expectedView = 
expectedCol.apply(View.<String>asList());
+
     PCollection<String> formattedKV =
         p.apply(read.withoutMetadata())
+          .apply(ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
+               @ProcessElement
+               public void process(ProcessContext c) {
+
+                  // Check side input is passed correctly
+                  Assert.assertEquals(c.sideInput(expectedView), 
Arrays.asList(EXPECTED));
+                  c.output(c.element());
+                }
+          }).withSideInputs(expectedView))
         .apply(Window.<KV<String, 
String>>into(FixedWindows.of(windowDuration)))
         .apply(ParDo.of(new FormatAsText()));
 

Reply via email to