Refactor translators according to new GroupAlsoByWindow implemenation for the 
Spark runnner.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/96abe4f0
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/96abe4f0
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/96abe4f0

Branch: refs/heads/master
Commit: 96abe4f08be12ac10dac39b55b8f8319a227b1ea
Parents: 8c37970
Author: Sela <ans...@paypal.com>
Authored: Fri Feb 17 01:19:23 2017 +0200
Committer: Sela <ans...@paypal.com>
Committed: Wed Mar 1 00:17:59 2017 +0200

----------------------------------------------------------------------
 .../SparkGroupAlsoByWindowViaWindowSet.java     |  14 +-
 .../spark/stateful/SparkStateInternals.java     |   4 +-
 .../spark/stateful/SparkTimerInternals.java     |   6 +-
 .../translation/GroupCombineFunctions.java      | 237 ++++++------------
 .../spark/translation/TransformTranslator.java  | 238 +++++++++++++------
 .../spark/translation/TranslationUtils.java     |  22 +-
 .../streaming/StreamingTransformTranslator.java | 163 ++++---------
 .../translation/streaming/UnboundedDataset.java |  12 +-
 .../beam/runners/spark/util/LateDataUtils.java  |   2 +-
 .../spark/util/UnsupportedSideInputReader.java  |   2 +-
 .../streaming/TrackStreamingSourcesTest.java    |   2 +-
 11 files changed, 314 insertions(+), 388 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index 7902d7c..2fb4100 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -17,8 +17,8 @@
  */
 package org.apache.beam.runners.spark.stateful;
 
-import com.google.common.collect.Table;
 import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Table;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Collection;
@@ -84,7 +84,8 @@ import scala.runtime.AbstractFunction1;
  * in the following steps.
  */
 public class SparkGroupAlsoByWindowViaWindowSet {
-  private static final Logger LOG = 
LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class);
+  private static final Logger LOG = LoggerFactory.getLogger(
+      SparkGroupAlsoByWindowViaWindowSet.class);
 
   /**
    * A helper class that is essentially a {@link Serializable} {@link 
AbstractFunction1}.
@@ -101,7 +102,7 @@ public class SparkGroupAlsoByWindowViaWindowSet {
           final SparkRuntimeContext runtimeContext,
           final List<Integer> sourceIds) {
 
-    Long checkpointDuration =
+    long checkpointDurationMillis =
         runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class)
             .getCheckpointDurationMillis();
 
@@ -271,8 +272,11 @@ public class SparkGroupAlsoByWindowViaWindowSet {
         return scala.collection.JavaConversions.asScalaIterator(outIter);
       }
     }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers,
-        List<WindowedValue<KV<K, Iterable<InputT>>>>>>fakeClassTag())
-            .checkpoint(new Duration(checkpointDuration));
+        List<WindowedValue<KV<K, Iterable<InputT>>>>>>fakeClassTag());
+
+    if (checkpointDurationMillis > 0) {
+      firedStream.checkpoint(new Duration(checkpointDurationMillis));
+    }
 
     // go back to Java now.
     JavaPairDStream<K, Tuple2<StateAndTimers, List<WindowedValue<KV<K, 
Iterable<InputT>>>>>>

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
index e628d31..93b1f63 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
@@ -22,11 +22,11 @@ import com.google.common.collect.Table;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
-import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.core.StateInternals;
 import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateTag;
 import org.apache.beam.runners.core.StateTag.StateBinder;
+import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.ListCoder;
@@ -399,4 +399,4 @@ class SparkStateInternals<K> implements StateInternals<K> {
       };
     }
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
index 65225c5..4072240 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java
@@ -27,9 +27,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import javax.annotation.Nullable;
-import 
org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks;
 import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.TimerInternals;
+import 
org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.TimeDomain;
 import org.apache.spark.broadcast.Broadcast;
@@ -145,7 +145,7 @@ class SparkTimerInternals implements TimerInternals {
     return inputWatermark;
   }
 
-  /** Advances the watermark - since */
+  /** Advances the watermark. */
   public void advanceWatermark() {
     inputWatermark = highWatermark;
   }
@@ -170,4 +170,4 @@ class SparkTimerInternals implements TimerInternals {
     throw new UnsupportedOperationException("Deleting a timer by ID is not yet 
supported.");
   }
 
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
index 8a41b4e..1e879ce 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -18,32 +18,21 @@
 
 package org.apache.beam.runners.spark.translation;
 
-import com.google.common.collect.Lists;
-import java.util.Collections;
-import java.util.Map;
-import org.apache.beam.runners.core.SystemReduceFn;
-import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import static com.google.common.base.Preconditions.checkArgument;
+
 import org.apache.beam.runners.spark.coders.CoderHelpers;
 import org.apache.beam.runners.spark.util.ByteArray;
-import org.apache.beam.runners.spark.util.SideInputBroadcast;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.transforms.CombineWithContext;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.spark.Accumulator;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
-import scala.Tuple2;
 
 
 
@@ -53,113 +42,71 @@ import scala.Tuple2;
 public class GroupCombineFunctions {
 
   /**
-   * Apply {@link org.apache.beam.sdk.transforms.GroupByKey} to a Spark RDD.
+   * An implementation of
+   * {@link 
org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly}
+   * for the Spark runner.
    */
-  public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
-      Iterable<V>>>> groupByKey(JavaRDD<WindowedValue<KV<K, V>>> rdd,
-                                Accumulator<NamedAggregators> accum,
-                                KvCoder<K, V> coder,
-                                SparkRuntimeContext runtimeContext,
-                                WindowingStrategy<?, W> windowingStrategy) {
-    //--- coders.
-    final Coder<K> keyCoder = coder.getKeyCoder();
-    final Coder<V> valueCoder = coder.getValueCoder();
-    final WindowedValue.WindowedValueCoder<V> wvCoder = 
WindowedValue.FullWindowedValueCoder.of(
-        valueCoder, windowingStrategy.getWindowFn().windowCoder());
+  public static <K, V> JavaRDD<WindowedValue<KV<K, 
Iterable<WindowedValue<V>>>>> groupByKeyOnly(
+      JavaRDD<WindowedValue<KV<K, V>>> rdd,
+      Coder<K> keyCoder,
+      WindowedValueCoder<V> wvCoder) {
 
-    //--- groupByKey.
     // Use coders to convert objects in the PCollection to byte arrays, so they
     // can be transferred over the network for the shuffle.
-    JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKey =
-        rdd.map(new ReifyTimestampsAndWindowsFunction<K, V>())
-            .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction())
-            .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction())
-            .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder))
-            .groupByKey()
-            .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, 
wvCoder))
-            // empty windows are OK here, see GroupByKey#evaluateHelper in the 
SDK
-            .map(TranslationUtils.<K, 
Iterable<WindowedValue<V>>>fromPairFunction())
-            .map(WindowingHelpers.<KV<K, 
Iterable<WindowedValue<V>>>>windowFunction());
-
-    //--- now group also by window.
-    // GroupAlsoByWindow currently uses a dummy in-memory StateInternals
-    return groupedByKey.flatMap(
-        new SparkGroupAlsoByWindowViaOutputBufferFn<>(
-            windowingStrategy,
-            new TranslationUtils.InMemoryStateInternalsFactory<K>(),
-            SystemReduceFn.<K, V, W>buffering(valueCoder),
-            runtimeContext,
-            accum));
+    return rdd
+        .map(new ReifyTimestampsAndWindowsFunction<K, V>())
+        .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction())
+        .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction())
+        .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder))
+        .groupByKey()
+        .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder))
+        .map(TranslationUtils.<K, 
Iterable<WindowedValue<V>>>fromPairFunction())
+        .map(WindowingHelpers.<KV<K, 
Iterable<WindowedValue<V>>>>windowFunction());
   }
 
   /**
    * Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} 
transformation.
    */
-  public static <InputT, AccumT, OutputT> JavaRDD<WindowedValue<OutputT>>
-  combineGlobally(JavaRDD<WindowedValue<InputT>> rdd,
-                  final CombineWithContext.CombineFnWithContext<InputT, 
AccumT, OutputT> combineFn,
-                  final Coder<InputT> iCoder,
-                  final Coder<OutputT> oCoder,
-                  final SparkRuntimeContext runtimeContext,
-                  final WindowingStrategy<?, ?> windowingStrategy,
-                  final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>>
-                      sideInputs,
-                  boolean hasDefault) {
-    // handle empty input RDD, which will natively skip the entire execution 
as Spark will not
-    // run on empty RDDs.
-    if (rdd.isEmpty()) {
-      JavaSparkContext jsc = new JavaSparkContext(rdd.context());
-      if (hasDefault) {
-        OutputT defaultValue = combineFn.defaultValue();
-        return jsc
-            
.parallelize(Lists.newArrayList(CoderHelpers.toByteArray(defaultValue, oCoder)))
-            .map(CoderHelpers.fromByteFunction(oCoder))
-            .map(WindowingHelpers.<OutputT>windowFunction());
-      } else {
-        return jsc.emptyRDD();
-      }
-    }
-
-    //--- coders.
-    final Coder<AccumT> aCoder;
-    try {
-      aCoder = 
combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder);
-    } catch (CannotProvideCoderException e) {
-      throw new IllegalStateException("Could not determine coder for 
accumulator", e);
-    }
-    // windowed coders.
+  public static <InputT, AccumT> Iterable<WindowedValue<AccumT>> 
combineGlobally(
+      JavaRDD<WindowedValue<InputT>> rdd,
+      final SparkGlobalCombineFn<InputT, AccumT, ?> sparkCombineFn,
+      final Coder<InputT> iCoder,
+      final Coder<AccumT> aCoder,
+      final WindowingStrategy<?, ?> windowingStrategy) {
+    checkArgument(!rdd.isEmpty(), "CombineGlobally computation should be 
skipped for empty RDDs.");
+
+    // coders.
     final WindowedValue.FullWindowedValueCoder<InputT> wviCoder =
         WindowedValue.FullWindowedValueCoder.of(iCoder,
             windowingStrategy.getWindowFn().windowCoder());
     final WindowedValue.FullWindowedValueCoder<AccumT> wvaCoder =
         WindowedValue.FullWindowedValueCoder.of(aCoder,
             windowingStrategy.getWindowFn().windowCoder());
-    final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder =
-        WindowedValue.FullWindowedValueCoder.of(oCoder,
-            windowingStrategy.getWindowFn().windowCoder());
-
-    final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn =
-        new SparkGlobalCombineFn<>(combineFn, runtimeContext, sideInputs, 
windowingStrategy);
     final IterableCoder<WindowedValue<AccumT>> iterAccumCoder = 
IterableCoder.of(wvaCoder);
 
-
     // Use coders to convert objects in the PCollection to byte arrays, so they
     // can be transferred over the network for the shuffle.
-    JavaRDD<byte[]> inRddBytes = 
rdd.map(CoderHelpers.toByteFunction(wviCoder));
-    /*AccumT*/ byte[] acc = inRddBytes.aggregate(
+    // for readability, we add comments with actual type next to byte[].
+    // to shorten line length, we use:
+    //---- WV: WindowedValue
+    //---- Iterable: Itr
+    //---- AccumT: A
+    //---- InputT: I
+    JavaRDD<byte[]> inputRDDBytes = 
rdd.map(CoderHelpers.toByteFunction(wviCoder));
+    /*Itr<WV<A>>*/ byte[] accumulatedBytes = inputRDDBytes.aggregate(
         CoderHelpers.toByteArray(sparkCombineFn.zeroValue(), iterAccumCoder),
-        new Function2</*AccumT*/ byte[], /*InputT*/ byte[], /*AccumT*/ 
byte[]>() {
+        new Function2</*A*/ byte[], /*I*/ byte[], /*A*/ byte[]>() {
           @Override
-          public /*AccumT*/ byte[] call(/*AccumT*/ byte[] ab, /*InputT*/ 
byte[] ib)
+          public /*Itr<WV<A>>*/ byte[] call(/*Itr<WV<A>>*/ byte[] ab, 
/*WV<I>*/ byte[] ib)
               throws Exception {
             Iterable<WindowedValue<AccumT>> a = CoderHelpers.fromByteArray(ab, 
iterAccumCoder);
             WindowedValue<InputT> i = CoderHelpers.fromByteArray(ib, wviCoder);
             return CoderHelpers.toByteArray(sparkCombineFn.seqOp(a, i), 
iterAccumCoder);
           }
         },
-        new Function2</*AccumT*/ byte[], /*AccumT*/ byte[], /*AccumT*/ 
byte[]>() {
+        new Function2</*Itr<WV<A>>>*/ byte[], /*Itr<WV<A>>>*/ byte[], 
/*Itr<WV<A>>>*/ byte[]>() {
           @Override
-          public /*AccumT*/ byte[] call(/*AccumT*/ byte[] a1b, /*AccumT*/ 
byte[] a2b)
+          public /*Itr<WV<A>>>*/ byte[] call(/*Itr<WV<A>>>*/ byte[] a1b, 
/*Itr<WV<A>>>*/ byte[] a2b)
               throws Exception {
             Iterable<WindowedValue<AccumT>> a1 = 
CoderHelpers.fromByteArray(a1b, iterAccumCoder);
             Iterable<WindowedValue<AccumT>> a2 = 
CoderHelpers.fromByteArray(a2b, iterAccumCoder);
@@ -168,10 +115,7 @@ public class GroupCombineFunctions {
           }
         }
     );
-    Iterable<WindowedValue<OutputT>> output =
-        sparkCombineFn.extractOutput(CoderHelpers.fromByteArray(acc, 
iterAccumCoder));
-    return new JavaSparkContext(rdd.context()).parallelize(
-        CoderHelpers.toByteArrays(output, 
wvoCoder)).map(CoderHelpers.fromByteFunction(wvoCoder));
+    return CoderHelpers.fromByteArray(accumulatedBytes, iterAccumCoder);
   }
 
   /**
@@ -183,31 +127,22 @@ public class GroupCombineFunctions {
    * For streaming, this will be called from within a serialized context
    * (DStream's transform callback), so passed arguments need to be 
Serializable.
    */
-  public static <K, InputT, AccumT, OutputT> JavaRDD<WindowedValue<KV<K, 
OutputT>>>
-  combinePerKey(JavaRDD<WindowedValue<KV<K, InputT>>> rdd,
-                final CombineWithContext.KeyedCombineFnWithContext<K, InputT, 
AccumT, OutputT>
-                    combineFn,
-                final KvCoder<K, InputT> inputCoder,
-                final SparkRuntimeContext runtimeContext,
-                final WindowingStrategy<?, ?> windowingStrategy,
-                final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>>
-                    sideInputs) {
-    //--- coders.
-    final Coder<K> keyCoder = inputCoder.getKeyCoder();
-    final Coder<InputT> viCoder = inputCoder.getValueCoder();
-    final Coder<AccumT> vaCoder;
-    try {
-      vaCoder = 
combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), keyCoder, 
viCoder);
-    } catch (CannotProvideCoderException e) {
-      throw new IllegalStateException("Could not determine coder for 
accumulator", e);
-    }
-    // windowed coders.
+  public static <K, InputT, AccumT> JavaPairRDD<K, 
Iterable<WindowedValue<KV<K, AccumT>>>>
+      combinePerKey(
+          JavaRDD<WindowedValue<KV<K, InputT>>> rdd,
+          final SparkKeyedCombineFn<K, InputT, AccumT, ?> sparkCombineFn,
+          final Coder<K> keyCoder,
+          final Coder<InputT> iCoder,
+          final Coder<AccumT> aCoder,
+          final WindowingStrategy<?, ?> windowingStrategy) {
+    // coders.
     final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder =
-        WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, viCoder),
+        WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, iCoder),
             windowingStrategy.getWindowFn().windowCoder());
     final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder =
-        WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, vaCoder),
+        WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, aCoder),
             windowingStrategy.getWindowFn().windowCoder());
+    final IterableCoder<WindowedValue<KV<K, AccumT>>> iterAccumCoder = 
IterableCoder.of(wkvaCoder);
 
     // We need to duplicate K as both the key of the JavaPairRDD as well as 
inside the value,
     // since the functions passed to combineByKey don't receive the associated 
key of each
@@ -217,53 +152,46 @@ public class GroupCombineFunctions {
     // we won't need to duplicate the keys anymore.
     // Key has to bw windowed in order to group by window as well.
     JavaPairRDD<K, WindowedValue<KV<K, InputT>>> inRddDuplicatedKeyPair =
-        rdd.flatMapToPair(
-            new PairFlatMapFunction<WindowedValue<KV<K, InputT>>, K,
-                WindowedValue<KV<K, InputT>>>() {
-              @Override
-              public Iterable<Tuple2<K, WindowedValue<KV<K, InputT>>>>
-              call(WindowedValue<KV<K, InputT>> wkv) {
-                return Collections.singletonList(new 
Tuple2<>(wkv.getValue().getKey(), wkv));
-              }
-            });
-
-    final SparkKeyedCombineFn<K, InputT, AccumT, OutputT> sparkCombineFn =
-        new SparkKeyedCombineFn<>(combineFn, runtimeContext, sideInputs, 
windowingStrategy);
-    final IterableCoder<WindowedValue<KV<K, AccumT>>> iterAccumCoder = 
IterableCoder.of(wkvaCoder);
+        rdd.mapToPair(TranslationUtils.<K, 
InputT>toPairByKeyInWindowedValue());
 
     // Use coders to convert objects in the PCollection to byte arrays, so they
     // can be transferred over the network for the shuffle.
+    // for readability, we add comments with actual type next to byte[].
+    // to shorten line length, we use:
+    //---- WV: WindowedValue
+    //---- Iterable: Itr
+    //---- AccumT: A
+    //---- InputT: I
     JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = 
inRddDuplicatedKeyPair
         .mapToPair(CoderHelpers.toByteFunction(keyCoder, wkviCoder));
 
-    // The output of combineByKey will be "AccumT" (accumulator)
-    // types rather than "OutputT" (final output types) since Combine.CombineFn
-    // only provides ways to merge VAs, and no way to merge VOs.
-    JavaPairRDD</*K*/ ByteArray, /*KV<K, AccumT>*/ byte[]> accumulatedBytes =
+    JavaPairRDD</*K*/ ByteArray, /*Itr<WV<KV<K, A>>>*/ byte[]> 
accumulatedBytes =
         inRddDuplicatedKeyPairBytes.combineByKey(
-        new Function</*KV<K, InputT>*/ byte[], /*KV<K, AccumT>*/ byte[]>() {
+        new Function</*WV<KV<K, I>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[]>() {
           @Override
-          public /*KV<K, AccumT>*/ byte[] call(/*KV<K, InputT>*/ byte[] input) 
{
+          public /*Itr<WV<KV<K, A>>>*/ byte[] call(/*WV<KV<K, I>>*/ byte[] 
input) {
             WindowedValue<KV<K, InputT>> wkvi = 
CoderHelpers.fromByteArray(input, wkviCoder);
             return 
CoderHelpers.toByteArray(sparkCombineFn.createCombiner(wkvi), iterAccumCoder);
           }
         },
-        new Function2</*KV<K, AccumT>*/ byte[], /*KV<K, InputT>*/ byte[],
-            /*KV<K, AccumT>*/ byte[]>() {
+        new Function2</*Itr<WV<KV<K, A>>>*/ byte[], /*WV<KV<K, I>>*/ byte[],
+            /*Itr<WV<KV<K, A>>>*/ byte[]>() {
           @Override
-          public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc,
-              /*KV<K, InputT>*/ byte[] input) {
+          public /*Itr<WV<KV<K, A>>>*/ byte[] call(
+              /*Itr<WV<KV<K, A>>>*/ byte[] acc,
+              /*WV<KV<K, I>>*/ byte[] input) {
             Iterable<WindowedValue<KV<K, AccumT>>> wkvas =
                 CoderHelpers.fromByteArray(acc, iterAccumCoder);
             WindowedValue<KV<K, InputT>> wkvi = 
CoderHelpers.fromByteArray(input, wkviCoder);
             return CoderHelpers.toByteArray(sparkCombineFn.mergeValue(wkvi, 
wkvas), iterAccumCoder);
           }
         },
-        new Function2</*KV<K, AccumT>*/ byte[], /*KV<K, AccumT>*/ byte[],
-            /*KV<K, AccumT>*/ byte[]>() {
+        new Function2</*Itr<WV<KV<K, A>>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ 
byte[],
+            /*Itr<WV<KV<K, A>>>*/ byte[]>() {
           @Override
-          public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc1,
-              /*KV<K, AccumT>*/ byte[] acc2) {
+          public /*Itr<WV<KV<K, A>>>*/ byte[] call(
+              /*Itr<WV<KV<K, A>>>*/ byte[] acc1,
+              /*Itr<WV<KV<K, A>>>*/ byte[] acc2) {
             Iterable<WindowedValue<KV<K, AccumT>>> wkvas1 =
                 CoderHelpers.fromByteArray(acc1, iterAccumCoder);
             Iterable<WindowedValue<KV<K, AccumT>>> wkvas2 =
@@ -273,23 +201,6 @@ public class GroupCombineFunctions {
           }
         });
 
-    JavaPairRDD<K, WindowedValue<OutputT>> extracted = accumulatedBytes
-        .mapToPair(CoderHelpers.fromByteFunction(keyCoder, iterAccumCoder))
-        .flatMapValues(new Function<Iterable<WindowedValue<KV<K, AccumT>>>,
-            Iterable<WindowedValue<OutputT>>>() {
-              @Override
-              public Iterable<WindowedValue<OutputT>> call(
-                  Iterable<WindowedValue<KV<K, AccumT>>> accums) {
-                return sparkCombineFn.extractOutput(accums);
-              }
-            });
-    return extracted.map(TranslationUtils.<K, 
WindowedValue<OutputT>>fromPairFunction()).map(
-        new Function<KV<K, WindowedValue<OutputT>>, WindowedValue<KV<K, 
OutputT>>>() {
-          @Override
-          public WindowedValue<KV<K, OutputT>> call(KV<K, 
WindowedValue<OutputT>> kwvo)
-              throws Exception {
-            return kwvo.getValue().withValue(KV.of(kwvo.getKey(), 
kwvo.getValue().getValue()));
-          }
-        });
+    return accumulatedBytes.mapToPair(CoderHelpers.fromByteFunction(keyCoder, 
iterAccumCoder));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 14c14dc..a643651 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -27,6 +27,7 @@ import static 
org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceSh
 import static 
org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable;
 import static 
org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
 
+import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import java.io.IOException;
 import java.util.Collections;
@@ -35,6 +36,7 @@ import java.util.Map;
 import org.apache.avro.mapred.AvroKey;
 import org.apache.avro.mapreduce.AvroJob;
 import org.apache.avro.mapreduce.AvroKeyInputFormat;
+import org.apache.beam.runners.core.SystemReduceFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -46,6 +48,7 @@ import 
org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat;
 import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
 import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
@@ -63,6 +66,7 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
@@ -119,101 +123,159 @@ public final class TransformTranslator {
     };
   }
 
-  private static <K, V> TransformEvaluator<GroupByKey<K, V>> groupByKey() {
+  private static <K, V, W extends BoundedWindow> 
TransformEvaluator<GroupByKey<K, V>> groupByKey() {
     return new TransformEvaluator<GroupByKey<K, V>>() {
       @Override
       public void evaluate(GroupByKey<K, V> transform, EvaluationContext 
context) {
         @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<KV<K, V>>> inRDD =
             ((BoundedDataset<KV<K, V>>) 
context.borrowDataset(transform)).getRDD();
-
         @SuppressWarnings("unchecked")
         final KvCoder<K, V> coder = (KvCoder<K, V>) 
context.getInput(transform).getCoder();
-
         final Accumulator<NamedAggregators> accum =
             SparkAggregators.getNamedAggregators(context.getSparkContext());
-
-        context.putDataset(
-            transform,
-            new BoundedDataset<>(
-                GroupCombineFunctions.groupByKey(
-                    inRDD,
-                    accum,
-                    coder,
-                    context.getRuntimeContext(),
-                    context.getInput(transform).getWindowingStrategy())));
+        @SuppressWarnings("unchecked")
+        final WindowingStrategy<?, W> windowingStrategy =
+            (WindowingStrategy<?, W>) 
context.getInput(transform).getWindowingStrategy();
+        @SuppressWarnings("unchecked")
+        final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) 
windowingStrategy.getWindowFn();
+
+        //--- coders.
+        final Coder<K> keyCoder = coder.getKeyCoder();
+        final WindowedValue.WindowedValueCoder<V> wvCoder =
+            WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), 
windowFn.windowCoder());
+
+        //--- group by key only.
+        JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKey 
=
+            GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder);
+
+        //--- now group also by window.
+        // for batch, GroupAlsoByWindow uses an in-memory StateInternals.
+        JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedAlsoByWindow = 
groupedByKey.flatMap(
+            new SparkGroupAlsoByWindowViaOutputBufferFn<>(
+                windowingStrategy,
+                new TranslationUtils.InMemoryStateInternalsFactory<K>(),
+                SystemReduceFn.<K, V, W>buffering(coder.getValueCoder()),
+                context.getRuntimeContext(),
+                accum));
+
+        context.putDataset(transform, new 
BoundedDataset<>(groupedAlsoByWindow));
       }
     };
   }
 
   private static <K, InputT, OutputT> 
TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>
-  combineGrouped() {
-    return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() 
{
-      @Override
-      public void evaluate(Combine.GroupedValues<K, InputT, OutputT> transform,
-                           EvaluationContext context) {
-        // get the applied combine function.
-        PCollection<? extends KV<K, ? extends Iterable<InputT>>> input =
-            context.getInput(transform);
-        WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
-        @SuppressWarnings("unchecked")
-        CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT> fn 
=
-            (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, 
OutputT>)
-                CombineFnUtil.toFnWithContext(transform.getFn());
-
-        @SuppressWarnings("unchecked")
-        JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> inRDD =
-            ((BoundedDataset<KV<K, Iterable<InputT>>>)
-                context.borrowDataset(transform)).getRDD();
-
-        SparkKeyedCombineFn<K, InputT, ?, OutputT> combineFnWithContext =
-            new SparkKeyedCombineFn<>(fn, context.getRuntimeContext(),
-                TranslationUtils.getSideInputs(transform.getSideInputs(), 
context),
-                windowingStrategy);
-        context.putDataset(transform, new BoundedDataset<>(inRDD.map(new 
TranslationUtils
-            .CombineGroupedValues<>(
-            combineFnWithContext))));
-      }
-    };
+      combineGrouped() {
+          return new TransformEvaluator<Combine.GroupedValues<K, InputT, 
OutputT>>() {
+            @Override
+            public void evaluate(
+                Combine.GroupedValues<K, InputT, OutputT> transform,
+                EvaluationContext context) {
+              @SuppressWarnings("unchecked")
+              CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, 
OutputT> combineFn =
+                  (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, 
OutputT>)
+                      CombineFnUtil.toFnWithContext(transform.getFn());
+              final SparkKeyedCombineFn<K, InputT, ?, OutputT> sparkCombineFn =
+                  new SparkKeyedCombineFn<>(combineFn, 
context.getRuntimeContext(),
+                      
TranslationUtils.getSideInputs(transform.getSideInputs(), context),
+                          context.getInput(transform).getWindowingStrategy());
+
+              @SuppressWarnings("unchecked")
+              JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> inRDD =
+                  ((BoundedDataset<KV<K, Iterable<InputT>>>) 
context.borrowDataset(transform))
+                      .getRDD();
+
+              JavaRDD<WindowedValue<KV<K, OutputT>>> outRDD = inRDD.map(
+                   new Function<WindowedValue<KV<K, Iterable<InputT>>>,
+                       WindowedValue<KV<K, OutputT>>>() {
+                         @Override
+                         public WindowedValue<KV<K, OutputT>> call(
+                             WindowedValue<KV<K, Iterable<InputT>>> in) throws 
Exception {
+                               return WindowedValue.of(
+                                   KV.of(in.getValue().getKey(), 
sparkCombineFn.apply(in)),
+                                   in.getTimestamp(),
+                                   in.getWindows(),
+                                   in.getPane());
+                             }
+                       });
+               context.putDataset(transform, new BoundedDataset<>(outRDD));
+            }
+          };
   }
 
   private static <InputT, AccumT, OutputT> 
TransformEvaluator<Combine.Globally<InputT, OutputT>>
-  combineGlobally() {
-    return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() {
-
-      @Override
-      public void evaluate(Combine.Globally<InputT, OutputT> transform, 
EvaluationContext context) {
-        final PCollection<InputT> input = context.getInput(transform);
-        // serializable arguments to pass.
-        final Coder<InputT> iCoder = context.getInput(transform).getCoder();
-        final Coder<OutputT> oCoder = context.getOutput(transform).getCoder();
-        @SuppressWarnings("unchecked")
-        final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> 
combineFn =
-            (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>)
-                CombineFnUtil.toFnWithContext(transform.getFn());
-        final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
-        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
-            TranslationUtils.getSideInputs(transform.getSideInputs(), context);
-        final boolean hasDefault = transform.isInsertDefault();
+      combineGlobally() {
+        return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() {
 
-        @SuppressWarnings("unchecked")
-        JavaRDD<WindowedValue<InputT>> inRdd =
-            ((BoundedDataset<InputT>) 
context.borrowDataset(transform)).getRDD();
-
-        context.putDataset(transform, new 
BoundedDataset<>(GroupCombineFunctions
-            .combineGlobally(inRdd, combineFn,
-                iCoder, oCoder, runtimeContext, windowingStrategy, sideInputs, 
hasDefault)));
-      }
-    };
+          @Override
+          public void evaluate(
+              Combine.Globally<InputT, OutputT> transform,
+              EvaluationContext context) {
+            final PCollection<InputT> input = context.getInput(transform);
+            final Coder<InputT> iCoder = 
context.getInput(transform).getCoder();
+            final Coder<OutputT> oCoder = 
context.getOutput(transform).getCoder();
+            final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
+            @SuppressWarnings("unchecked")
+            final CombineWithContext.CombineFnWithContext<InputT, AccumT, 
OutputT> combineFn =
+                (CombineWithContext.CombineFnWithContext<InputT, AccumT, 
OutputT>)
+                    CombineFnUtil.toFnWithContext(transform.getFn());
+            final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder =
+                WindowedValue.FullWindowedValueCoder.of(oCoder,
+                    windowingStrategy.getWindowFn().windowCoder());
+            final SparkRuntimeContext runtimeContext = 
context.getRuntimeContext();
+            final boolean hasDefault = transform.isInsertDefault();
+
+            final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn 
=
+                new SparkGlobalCombineFn<>(
+                    combineFn,
+                    runtimeContext,
+                    TranslationUtils.getSideInputs(transform.getSideInputs(), 
context),
+                    windowingStrategy);
+            final Coder<AccumT> aCoder;
+            try {
+              aCoder = 
combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder);
+            } catch (CannotProvideCoderException e) {
+              throw new IllegalStateException("Could not determine coder for 
accumulator", e);
+            }
+
+            @SuppressWarnings("unchecked")
+            JavaRDD<WindowedValue<InputT>> inRdd =
+                ((BoundedDataset<InputT>) 
context.borrowDataset(transform)).getRDD();
+
+            JavaRDD<WindowedValue<OutputT>> outRdd;
+            // handle empty input RDD, which will naturally skip the entire 
execution
+            // as Spark will not run on empty RDDs.
+            if (inRdd.isEmpty()) {
+              JavaSparkContext jsc = new JavaSparkContext(inRdd.context());
+              if (hasDefault) {
+                OutputT defaultValue = combineFn.defaultValue();
+                outRdd = jsc
+                    
.parallelize(Lists.newArrayList(CoderHelpers.toByteArray(defaultValue, oCoder)))
+                    .map(CoderHelpers.fromByteFunction(oCoder))
+                    .map(WindowingHelpers.<OutputT>windowFunction());
+              } else {
+                outRdd = jsc.emptyRDD();
+              }
+            } else {
+              Iterable<WindowedValue<AccumT>> accumulated = 
GroupCombineFunctions.combineGlobally(
+                  inRdd, sparkCombineFn, iCoder, aCoder, windowingStrategy);
+              Iterable<WindowedValue<OutputT>> output = 
sparkCombineFn.extractOutput(accumulated);
+              outRdd = context.getSparkContext()
+                  .parallelize(CoderHelpers.toByteArrays(output, wvoCoder))
+                  .map(CoderHelpers.fromByteFunction(wvoCoder));
+            }
+            context.putDataset(transform, new BoundedDataset<>(outRdd));
+          }
+        };
   }
 
   private static <K, InputT, AccumT, OutputT>
   TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() {
     return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() {
       @Override
-      public void evaluate(Combine.PerKey<K, InputT, OutputT> transform,
-                           EvaluationContext context) {
+      public void evaluate(
+          Combine.PerKey<K, InputT, OutputT> transform,
+          EvaluationContext context) {
         final PCollection<KV<K, InputT>> input = context.getInput(transform);
         // serializable arguments to pass.
         @SuppressWarnings("unchecked")
@@ -227,14 +289,44 @@ public final class TransformTranslator {
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
         final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
+        final SparkKeyedCombineFn<K, InputT, AccumT, OutputT> sparkCombineFn =
+            new SparkKeyedCombineFn<>(combineFn, runtimeContext, sideInputs, 
windowingStrategy);
+        final Coder<AccumT> vaCoder;
+        try {
+          vaCoder = 
combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(),
+              inputCoder.getKeyCoder(), inputCoder.getValueCoder());
+        } catch (CannotProvideCoderException e) {
+          throw new IllegalStateException("Could not determine coder for 
accumulator", e);
+        }
 
         @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<KV<K, InputT>>> inRdd =
             ((BoundedDataset<KV<K, InputT>>) 
context.borrowDataset(transform)).getRDD();
 
-        context.putDataset(transform, new 
BoundedDataset<>(GroupCombineFunctions
-            .combinePerKey(inRdd, combineFn,
-                inputCoder, runtimeContext, windowingStrategy, sideInputs)));
+        JavaPairRDD<K, Iterable<WindowedValue<KV<K, AccumT>>>> 
accumulatePerKey =
+            GroupCombineFunctions.combinePerKey(inRdd, sparkCombineFn, 
inputCoder.getKeyCoder(),
+                inputCoder.getValueCoder(), vaCoder, windowingStrategy);
+
+        JavaRDD<WindowedValue<KV<K, OutputT>>> outRdd =
+            accumulatePerKey.flatMapValues(new 
Function<Iterable<WindowedValue<KV<K, AccumT>>>,
+                Iterable<WindowedValue<OutputT>>>() {
+                  @Override
+                  public Iterable<WindowedValue<OutputT>> call(
+                      Iterable<WindowedValue<KV<K, AccumT>>> iter) throws 
Exception {
+                        return sparkCombineFn.extractOutput(iter);
+                      }
+                }).map(TranslationUtils.<K, 
WindowedValue<OutputT>>fromPairFunction())
+                  .map(new Function<KV<K, WindowedValue<OutputT>>,
+                      WindowedValue<KV<K, OutputT>>>() {
+                        @Override
+                          public WindowedValue<KV<K, OutputT>> call(
+                              KV<K, WindowedValue<OutputT>> kv) throws 
Exception {
+                                WindowedValue<OutputT> wv = kv.getValue();
+                                return wv.withValue(KV.of(kv.getKey(), 
wv.getValue()));
+                              }
+                      });
+
+        context.putDataset(transform, new BoundedDataset<>(outRdd));
       }
     };
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
index 7d83230..6b27436 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java
@@ -21,7 +21,6 @@ package org.apache.beam.runners.spark.translation;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Maps;
 import java.io.Serializable;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import org.apache.beam.runners.core.InMemoryStateInternals;
@@ -42,11 +41,11 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.streaming.api.java.JavaDStream;
 import org.apache.spark.streaming.api.java.JavaPairDStream;
+
 import scala.Tuple2;
 
 /**
@@ -148,20 +147,17 @@ public final class TranslationUtils {
     };
   }
 
-  /** A Flatmap iterator function, flattening iterators into their elements. */
-  public static <T> FlatMapFunction<Iterator<T>, T> flattenIter() {
-    return new FlatMapFunction<Iterator<T>, T>() {
-      @Override
-      public Iterable<T> call(final Iterator<T> t) throws Exception {
-        return new Iterable<T>() {
+  /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */
+  public static <K, V> PairFunction<WindowedValue<KV<K, V>>, K, 
WindowedValue<KV<K, V>>>
+      toPairByKeyInWindowedValue() {
+        return new PairFunction<WindowedValue<KV<K, V>>, K, 
WindowedValue<KV<K, V>>>() {
           @Override
-          public Iterator<T> iterator() {
-            return t;
-          }
+          public Tuple2<K, WindowedValue<KV<K, V>>> call(
+              WindowedValue<KV<K, V>> windowedKv) throws Exception {
+                return new Tuple2<>(windowedKv.getValue().getKey(), 
windowedKv);
+              }
         };
       }
-    };
-  }
 
   /**
    * A utility class to filter {@link TupleTag}s.

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 9451df7..e90b490 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -31,6 +31,7 @@ import 
org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.io.ConsoleIO;
 import org.apache.beam.runners.spark.io.CreateStream;
 import org.apache.beam.runners.spark.io.SparkUnboundedSource;
+import 
org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet;
 import org.apache.beam.runners.spark.metrics.MetricsAccumulator;
 import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
 import org.apache.beam.runners.spark.translation.BoundedDataset;
@@ -148,7 +149,7 @@ final class StreamingTransformTranslator {
           Dataset dataset = context.borrowDataset(pcol);
           if (dataset instanceof UnboundedDataset) {
             UnboundedDataset<T> unboundedDataset = (UnboundedDataset<T>) 
dataset;
-            streamingSources.addAll(unboundedDataset.getStreamingSources());
+            streamingSources.addAll(unboundedDataset.getStreamSources());
             dStreams.add(unboundedDataset.getDStream());
           } else {
             rdds.add(((BoundedDataset<T>) dataset).getRDD());
@@ -205,7 +206,7 @@ final class StreamingTransformTranslator {
         //--- then we apply windowing to the elements
         if (TranslationUtils.skipAssignWindows(transform, context)) {
           context.putDataset(transform,
-              new UnboundedDataset<>(windowedDStream, 
unboundedDataset.getStreamingSources()));
+              new UnboundedDataset<>(windowedDStream, 
unboundedDataset.getStreamSources()));
         } else {
           JavaDStream<WindowedValue<T>> outStream = windowedDStream.transform(
               new Function<JavaRDD<WindowedValue<T>>, 
JavaRDD<WindowedValue<T>>>() {
@@ -215,42 +216,55 @@ final class StreamingTransformTranslator {
             }
           });
           context.putDataset(transform,
-              new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamingSources()));
+              new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamSources()));
         }
       }
     };
   }
 
-  private static <K, V> TransformEvaluator<GroupByKey<K, V>> groupByKey() {
+  private static <K, V, W extends BoundedWindow> 
TransformEvaluator<GroupByKey<K, V>> groupByKey() {
     return new TransformEvaluator<GroupByKey<K, V>>() {
       @Override
       public void evaluate(GroupByKey<K, V> transform, EvaluationContext 
context) {
-        @SuppressWarnings("unchecked")
-        UnboundedDataset<KV<K, V>> unboundedDataset =
-            ((UnboundedDataset<KV<K, V>>) context.borrowDataset(transform));
-        JavaDStream<WindowedValue<KV<K, V>>> dStream = 
unboundedDataset.getDStream();
-
+        @SuppressWarnings("unchecked") UnboundedDataset<KV<K, V>> inputDataset 
=
+            (UnboundedDataset<KV<K, V>>) context.borrowDataset(transform);
+        List<Integer> streamSources = inputDataset.getStreamSources();
+        JavaDStream<WindowedValue<KV<K, V>>> dStream = 
inputDataset.getDStream();
         @SuppressWarnings("unchecked")
         final KvCoder<K, V> coder = (KvCoder<K, V>) 
context.getInput(transform).getCoder();
-
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final WindowingStrategy<?, ?> windowingStrategy =
-            context.getInput(transform).getWindowingStrategy();
+        @SuppressWarnings("unchecked")
+        final WindowingStrategy<?, W> windowingStrategy =
+            (WindowingStrategy<?, W>) 
context.getInput(transform).getWindowingStrategy();
+        @SuppressWarnings("unchecked")
+        final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) 
windowingStrategy.getWindowFn();
 
-        JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream =
+        //--- coders.
+        final WindowedValue.WindowedValueCoder<V> wvCoder =
+            WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), 
windowFn.windowCoder());
+
+        //--- group by key only.
+        JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> 
groupedByKeyStream =
             dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, V>>>,
-                JavaRDD<WindowedValue<KV<K, Iterable<V>>>>>() {
-          @Override
-          public JavaRDD<WindowedValue<KV<K, Iterable<V>>>> call(
-              JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception {
-            final Accumulator<NamedAggregators> accum =
-                SparkAggregators.getNamedAggregators(new 
JavaSparkContext(rdd.context()));
-            return GroupCombineFunctions.groupByKey(rdd, accum, coder, 
runtimeContext,
-                windowingStrategy);
-          }
-        });
-        context.putDataset(transform,
-            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamingSources()));
+                JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>>>() {
+                  @Override
+                  public JavaRDD<WindowedValue<KV<K, 
Iterable<WindowedValue<V>>>>> call(
+                      JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception {
+                        return GroupCombineFunctions.groupByKeyOnly(
+                            rdd, coder.getKeyCoder(), wvCoder);
+                      }
+                });
+
+        //--- now group also by window.
+        JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream =
+            SparkGroupAlsoByWindowViaWindowSet.groupAlsoByWindow(
+                groupedByKeyStream,
+                coder.getValueCoder(),
+                windowingStrategy,
+                runtimeContext,
+                streamSources);
+
+        context.putDataset(transform, new UnboundedDataset<>(outStream, 
streamSources));
       }
     };
   }
@@ -296,96 +310,7 @@ final class StreamingTransformTranslator {
                 });
 
         context.putDataset(transform,
-            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamingSources()));
-      }
-    };
-  }
-
-  private static <InputT, AccumT, OutputT> 
TransformEvaluator<Combine.Globally<InputT, OutputT>>
-  combineGlobally() {
-    return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() {
-
-      @Override
-      public void evaluate(
-          final Combine.Globally<InputT, OutputT> transform,
-          EvaluationContext context) {
-        final PCollection<InputT> input = context.getInput(transform);
-        // serializable arguments to pass.
-        final Coder<InputT> iCoder = context.getInput(transform).getCoder();
-        final Coder<OutputT> oCoder = context.getOutput(transform).getCoder();
-        @SuppressWarnings("unchecked")
-        final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> 
combineFn =
-            (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>)
-                CombineFnUtil.toFnWithContext(transform.getFn());
-        final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
-        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final boolean hasDefault = transform.isInsertDefault();
-        final SparkPCollectionView pviews = context.getPViews();
-
-        @SuppressWarnings("unchecked")
-        UnboundedDataset<InputT> unboundedDataset =
-            ((UnboundedDataset<InputT>) context.borrowDataset(transform));
-        JavaDStream<WindowedValue<InputT>> dStream = 
unboundedDataset.getDStream();
-
-        JavaDStream<WindowedValue<OutputT>> outStream = dStream.transform(
-            new Function<JavaRDD<WindowedValue<InputT>>, 
JavaRDD<WindowedValue<OutputT>>>() {
-          @Override
-          public JavaRDD<WindowedValue<OutputT>> 
call(JavaRDD<WindowedValue<InputT>> rdd)
-              throws Exception {
-            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
-                TranslationUtils.getSideInputs(transform.getSideInputs(),
-                    JavaSparkContext.fromSparkContext(rdd.context()),
-                    pviews);
-            return GroupCombineFunctions.combineGlobally(rdd, combineFn, 
iCoder, oCoder,
-                runtimeContext, windowingStrategy, sideInputs, hasDefault);
-          }
-        });
-
-        context.putDataset(transform,
-            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamingSources()));
-      }
-    };
-  }
-
-  private static <K, InputT, AccumT, OutputT>
-  TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() {
-    return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() {
-      @Override
-      public void evaluate(final Combine.PerKey<K, InputT, OutputT> transform,
-                           final EvaluationContext context) {
-        final PCollection<KV<K, InputT>> input = context.getInput(transform);
-        // serializable arguments to pass.
-        final KvCoder<K, InputT> inputCoder =
-            (KvCoder<K, InputT>) context.getInput(transform).getCoder();
-        @SuppressWarnings("unchecked")
-        final CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, 
OutputT> combineFn =
-            (CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, 
OutputT>)
-                CombineFnUtil.toFnWithContext(transform.getFn());
-        final WindowingStrategy<?, ?> windowingStrategy = 
input.getWindowingStrategy();
-        final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
-        final SparkPCollectionView pviews = context.getPViews();
-
-        @SuppressWarnings("unchecked")
-        UnboundedDataset<KV<K, InputT>> unboundedDataset =
-            ((UnboundedDataset<KV<K, InputT>>) 
context.borrowDataset(transform));
-        JavaDStream<WindowedValue<KV<K, InputT>>> dStream = 
unboundedDataset.getDStream();
-
-        JavaDStream<WindowedValue<KV<K, OutputT>>> outStream =
-            dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, 
InputT>>>,
-                JavaRDD<WindowedValue<KV<K, OutputT>>>>() {
-          @Override
-          public JavaRDD<WindowedValue<KV<K, OutputT>>> call(
-              JavaRDD<WindowedValue<KV<K, InputT>>> rdd) throws Exception {
-            final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs =
-                TranslationUtils.getSideInputs(transform.getSideInputs(),
-                    JavaSparkContext.fromSparkContext(rdd.context()),
-                    pviews);
-            return GroupCombineFunctions.combinePerKey(rdd, combineFn, 
inputCoder, runtimeContext,
-                windowingStrategy, sideInputs);
-          }
-        });
-        context.putDataset(transform,
-            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamingSources()));
+            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamSources()));
       }
     };
   }
@@ -431,7 +356,7 @@ final class StreamingTransformTranslator {
         });
 
         context.putDataset(transform,
-            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamingSources()));
+            new UnboundedDataset<>(outStream, 
unboundedDataset.getStreamSources()));
       }
     };
   }
@@ -486,7 +411,7 @@ final class StreamingTransformTranslator {
               (JavaDStream<WindowedValue<Object>>)
                   (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
           context.putDataset(e.getValue(),
-              new UnboundedDataset<>(values, 
unboundedDataset.getStreamingSources()));
+              new UnboundedDataset<>(values, 
unboundedDataset.getStreamSources()));
         }
       }
     };
@@ -499,8 +424,6 @@ final class StreamingTransformTranslator {
     EVALUATORS.put(Read.Unbounded.class, readUnbounded());
     EVALUATORS.put(GroupByKey.class, groupByKey());
     EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
-    EVALUATORS.put(Combine.Globally.class, combineGlobally());
-    EVALUATORS.put(Combine.PerKey.class, combinePerKey());
     EVALUATORS.put(ParDo.Bound.class, parDo());
     EVALUATORS.put(ParDo.BoundMulti.class, multiDo());
     EVALUATORS.put(ConsoleIO.Write.Unbound.class, print());
@@ -523,7 +446,7 @@ final class StreamingTransformTranslator {
     @Override
     public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) {
       // streaming includes rdd/bounded transformations as well
-      return EVALUATORS.containsKey(clazz) || 
batchTranslator.hasTranslation(clazz);
+      return EVALUATORS.containsKey(clazz);
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
index 6f5fa93..8624f41 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java
@@ -56,11 +56,11 @@ public class UnboundedDataset<T> implements Dataset {
   // should be greater > 1 in case of Flatten for example.
   // when using GlobalWatermarkHolder this information helps to take only the 
relevant watermarks
   // and reason about them accordingly.
-  private final List<Integer> streamingSources = new ArrayList<>();
+  private final List<Integer> streamSources = new ArrayList<>();
 
-  public UnboundedDataset(JavaDStream<WindowedValue<T>> dStream, List<Integer> 
streamingSources) {
+  public UnboundedDataset(JavaDStream<WindowedValue<T>> dStream, List<Integer> 
streamSources) {
     this.dStream = dStream;
-    this.streamingSources.addAll(streamingSources);
+    this.streamSources.addAll(streamSources);
   }
 
   public UnboundedDataset(Iterable<Iterable<T>> values, JavaStreamingContext 
jssc, Coder<T> coder) {
@@ -68,7 +68,7 @@ public class UnboundedDataset<T> implements Dataset {
     this.jssc = jssc;
     this.coder = coder;
     // QueuedStream will have a negative (decreasing) unique id.
-    this.streamingSources.add(queuedStreamIds.decrementAndGet());
+    this.streamSources.add(queuedStreamIds.decrementAndGet());
   }
 
   @VisibleForTesting
@@ -97,8 +97,8 @@ public class UnboundedDataset<T> implements Dataset {
     return dStream;
   }
 
-  public List<Integer> getStreamingSources() {
-    return streamingSources;
+  public List<Integer> getStreamSources() {
+    return streamSources;
   }
 
   public void cache() {

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java
index 96e6ee5..18689bd 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java
@@ -89,4 +89,4 @@ public class LateDataUtils {
               }
             });
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java
index 6de7e86..96d889d 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java
@@ -49,4 +49,4 @@ public class UnsupportedSideInputReader implements 
SideInputReader {
     throw new UnsupportedOperationException(
         String.format("%s does not support side inputs.", transformName));
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
index fbe5777..8449724 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java
@@ -157,7 +157,7 @@ public class TrackStreamingSourcesTest {
         ctxt.setCurrentTransform(appliedTransform);
         //noinspection unchecked
         Dataset dataset = ctxt.borrowDataset((PTransform<? extends PValue, ?>) 
transform);
-        assertSourceIds(((UnboundedDataset<?>) dataset).getStreamingSources());
+        assertSourceIds(((UnboundedDataset<?>) dataset).getStreamSources());
         ctxt.setCurrentTransform(null);
       } else {
         evaluator.visitPrimitiveTransform(node);

Reply via email to