Repository: beam
Updated Branches:
  refs/heads/master 80aebd902 -> a481d5611


[BEAM-2669] Fixed Kryo serialization exception when dstream is cached (by using 
coders and moving to bytes before attempting to serialise an RDD as part of 
caching it).


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ffd08dae
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ffd08dae
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ffd08dae

Branch: refs/heads/master
Commit: ffd08dae0d1a6fcde438ae4e9c2a348eb2a5d493
Parents: 80aebd9
Author: ksalant <kobi.sal...@gmail.com>
Authored: Wed Aug 23 14:54:46 2017 +0300
Committer: Stas Levin <stasle...@apache.org>
Committed: Sun Sep 3 09:03:28 2017 +0300

----------------------------------------------------------------------
 .../SparkGroupAlsoByWindowViaWindowSet.java     | 15 ++--
 .../spark/translation/BoundedDataset.java       | 17 ++++-
 .../beam/runners/spark/translation/Dataset.java |  3 +-
 .../spark/translation/EvaluationContext.java    | 23 ++++--
 .../spark/translation/SparkContextFactory.java  |  2 -
 .../translation/StorageLevelPTransform.java     | 37 ----------
 .../spark/translation/TransformTranslator.java  | 53 +++++--------
 .../spark/translation/TranslationUtils.java     | 78 ++++++++++++++++++++
 .../streaming/StreamingTransformTranslator.java | 15 +++-
 .../translation/streaming/UnboundedDataset.java | 27 +++++--
 .../spark/translation/StorageLevelTest.java     | 75 -------------------
 11 files changed, 166 insertions(+), 179 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index 52f7376..e6a55a6 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -96,7 +96,7 @@ import scala.runtime.AbstractFunction1;
  * a (state, output) tuple is used, filtering the state (and output if no 
firing)
  * in the following steps.
  */
-public class SparkGroupAlsoByWindowViaWindowSet {
+public class SparkGroupAlsoByWindowViaWindowSet implements Serializable {
   private static final Logger LOG = LoggerFactory.getLogger(
       SparkGroupAlsoByWindowViaWindowSet.class);
 
@@ -226,8 +226,6 @@ public class SparkGroupAlsoByWindowViaWindowSet {
         final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> 
reduceFn =
             SystemReduceFn.buffering(
                 ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder());
-        final OutputWindowedValueHolder<K, InputT> outputHolder =
-            new OutputWindowedValueHolder<>();
         // use in memory Aggregators since Spark Accumulators are not resilient
         // in stateful operators, once done with this partition.
         final MetricsContainerImpl cellProvider = new 
MetricsContainerImpl("cellProvider");
@@ -280,6 +278,9 @@ public class SparkGroupAlsoByWindowViaWindowSet {
                             SparkTimerInternals.deserializeTimers(serTimers, 
timerDataCoder));
                       }
 
+                      final OutputWindowedValueHolder<K, InputT> outputHolder =
+                          new OutputWindowedValueHolder<>();
+
                       ReduceFnRunner<K, InputT, Iterable<InputT>, W> 
reduceFnRunner =
                           new ReduceFnRunner<>(
                               key,
@@ -294,8 +295,6 @@ public class SparkGroupAlsoByWindowViaWindowSet {
                               reduceFn,
                               options.get());
 
-                      outputHolder.clear(); // clear before potential use.
-
                       if (!seq.isEmpty()) {
                         // new input for key.
                         try {
@@ -457,7 +456,7 @@ public class SparkGroupAlsoByWindowViaWindowSet {
         });
   }
 
-  private static class StateAndTimers {
+  private static class StateAndTimers implements Serializable {
     //Serializable state for internals (namespace to state tag to coded value).
     private final Table<String, String, byte[]> state;
     private final Collection<byte[]> serTimers;
@@ -494,10 +493,6 @@ public class SparkGroupAlsoByWindowViaWindowSet {
       return windowedValues;
     }
 
-    private void clear() {
-      windowedValues.clear();
-    }
-
     @Override
     public <AdditionalOutputT> void outputWindowedValue(
         TupleTag<AdditionalOutputT> tag,

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
index 652c753..7c38348 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/BoundedDataset.java
@@ -98,9 +98,20 @@ public class BoundedDataset<T> implements Dataset {
   }
 
   @Override
-  public void cache(String storageLevel) {
-    // populate the rdd if needed
-    getRDD().persist(StorageLevel.fromString(storageLevel));
+  @SuppressWarnings("unchecked")
+  public void cache(String storageLevel, Coder<?> coder) {
+    StorageLevel level = StorageLevel.fromString(storageLevel);
+    if (TranslationUtils.avoidRddSerialization(level)) {
+      // if it is memory only reduce the overhead of moving to bytes
+      this.rdd = getRDD().persist(level);
+    } else {
+      // Caching can cause Serialization, we need to code to bytes
+      // more details in https://issues.apache.org/jira/browse/BEAM-2669
+      Coder<WindowedValue<T>> windowedValueCoder = (Coder<WindowedValue<T>>) 
coder;
+      this.rdd = getRDD().map(CoderHelpers.toByteFunction(windowedValueCoder))
+          .persist(level)
+          .map(CoderHelpers.fromByteFunction(windowedValueCoder));
+    }
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java
index b5d550e..b361756 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/Dataset.java
@@ -19,6 +19,7 @@
 package org.apache.beam.runners.spark.translation;
 
 import java.io.Serializable;
+import org.apache.beam.sdk.coders.Coder;
 
 
 /**
@@ -26,7 +27,7 @@ import java.io.Serializable;
  */
 public interface Dataset extends Serializable {
 
-  void cache(String storageLevel);
+  void cache(String storageLevel, Coder<?> coder);
 
   void action();
 

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
index 463e507..10526f9 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/EvaluationContext.java
@@ -35,6 +35,7 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
@@ -137,18 +138,30 @@ public class EvaluationContext {
     return false;
   }
 
+  public void putDataset(PTransform<?, ? extends PValue> transform, Dataset 
dataset,
+      boolean forceCache) {
+    putDataset(getOutput(transform), dataset, forceCache);
+  }
+
+
   public void putDataset(PTransform<?, ? extends PValue> transform, Dataset 
dataset) {
-    putDataset(getOutput(transform), dataset);
+    putDataset(transform, dataset,  false);
   }
 
-  public void putDataset(PValue pvalue, Dataset dataset) {
+  public void putDataset(PValue pvalue, Dataset dataset, boolean forceCache) {
     try {
       dataset.setName(pvalue.getName());
     } catch (IllegalStateException e) {
       // name not set, ignore
     }
-    if (shouldCache(pvalue)) {
-      dataset.cache(storageLevel());
+    if (forceCache || shouldCache(pvalue)) {
+      // we cache only PCollection
+      if (pvalue instanceof PCollection) {
+        Coder<?> coder = ((PCollection<?>) pvalue).getCoder();
+        Coder<? extends BoundedWindow> wCoder =
+            ((PCollection<?>) 
pvalue).getWindowingStrategy().getWindowFn().windowCoder();
+        dataset.cache(storageLevel(), WindowedValue.getFullCoder(coder, 
wCoder));
+      }
     }
     datasets.put(pvalue, dataset);
     leaves.add(dataset);
@@ -254,7 +267,7 @@ public class EvaluationContext {
     return boundedDataset.getValues(pcollection);
   }
 
-  private String storageLevel() {
+  public String storageLevel() {
     return 
serializableOptions.get().as(SparkPipelineOptions.class).getStorageLevel();
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
index cdeddad..0132de3 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkContextFactory.java
@@ -23,7 +23,6 @@ import org.apache.beam.runners.spark.SparkPipelineOptions;
 import org.apache.beam.runners.spark.coders.BeamSparkRunnerRegistrator;
 import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.serializer.KryoSerializer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -96,7 +95,6 @@ public final class SparkContextFactory {
       conf.setAppName(contextOptions.getAppName());
       // register immutable collections serializers because the SDK uses them.
       conf.set("spark.kryo.registrator", 
BeamSparkRunnerRegistrator.class.getName());
-      conf.set("spark.serializer", KryoSerializer.class.getName());
       return new JavaSparkContext(conf);
     }
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
deleted file mode 100644
index b236ce7..0000000
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/StorageLevelPTransform.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.translation;
-
-import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.WindowingStrategy;
-
-/**
- * Get RDD storage level for the input PCollection (mostly used for testing 
purpose).
- */
-public final class StorageLevelPTransform extends PTransform<PCollection<?>, 
PCollection<String>> {
-
-  @Override
-  public PCollection<String> expand(PCollection<?> input) {
-    return PCollection.createPrimitiveOutputInternal(input.getPipeline(),
-        WindowingStrategy.globalDefault(),
-        PCollection.IsBounded.BOUNDED,
-        StringUtf8Coder.of());
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index e060e1d..7cb8628 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -20,6 +20,7 @@ package org.apache.beam.runners.spark.translation;
 
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkState;
+import static 
org.apache.beam.runners.spark.translation.TranslationUtils.avoidRddSerialization;
 import static 
org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable;
 
 import com.google.common.base.Optional;
@@ -27,7 +28,6 @@ import com.google.common.collect.FluentIterable;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.Iterator;
 import java.util.Map;
 import org.apache.beam.runners.core.SystemReduceFn;
@@ -41,7 +41,6 @@ import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.CombineWithContext;
@@ -71,7 +70,7 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
-
+import org.apache.spark.storage.StorageLevel;
 
 /**
  * Supports translation between a Beam transform, and Spark's operations on 
RDDs.
@@ -393,8 +392,20 @@ public final class TransformTranslator {
 
         Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
         if (outputs.size() > 1) {
-          // cache the RDD if we're going to filter it more than once.
-          all.cache();
+          StorageLevel level = StorageLevel.fromString(context.storageLevel());
+          if (avoidRddSerialization(level)) {
+            // if it is memory only reduce the overhead of moving to bytes
+            all = all.persist(level);
+          } else {
+            // Caching can cause Serialization, we need to code to bytes
+            // more details in https://issues.apache.org/jira/browse/BEAM-2669
+            Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap =
+                TranslationUtils.getTupleTagCoders(outputs);
+            all = all
+                
.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
+                .persist(level)
+                
.mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
+          }
         }
         for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
           JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
@@ -402,7 +413,7 @@ public final class TransformTranslator {
           // Object is the best we can do since different outputs can have 
different tags
           JavaRDD<WindowedValue<Object>> values =
               (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
-          context.putDataset(output.getValue(), new BoundedDataset<>(values));
+          context.putDataset(output.getValue(), new BoundedDataset<>(values), 
false);
         }
       }
 
@@ -456,7 +467,7 @@ public final class TransformTranslator {
                     jsc.sc(), transform.getSource(), 
context.getSerializableOptions(), stepName)
                 .toJavaRDD();
         // cache to avoid re-evaluation of the source by Spark's lazy DAG 
evaluation.
-        context.putDataset(transform, new BoundedDataset<>(input.cache()));
+        context.putDataset(transform, new BoundedDataset<>(input), true);
       }
 
       @Override
@@ -531,32 +542,6 @@ public final class TransformTranslator {
     };
   }
 
-  private static TransformEvaluator<StorageLevelPTransform> storageLevel() {
-    return new TransformEvaluator<StorageLevelPTransform>() {
-      @Override
-      public void evaluate(StorageLevelPTransform transform, EvaluationContext 
context) {
-        JavaRDD rdd = ((BoundedDataset) 
(context).borrowDataset(transform)).getRDD();
-        JavaSparkContext javaSparkContext = context.getSparkContext();
-
-        WindowedValue.ValueOnlyWindowedValueCoder<String> windowCoder =
-            WindowedValue.getValueOnlyCoder(StringUtf8Coder.of());
-        JavaRDD output =
-            javaSparkContext.parallelize(
-                CoderHelpers.toByteArrays(
-                    
Collections.singletonList(rdd.getStorageLevel().description()),
-                    StringUtf8Coder.of()))
-            .map(CoderHelpers.fromByteFunction(windowCoder));
-
-        context.putDataset(transform, new BoundedDataset<String>(output));
-      }
-
-      @Override
-      public String toNativeString() {
-        return "sparkContext.parallelize(rdd.getStorageLevel().description())";
-      }
-    };
-  }
-
   private static <K, V, W extends BoundedWindow> 
TransformEvaluator<Reshuffle<K, V>> reshuffle() {
     return new TransformEvaluator<Reshuffle<K, V>>() {
       @Override public void evaluate(Reshuffle<K, V> transform, 
EvaluationContext context) {
@@ -605,8 +590,6 @@ public final class TransformTranslator {
     EVALUATORS.put(View.CreatePCollectionView.class, createPCollView());
     EVALUATORS.put(Window.Assign.class, window());
     EVALUATORS.put(Reshuffle.class, reshuffle());
-    // mostly test evaluators
-    EVALUATORS.put(StorageLevelPTransform.class, storageLevel());
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
index 993062c..90f5ee3 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
@@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.Maps;
 import java.io.Serializable;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -29,7 +30,9 @@ import org.apache.beam.runners.core.InMemoryStateInternals;
 import org.apache.beam.runners.core.StateInternals;
 import org.apache.beam.runners.core.StateInternalsFactory;
 import org.apache.beam.runners.spark.SparkRunner;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
@@ -39,7 +42,9 @@ import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -48,6 +53,7 @@ import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
 import scala.Tuple2;
@@ -413,4 +419,76 @@ public final class TranslationUtils {
       }
     };
   }
+
+  /**
+   * Utility to get mapping between TupleTag and a coder.
+   * @param outputs - A map of tuple tags and pcollections
+   * @return mapping between TupleTag and a coder
+   */
+  public static Map<TupleTag<?>, Coder<WindowedValue<?>>> getTupleTagCoders(
+      Map<TupleTag<?>, PValue> outputs) {
+    Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap = new 
HashMap<>(outputs.size());
+
+    for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
+      // we get the first PValue as all of them are fro the same type.
+      PCollection<?> pCollection = (PCollection<?>) output.getValue();
+      Coder<?> coder = pCollection.getCoder();
+      Coder<? extends BoundedWindow> wCoder =
+          pCollection.getWindowingStrategy().getWindowFn().windowCoder();
+      @SuppressWarnings("unchecked")
+      Coder<WindowedValue<?>> windowedValueCoder =
+          (Coder<WindowedValue<?>>) (Coder<?>) 
WindowedValue.getFullCoder(coder, wCoder);
+      coderMap.put(output.getKey(), windowedValueCoder);
+    }
+    return coderMap;
+  }
+
+  /**
+   * Returns a pair function to convert value to bytes via coder.
+   * @param coderMap - mapping between TupleTag and a coder
+   * @return a pair function to convert value to bytes via coder
+   */
+  public static PairFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, 
TupleTag<?>, byte[]>
+      getTupleTagEncodeFunction(final Map<TupleTag<?>, 
Coder<WindowedValue<?>>> coderMap) {
+    return new PairFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, 
TupleTag<?>, byte[]>() {
+
+      @Override public Tuple2<TupleTag<?>, byte[]>
+      call(Tuple2<TupleTag<?>, WindowedValue<?>> tuple2) throws Exception {
+        TupleTag<?> tupleTag = tuple2._1;
+        WindowedValue<?> windowedValue = tuple2._2;
+        return new Tuple2<TupleTag<?>, byte[]>
+            (tupleTag, CoderHelpers.toByteArray(windowedValue, 
coderMap.get(tupleTag)));
+      }
+    };
+  }
+
+  /**
+   * Returns a pair function to convert bytes to value via coder.
+   * @param coderMap - mapping between TupleTag and a coder
+   * @return a pair function to convert bytes to value via coder
+   * */
+  public static PairFunction<Tuple2<TupleTag<?>, byte[]>, TupleTag<?>, 
WindowedValue<?>>
+      getTupleTagDecodeFunction(final Map<TupleTag<?>, 
Coder<WindowedValue<?>>> coderMap) {
+    return new PairFunction<Tuple2<TupleTag<?>, byte[]>, TupleTag<?>, 
WindowedValue<?>>() {
+
+      @Override public Tuple2<TupleTag<?>, WindowedValue<?>>
+      call(Tuple2<TupleTag<?>, byte[]> tuple2) throws Exception {
+        TupleTag<?> tupleTag = tuple2._1;
+        byte[] windowedByteValue = tuple2._2;
+        return new Tuple2<TupleTag<?>, WindowedValue<?>>
+            (tupleTag, CoderHelpers.fromByteArray(windowedByteValue, 
coderMap.get(tupleTag)));
+      }
+    };
+  }
+
+  /**
+   * checking if we can avoid Serialization - relevant to RDDs. DStreams are 
memory ser in spark.
+   * @param level StorageLevel required
+   * @return
+   */
+  public static boolean avoidRddSerialization(StorageLevel level) {
+    return level.equals(StorageLevel.MEMORY_ONLY()) || 
level.equals(StorageLevel.MEMORY_ONLY_2());
+  }
+
+
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 4114803..ea26007 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -442,11 +442,19 @@ public final class StreamingTransformTranslator {
                             false));
                   }
                 });
+
         Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
         if (outputs.size() > 1) {
-          // cache the DStream if we're going to filter it more than once.
-          all.cache();
+          // Caching can cause Serialization, we need to code to bytes
+          // more details in https://issues.apache.org/jira/browse/BEAM-2669
+          Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap =
+              TranslationUtils.getTupleTagCoders(outputs);
+          all = all
+              .mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap))
+              .cache()
+              .mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
         }
+
         for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
           @SuppressWarnings("unchecked")
           JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
@@ -458,7 +466,8 @@ public final class StreamingTransformTranslator {
                   (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
           context.putDataset(
               output.getValue(),
-              new UnboundedDataset<>(values, 
unboundedDataset.getStreamSources()));
+              new UnboundedDataset<>(values, 
unboundedDataset.getStreamSources()),
+              false);
         }
       }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
index ccdaf11..df927af 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
@@ -20,11 +20,15 @@ package org.apache.beam.runners.spark.translation.streaming;
 
 import java.util.ArrayList;
 import java.util.List;
+
+import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.translation.Dataset;
 import org.apache.beam.runners.spark.translation.TranslationUtils;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -37,7 +41,7 @@ public class UnboundedDataset<T> implements Dataset {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(UnboundedDataset.class);
 
-  private final JavaDStream<WindowedValue<T>> dStream;
+  private JavaDStream<WindowedValue<T>> dStream;
   // points to the input streams that created this UnboundedDataset,
   // should be greater > 1 in case of Flatten for example.
   // when using GlobalWatermarkHolder this information helps to take only the 
relevant watermarks
@@ -57,15 +61,22 @@ public class UnboundedDataset<T> implements Dataset {
     return streamSources;
   }
 
-  public void cache() {
-    dStream.cache();
-  }
-
   @Override
-  public void cache(String storageLevel) {
+  @SuppressWarnings("unchecked")
+  public void cache(String storageLevel, Coder<?> coder) {
     // we "force" MEMORY storage level in streaming
-    LOG.warn("Provided StorageLevel ignored for stream, using default level");
-    cache();
+    if 
(!StorageLevel.fromString(storageLevel).equals(StorageLevel.MEMORY_ONLY_SER())) 
{
+      LOG.warn("Provided StorageLevel: {} is ignored for streams, using the 
default level: {}",
+          storageLevel,
+          StorageLevel.MEMORY_ONLY_SER());
+    }
+    // Caching can cause Serialization, we need to code to bytes
+    // more details in https://issues.apache.org/jira/browse/BEAM-2669
+    Coder<WindowedValue<T>> wc = (Coder<WindowedValue<T>>) coder;
+    this.dStream = dStream.map(CoderHelpers.toByteFunction(wc))
+        .cache()
+        .map(CoderHelpers.fromByteFunction(wc));
+
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/ffd08dae/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
deleted file mode 100644
index 8bd6dae..0000000
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/StorageLevelTest.java
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.spark.translation;
-
-import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Count;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.values.PCollection;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Rule;
-import org.junit.Test;
-
-
-/**
- * Test the RDD storage level defined by user.
- */
-public class StorageLevelTest {
-
-  private static String beamTestPipelineOptions;
-
-  @Rule
-  public final TestPipeline pipeline = TestPipeline.create();
-
-  @BeforeClass
-  public static void init() {
-    beamTestPipelineOptions =
-        System.getProperty(TestPipeline.PROPERTY_BEAM_TEST_PIPELINE_OPTIONS);
-
-    System.setProperty(
-        TestPipeline.PROPERTY_BEAM_TEST_PIPELINE_OPTIONS,
-        beamTestPipelineOptions.replace("]", ", 
\"--storageLevel=DISK_ONLY\"]"));
-  }
-
-  @AfterClass
-  public static void teardown() {
-    System.setProperty(
-        TestPipeline.PROPERTY_BEAM_TEST_PIPELINE_OPTIONS,
-        beamTestPipelineOptions);
-  }
-
-  @Test
-  public void test() throws Exception {
-    PCollection<String> pCollection = pipeline.apply("CreateFoo", 
Create.of("foo"));
-
-    // by default, the Spark runner doesn't cache the RDD if it accessed only 
one time.
-    // So, to "force" the caching of the RDD, we have to call the RDD at least 
two time.
-    // That's why we are using Count fn on the PCollection.
-    pCollection.apply("CountAll", Count.<String>globally());
-
-    PCollection<String> output = pCollection.apply(new 
StorageLevelPTransform());
-
-    PAssert.thatSingleton(output).isEqualTo("Disk Serialized 1x Replicated");
-
-    pipeline.run();
-  }
-
-}

Reply via email to