Repository: beam
Updated Branches:
  refs/heads/master a481d5611 -> f7d4583bd


[BEAM-2825] Refactored SparkGroupAlsoByWindowViaWindowSet to improve 
readability.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c8b99ba3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c8b99ba3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c8b99ba3

Branch: refs/heads/master
Commit: c8b99ba393c54da1a3ffbc61c2e5f2ae92b0b2bb
Parents: a481d56
Author: Stas Levin <stasle...@apache.org>
Authored: Wed Aug 30 12:01:32 2017 +0300
Committer: Stas Levin <stasle...@apache.org>
Committed: Sun Sep 3 15:40:25 2017 +0300

----------------------------------------------------------------------
 .../SparkGroupAlsoByWindowViaWindowSet.java     | 878 +++++++++++--------
 1 file changed, 498 insertions(+), 380 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/c8b99ba3/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
index e6a55a6..2258f05 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java
@@ -58,12 +58,12 @@ import 
org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext$;
 import org.apache.spark.api.java.function.FlatMapFunction;
 import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
 import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.Time;
 import org.apache.spark.streaming.api.java.JavaDStream;
@@ -73,435 +73,553 @@ import 
org.apache.spark.streaming.dstream.PairDStreamFunctions;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import scala.Function1;
 import scala.Option;
 import scala.Tuple2;
 import scala.Tuple3;
+import scala.collection.Iterator;
 import scala.collection.Seq;
-import scala.reflect.ClassTag;
 import scala.runtime.AbstractFunction1;
 
 /**
- * An implementation of {@link GroupAlsoByWindow}
- * logic for grouping by windows and controlling trigger firings and pane 
accumulation.
+ * An implementation of {@link GroupAlsoByWindow} logic for grouping by 
windows and controlling
+ * trigger firings and pane accumulation.
  *
  * <p>This implementation is a composite of Spark transformations revolving 
around state management
- * using Spark's
- * {@link PairDStreamFunctions#updateStateByKey(Function1, Partitioner, 
boolean, ClassTag)}
- * to update state with new data and timers.
+ * using Spark's {@link PairDStreamFunctions#updateStateByKey(scala.Function1,
+ * org.apache.spark.Partitioner, boolean, scala.reflect.ClassTag)} to update 
state with new data and
+ * timers.
  *
- * <p>Using updateStateByKey allows to scan through the entire state visiting 
not just the
- * updated state (new values for key) but also check if timers are ready to 
fire.
- * Since updateStateByKey bounds the types of state and output to be the same,
- * a (state, output) tuple is used, filtering the state (and output if no 
firing)
- * in the following steps.
+ * <p>Using updateStateByKey allows to scan through the entire state visiting 
not just the updated
+ * state (new values for key) but also check if timers are ready to fire. 
Since updateStateByKey
+ * bounds the types of state and output to be the same, a (state, output) 
tuple is used, filtering
+ * the state (and output if no firing) in the following steps.
  */
 public class SparkGroupAlsoByWindowViaWindowSet implements Serializable {
-  private static final Logger LOG = LoggerFactory.getLogger(
-      SparkGroupAlsoByWindowViaWindowSet.class);
-
-  /**
-   * A helper class that is essentially a {@link Serializable} {@link 
AbstractFunction1}.
-   */
-  private abstract static class SerializableFunction1<T1, T2>
-      extends AbstractFunction1<T1, T2> implements Serializable {
-  }
+  private static final Logger LOG =
+      LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class);
 
-  public static <K, InputT, W extends BoundedWindow>
-      JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(
-      final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> 
inputDStream,
-      final Coder<K> keyCoder,
-      final Coder<WindowedValue<InputT>> wvCoder,
-      final WindowingStrategy<?, W> windowingStrategy,
-      final SerializablePipelineOptions options,
-      final List<Integer> sourceIds,
-      final String transformFullName) {
-
-    final long batchDurationMillis =
-        options.get().as(SparkPipelineOptions.class).getBatchIntervalMillis();
-    final IterableCoder<WindowedValue<InputT>> itrWvCoder = 
IterableCoder.of(wvCoder);
-    final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) 
wvCoder).getValueCoder();
-    final Coder<? extends BoundedWindow> wCoder =
-        ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder();
-    final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder =
-        FullWindowedValueCoder.of(KvCoder.of(keyCoder, 
IterableCoder.of(iCoder)), wCoder);
-    final TimerInternals.TimerDataCoder timerDataCoder =
-        
TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
-
-    long checkpointDurationMillis =
-        options.get().as(SparkPipelineOptions.class)
-            .getCheckpointDurationMillis();
+  private static class StateAndTimers implements Serializable {
+    //Serializable state for internals (namespace to state tag to coded value).
+    private final Table<String, String, byte[]> state;
+    private final Collection<byte[]> serTimers;
 
-    // we have to switch to Scala API to avoid Optional in the Java API, see: 
SPARK-4819.
-    // we also have a broader API for Scala (access to the actual key and 
entire iterator).
-    // we use coders to convert objects in the PCollection to byte arrays, so 
they
-    // can be transferred over the network for the shuffle and be in 
serialized form
-    // for checkpointing.
-    // for readability, we add comments with actual type next to byte[].
-    // to shorten line length, we use:
-    //---- WV: WindowedValue
-    //---- Iterable: Itr
-    //---- AccumT: A
-    //---- InputT: I
-    DStream<Tuple2</*K*/ ByteArray, /*Itr<WV<I>>*/ byte[]>> pairDStream =
-        inputDStream
-            .transformToPair(
-                new org.apache.spark.api.java.function.Function2<
-                    JavaRDD<WindowedValue<KV<K, 
Iterable<WindowedValue<InputT>>>>>,
-                    Time, JavaPairRDD<ByteArray, byte[]>>() {
-                  // we use mapPartitions with the RDD API because its the 
only available API
-                  // that allows to preserve partitioning.
-                  @Override
-                  public JavaPairRDD<ByteArray, byte[]> call(
-                      JavaRDD<WindowedValue<KV<K, 
Iterable<WindowedValue<InputT>>>>> rdd,
-                      final Time time)
-                      throws Exception {
-                    return rdd.mapPartitions(
-                        TranslationUtils.functionToFlatMapFunction(
-                            WindowingHelpers
-                                .<KV<K, 
Iterable<WindowedValue<InputT>>>>unwindowFunction()),
-                        true)
-                              .mapPartitionsToPair(
-                                  TranslationUtils
-                                      .<K, 
Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(),
-                                  true)
-                              .mapValues(new 
Function<Iterable<WindowedValue<InputT>>, KV<Long,
-                                  Iterable<WindowedValue<InputT>>>>() {
-
-                                @Override
-                                public KV<Long, 
Iterable<WindowedValue<InputT>>> call
-                                    (Iterable<WindowedValue<InputT>> values)
-                                    throws Exception {
-                                  // add the batch timestamp for visibility 
(e.g., debugging)
-                                  return KV.of(time.milliseconds(), values);
-                                }
-                              })
-                              // move to bytes representation and use coders 
for deserialization
-                              // because of checkpointing.
-                              .mapPartitionsToPair(
-                                  
TranslationUtils.pairFunctionToPairFlatMapFunction(
-                                      CoderHelpers.toByteFunction(keyCoder,
-                                                                  
KvCoder.of(VarLongCoder.of(),
-                                                                             
itrWvCoder))),
-                                  true);
-                  }
-                })
-            .dstream();
+    private StateAndTimers(
+        final Table<String, String, byte[]> state, final Collection<byte[]> 
timers) {
+      this.state = state;
+      this.serTimers = timers;
+    }
 
-    PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions =
-        DStream.toPairDStreamFunctions(
-        pairDStream,
-        JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(),
-        JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(),
-        null);
-    int defaultNumPartitions = 
pairDStreamFunctions.defaultPartitioner$default$1();
-    Partitioner partitioner = 
pairDStreamFunctions.defaultPartitioner(defaultNumPartitions);
+    public Table<String, String, byte[]> getState() {
+      return state;
+    }
 
-    // use updateStateByKey to scan through the state and update elements and 
timers.
-    DStream<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>>
-        firedStream = pairDStreamFunctions.updateStateByKey(
-            new SerializableFunction1<
-                scala.collection.Iterator<Tuple3</*K*/ ByteArray, 
Seq</*Itr<WV<I>>*/ byte[]>,
-                    Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ 
List<byte[]>>>>>,
-                scala.collection.Iterator<Tuple2</*K*/ ByteArray, 
Tuple2<StateAndTimers,
-                    /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>() {
+    Collection<byte[]> getTimers() {
+      return serTimers;
+    }
+  }
 
-      @Override
-      public scala.collection.Iterator<Tuple2</*K*/ ByteArray, 
Tuple2<StateAndTimers,
-          /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> apply(
-              final scala.collection.Iterator<Tuple3</*K*/ ByteArray, 
Seq</*Itr<WV<I>>*/ byte[]>,
-              Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ 
List<byte[]>>>>> iter) {
-        //--- ACTUAL STATEFUL OPERATION:
-        //
-        // Input Iterator: the partition (~bundle) of a cogrouping of the input
-        // and the previous state (if exists).
-        //
-        // Output Iterator: the output key, and the updated state.
-        //
-        // possible input scenarios for (K, Seq, Option<S>):
-        // (1) Option<S>.isEmpty: new data with no previous state.
-        // (2) Seq.isEmpty: no new data, but evaluating previous state 
(timer-like behaviour).
-        // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous 
state.
-
-        final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> 
reduceFn =
-            SystemReduceFn.buffering(
-                ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder());
-        // use in memory Aggregators since Spark Accumulators are not resilient
-        // in stateful operators, once done with this partition.
-        final MetricsContainerImpl cellProvider = new 
MetricsContainerImpl("cellProvider");
-        final CounterCell droppedDueToClosedWindow = cellProvider.getCounter(
-            MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class,
-            
GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER));
-        final CounterCell droppedDueToLateness = cellProvider.getCounter(
-            MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class,
-                
GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER));
-
-        AbstractIterator<
-            Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
KV<Long(Time),Itr<I>>>>*/
-                List<byte[]>>>>
-                outIter = new AbstractIterator<Tuple2</*K*/ ByteArray,
-                    Tuple2<StateAndTimers, /*WV<KV<K, 
KV<Long(Time),Itr<I>>>>*/ List<byte[]>>>>() {
-                  @Override
-                  protected Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers,
-                      /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> computeNext() {
-                    // input iterator is a Spark partition (~bundle), 
containing keys and their
-                    // (possibly) previous-state and (possibly) new data.
-                    while (iter.hasNext()) {
-                      // for each element in the partition:
-                      Tuple3<ByteArray, Seq<byte[]>,
-                          Option<Tuple2<StateAndTimers, List<byte[]>>>> next = 
iter.next();
-                      ByteArray encodedKey = next._1();
-                      K key = 
CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder);
-
-                      Seq<byte[]> seq = next._2();
-
-                      Option<Tuple2<StateAndTimers,
-                          List<byte[]>>> prevStateAndTimersOpt = next._3();
-
-                      SparkStateInternals<K> stateInternals;
-                      Map<Integer, GlobalWatermarkHolder.SparkWatermarks> 
watermarks =
-                          GlobalWatermarkHolder.get(batchDurationMillis);
-                      SparkTimerInternals timerInternals = 
SparkTimerInternals.forStreamFromSources(
-                          sourceIds, watermarks);
-
-                      // get state(internals) per key.
-                      if (prevStateAndTimersOpt.isEmpty()) {
-                        // no previous state.
-                        stateInternals = SparkStateInternals.forKey(key);
-                      } else {
-                        // with pre-existing state.
-                        StateAndTimers prevStateAndTimers = 
prevStateAndTimersOpt.get()._1();
-                        stateInternals = 
SparkStateInternals.forKeyAndState(key,
-                            prevStateAndTimers.getState());
-                        Collection<byte[]> serTimers = 
prevStateAndTimers.getTimers();
-                        timerInternals.addTimers(
-                            SparkTimerInternals.deserializeTimers(serTimers, 
timerDataCoder));
-                      }
-
-                      final OutputWindowedValueHolder<K, InputT> outputHolder =
-                          new OutputWindowedValueHolder<>();
-
-                      ReduceFnRunner<K, InputT, Iterable<InputT>, W> 
reduceFnRunner =
-                          new ReduceFnRunner<>(
-                              key,
-                              windowingStrategy,
-                              ExecutableTriggerStateMachine.create(
-                                  TriggerStateMachines.stateMachineForTrigger(
-                                      
TriggerTranslation.toProto(windowingStrategy.getTrigger()))),
-                              stateInternals,
-                              timerInternals,
-                              outputHolder,
-                              new 
UnsupportedSideInputReader("GroupAlsoByWindow"),
-                              reduceFn,
-                              options.get());
-
-                      if (!seq.isEmpty()) {
-                        // new input for key.
-                        try {
-                          final KV<Long, Iterable<WindowedValue<InputT>>> 
keyedElements =
-                              CoderHelpers.fromByteArray(seq.head(),
-                                                         
KvCoder.of(VarLongCoder.of(), itrWvCoder));
-
-                          final Long rddTimestamp = keyedElements.getKey();
-
-                          LOG.debug(
-                              transformFullName
-                                  + ": processing RDD with timestamp: {}, 
watermarks: {}",
-                              rddTimestamp,
-                              watermarks);
-
-                          final Iterable<WindowedValue<InputT>> elements = 
keyedElements.getValue();
-
-                          LOG.trace(transformFullName + ": input elements: 
{}", elements);
-
-                          /*
-                          Incoming expired windows are filtered based on
-                          timerInternals.currentInputWatermarkTime() and the 
configured allowed
-                          lateness. Note that this is done prior to calling
-                          timerInternals.advanceWatermark so essentially the 
inputWatermark is
-                          the highWatermark of the previous batch and the 
lowWatermark of the
-                          current batch.
-                          The highWatermark of the current batch will only 
affect filtering
-                          as of the next batch.
-                           */
-                          final Iterable<WindowedValue<InputT>> 
nonExpiredElements =
-                              Lists.newArrayList(LateDataUtils
-                                                     .dropExpiredWindows(
-                                                         key,
-                                                         elements,
-                                                         timerInternals,
-                                                         windowingStrategy,
-                                                         
droppedDueToLateness));
-
-                          LOG.trace(transformFullName + ": non expired input 
elements: {}",
-                                    elements);
-
-                          reduceFnRunner.processElements(nonExpiredElements);
-                        } catch (Exception e) {
-                          throw new RuntimeException(
-                              "Failed to process element with ReduceFnRunner", 
e);
-                        }
-                      } else if (stateInternals.getState().isEmpty()) {
-                        // no input and no state -> GC evict now.
-                        continue;
-                      }
-                      try {
-                        // advance the watermark to HWM to fire by timers.
-                        LOG.debug(transformFullName + ": timerInternals before 
advance are {}",
-                                  timerInternals.toString());
-
-                        // store the highWatermark as the new inputWatermark 
to calculate triggers
-                        timerInternals.advanceWatermark();
-
-                        LOG.debug(transformFullName + ": timerInternals after 
advance are {}",
-                                  timerInternals.toString());
-
-                        // call on timers that are ready.
-                        final Collection<TimerInternals.TimerData> 
readyToProcess =
-                            timerInternals.getTimersReadyToProcess();
-
-                        LOG.debug(transformFullName + ": ready timers are {}", 
readyToProcess);
-
-                        /*
-                        Note that at this point, the watermark has already 
advanced since
-                        timerInternals.advanceWatermark() has been called and 
the highWatermark
-                        is now stored as the new inputWatermark, according to 
which triggers are
-                        calculated.
-                         */
-                        reduceFnRunner.onTimers(readyToProcess);
-                      } catch (Exception e) {
-                        throw new RuntimeException(
-                            "Failed to process ReduceFnRunner onTimer.", e);
-                      }
-                      // this is mostly symbolic since actual persist is done 
by emitting output.
-                      reduceFnRunner.persist();
-                      // obtain output, if fired.
-                      List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = 
outputHolder.get();
-
-                      if (!outputs.isEmpty() || 
!stateInternals.getState().isEmpty()) {
-                        // empty outputs are filtered later using DStream 
filtering
-                        StateAndTimers updated = new 
StateAndTimers(stateInternals.getState(),
-                            SparkTimerInternals.serializeTimers(
-                                timerInternals.getTimers(), timerDataCoder));
-
-                        /*
-                        Not something we want to happen in production, but is 
very helpful
-                        when debugging - TRACE.
-                         */
-                        LOG.trace(transformFullName + ": output elements are 
{}",
-                                  Joiner.on(", ").join(outputs));
-
-                        // persist Spark's state by outputting.
-                        List<byte[]> serOutput = 
CoderHelpers.toByteArrays(outputs, wvKvIterCoder);
-                        return new Tuple2<>(encodedKey, new Tuple2<>(updated, 
serOutput));
-                      }
-                      // an empty state with no output, can be evicted 
completely - do nothing.
-                    }
-                    return endOfData();
-                  }
-        };
+  private static class OutputWindowedValueHolder<K, V>
+      implements OutputWindowedValue<KV<K, Iterable<V>>> {
+    private final List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new 
ArrayList<>();
+
+    @Override
+    public void outputWindowedValue(
+        final KV<K, Iterable<V>> output,
+        final Instant timestamp,
+        final Collection<? extends BoundedWindow> windows,
+        final PaneInfo pane) {
+      windowedValues.add(WindowedValue.of(output, timestamp, windows, pane));
+    }
+
+    private List<WindowedValue<KV<K, Iterable<V>>>> getWindowedValues() {
+      return windowedValues;
+    }
+
+    @Override
+    public <AdditionalOutputT> void outputWindowedValue(
+        final TupleTag<AdditionalOutputT> tag,
+        final AdditionalOutputT output,
+        final Instant timestamp,
+        final Collection<? extends BoundedWindow> windows,
+        final PaneInfo pane) {
+      throw new UnsupportedOperationException(
+          "Tagged outputs are not allowed in GroupAlsoByWindow.");
+    }
+  }
 
-        // log if there's something to log.
-        long lateDropped = droppedDueToLateness.getCumulative();
-        if (lateDropped > 0) {
-          LOG.info(String.format("Dropped %d elements due to lateness.", 
lateDropped));
-          droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
+  private static class UpdateStateByKeyFunction<K, InputT, W extends 
BoundedWindow>
+      extends AbstractFunction1<
+          Iterator<
+              Tuple3<
+                  /*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>,
+                  Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ 
List<byte[]>>>>>,
+          Iterator<
+              Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>>>
+      implements Serializable {
+
+    private class UpdateStateByKeyOutputIterator
+        extends AbstractIterator<
+            Tuple2<
+                /*K*/ ByteArray,
+                Tuple2<StateAndTimers, /*WV<KV<K, KV<Long(Time),Itr<I>>>>*/ 
List<byte[]>>>> {
+
+      private final Iterator<
+              Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, 
List<byte[]>>>>>
+          input;
+      private final SystemReduceFn<K, InputT, Iterable<InputT>, 
Iterable<InputT>, W> reduceFn;
+      private final CounterCell droppedDueToLateness;
+
+      private SparkStateInternals<K> processPreviousState(
+          final Option<Tuple2<StateAndTimers, List<byte[]>>> 
prevStateAndTimersOpt,
+          final K key,
+          final SparkTimerInternals timerInternals) {
+
+        final SparkStateInternals<K> stateInternals;
+
+        if (prevStateAndTimersOpt.isEmpty()) {
+          // no previous state.
+          stateInternals = SparkStateInternals.forKey(key);
+        } else {
+          // with pre-existing state.
+          final StateAndTimers prevStateAndTimers = 
prevStateAndTimersOpt.get()._1();
+          // get state(internals) per key.
+          stateInternals = SparkStateInternals.forKeyAndState(key, 
prevStateAndTimers.getState());
+
+          timerInternals.addTimers(
+              SparkTimerInternals.deserializeTimers(
+                  prevStateAndTimers.getTimers(), timerDataCoder));
         }
-        long closedWindowDropped = droppedDueToClosedWindow.getCumulative();
-        if (closedWindowDropped > 0) {
-          LOG.info(String.format("Dropped %d elements due to closed window.", 
closedWindowDropped));
-          
droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
+
+        return stateInternals;
+      }
+
+      UpdateStateByKeyOutputIterator(
+          final Iterator<
+                  Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, 
List<byte[]>>>>>
+              input,
+          final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, 
W> reduceFn,
+          final CounterCell droppedDueToLateness) {
+        this.input = input;
+        this.reduceFn = reduceFn;
+        this.droppedDueToLateness = droppedDueToLateness;
+      }
+
+      @Override
+      protected Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>
+          computeNext() {
+        // input iterator is a Spark partition (~bundle), containing keys and 
their
+        // (possibly) previous-state and (possibly) new data.
+        while (input.hasNext()) {
+
+          // for each element in the partition:
+          final Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, 
List<byte[]>>>> next =
+              input.next();
+
+          final ByteArray encodedKey = next._1();
+          final Seq<byte[]> encodedKeyedElements = next._2();
+          final Option<Tuple2<StateAndTimers, List<byte[]>>> 
prevStateAndTimersOpt = next._3();
+
+          final K key = CoderHelpers.fromByteArray(encodedKey.getValue(), 
keyCoder);
+
+          final Map<Integer, GlobalWatermarkHolder.SparkWatermarks> watermarks 
=
+              GlobalWatermarkHolder.get(getBatchDuration(options));
+
+          final SparkTimerInternals timerInternals =
+              SparkTimerInternals.forStreamFromSources(sourceIds, watermarks);
+
+          final SparkStateInternals<K> stateInternals =
+              processPreviousState(prevStateAndTimersOpt, key, timerInternals);
+
+          final ExecutableTriggerStateMachine triggerStateMachine =
+              ExecutableTriggerStateMachine.create(
+                  TriggerStateMachines.stateMachineForTrigger(
+                      
TriggerTranslation.toProto(windowingStrategy.getTrigger())));
+
+          final OutputWindowedValueHolder<K, InputT> outputHolder =
+              new OutputWindowedValueHolder<>();
+
+          final ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner =
+              new ReduceFnRunner<>(
+                  key,
+                  windowingStrategy,
+                  triggerStateMachine,
+                  stateInternals,
+                  timerInternals,
+                  outputHolder,
+                  new UnsupportedSideInputReader("GroupAlsoByWindow"),
+                  reduceFn,
+                  options.get());
+
+          if (!encodedKeyedElements.isEmpty()) {
+            // new input for key.
+            try {
+              final KV<Long, Iterable<WindowedValue<InputT>>> keyedElements =
+                  CoderHelpers.fromByteArray(
+                      encodedKeyedElements.head(), 
KvCoder.of(VarLongCoder.of(), itrWvCoder));
+
+              final Long rddTimestamp = keyedElements.getKey();
+
+              LOG.debug(
+                  logPrefix + ": processing RDD with timestamp: {}, 
watermarks: {}",
+                  rddTimestamp,
+                  watermarks);
+
+              final Iterable<WindowedValue<InputT>> elements = 
keyedElements.getValue();
+
+              LOG.trace(logPrefix + ": input elements: {}", elements);
+
+              /*
+              Incoming expired windows are filtered based on
+              timerInternals.currentInputWatermarkTime() and the configured 
allowed
+              lateness. Note that this is done prior to calling
+              timerInternals.advanceWatermark so essentially the 
inputWatermark is
+              the highWatermark of the previous batch and the lowWatermark of 
the
+              current batch.
+              The highWatermark of the current batch will only affect filtering
+              as of the next batch.
+               */
+              final Iterable<WindowedValue<InputT>> nonExpiredElements =
+                  Lists.newArrayList(
+                      LateDataUtils.dropExpiredWindows(
+                          key, elements, timerInternals, windowingStrategy, 
droppedDueToLateness));
+
+              LOG.trace(logPrefix + ": non expired input elements: {}", 
nonExpiredElements);
+
+              reduceFnRunner.processElements(nonExpiredElements);
+            } catch (final Exception e) {
+              throw new RuntimeException("Failed to process element with 
ReduceFnRunner", e);
+            }
+          } else if (stateInternals.getState().isEmpty()) {
+            // no input and no state -> GC evict now.
+            continue;
+          }
+          try {
+            // advance the watermark to HWM to fire by timers.
+            LOG.debug(
+                logPrefix + ": timerInternals before advance are {}",
+                timerInternals.toString());
+
+            // store the highWatermark as the new inputWatermark to calculate 
triggers
+            timerInternals.advanceWatermark();
+
+            LOG.debug(
+                logPrefix + ": timerInternals after advance are {}",
+                timerInternals.toString());
+
+            // call on timers that are ready.
+            final Collection<TimerInternals.TimerData> readyToProcess =
+                timerInternals.getTimersReadyToProcess();
+
+            LOG.debug(logPrefix + ": ready timers are {}", readyToProcess);
+
+            /*
+            Note that at this point, the watermark has already advanced since
+            timerInternals.advanceWatermark() has been called and the 
highWatermark
+            is now stored as the new inputWatermark, according to which 
triggers are
+            calculated.
+             */
+            reduceFnRunner.onTimers(readyToProcess);
+          } catch (final Exception e) {
+            throw new RuntimeException("Failed to process ReduceFnRunner 
onTimer.", e);
+          }
+          // this is mostly symbolic since actual persist is done by emitting 
output.
+          reduceFnRunner.persist();
+          // obtain output, if fired.
+          final List<WindowedValue<KV<K, Iterable<InputT>>>> outputs =
+              outputHolder.getWindowedValues();
+
+          if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) {
+            // empty outputs are filtered later using DStream filtering
+            final StateAndTimers updated =
+                new StateAndTimers(
+                    stateInternals.getState(),
+                    SparkTimerInternals.serializeTimers(
+                        timerInternals.getTimers(), timerDataCoder));
+
+            /*
+            Not something we want to happen in production, but is very helpful
+            when debugging - TRACE.
+             */
+            LOG.trace(
+                logPrefix + ": output elements are {}", Joiner.on(", 
").join(outputs));
+
+            // persist Spark's state by outputting.
+            final List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, 
wvKvIterCoder);
+            return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput));
+          }
+          // an empty state with no output, can be evicted completely - do 
nothing.
         }
+        return endOfData();
+      }
+    }
+
+    private final FullWindowedValueCoder<InputT> wvCoder;
+    private final Coder<K> keyCoder;
+    private final List<Integer> sourceIds;
+    private final TimerInternals.TimerDataCoder timerDataCoder;
+    private final WindowingStrategy<?, W> windowingStrategy;
+    private final SerializablePipelineOptions options;
+    private final IterableCoder<WindowedValue<InputT>> itrWvCoder;
+    private final String logPrefix;
+    private final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder;
+
+    UpdateStateByKeyFunction(
+        final List<Integer> sourceIds,
+        final WindowingStrategy<?, W> windowingStrategy,
+        final FullWindowedValueCoder<InputT> wvCoder,
+        final Coder<K> keyCoder,
+        final SerializablePipelineOptions options,
+        final String logPrefix) {
+      this.wvCoder = wvCoder;
+      this.keyCoder = keyCoder;
+      this.sourceIds = sourceIds;
+      this.timerDataCoder = timerDataCoderOf(windowingStrategy);
+      this.windowingStrategy = windowingStrategy;
+      this.options = options;
+      this.itrWvCoder = IterableCoder.of(wvCoder);
+      this.logPrefix = logPrefix;
+      this.wvKvIterCoder =
+          windowedValueKeyValueCoderOf(
+              keyCoder,
+              wvCoder.getValueCoder(),
+              ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder());
+    }
 
-        return scala.collection.JavaConversions.asScalaIterator(outIter);
+    @Override
+    public Iterator<
+            Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>>
+        apply(
+            final Iterator<
+                    Tuple3<
+                        /*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>,
+                        Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ 
List<byte[]>>>>>
+                input) {
+      //--- ACTUAL STATEFUL OPERATION:
+      //
+      // Input Iterator: the partition (~bundle) of a co-grouping of the input
+      // and the previous state (if exists).
+      //
+      // Output Iterator: the output key, and the updated state.
+      //
+      // possible input scenarios for (K, Seq, Option<S>):
+      // (1) Option<S>.isEmpty: new data with no previous state.
+      // (2) Seq.isEmpty: no new data, but evaluating previous state 
(timer-like behaviour).
+      // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state.
+
+      final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> 
reduceFn =
+          SystemReduceFn.buffering(wvCoder.getValueCoder());
+
+      final MetricsContainerImpl cellProvider = new 
MetricsContainerImpl("cellProvider");
+
+      final CounterCell droppedDueToClosedWindow =
+          cellProvider.getCounter(
+              MetricName.named(
+                  SparkGroupAlsoByWindowViaWindowSet.class,
+                  
GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER));
+
+      final CounterCell droppedDueToLateness =
+          cellProvider.getCounter(
+              MetricName.named(
+                  SparkGroupAlsoByWindowViaWindowSet.class,
+                  
GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER));
+
+      // log if there's something to log.
+      final long lateDropped = droppedDueToLateness.getCumulative();
+      if (lateDropped > 0) {
+        LOG.info(String.format("Dropped %d elements due to lateness.", 
lateDropped));
+        droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
       }
-    }, partitioner, true,
-        JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, 
List<byte[]>>>fakeClassTag());
+      final long closedWindowDropped = 
droppedDueToClosedWindow.getCumulative();
+      if (closedWindowDropped > 0) {
+        LOG.info(String.format("Dropped %d elements due to closed window.", 
closedWindowDropped));
+        
droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
+      }
+
+      return scala.collection.JavaConversions.asScalaIterator(
+          new UpdateStateByKeyOutputIterator(input, reduceFn, 
droppedDueToLateness));
+    }
+  }
+
+  private static <K, InputT>
+      FullWindowedValueCoder<KV<K, Iterable<InputT>>> 
windowedValueKeyValueCoderOf(
+          final Coder<K> keyCoder,
+          final Coder<InputT> iCoder,
+          final Coder<? extends BoundedWindow> wCoder) {
+    return FullWindowedValueCoder.of(KvCoder.of(keyCoder, 
IterableCoder.of(iCoder)), wCoder);
+  }
+
+  private static <W extends BoundedWindow> TimerInternals.TimerDataCoder 
timerDataCoderOf(
+      final WindowingStrategy<?, W> windowingStrategy) {
+    return 
TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
+  }
+
+  private static void
+      checkpointIfNeeded(
+          final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, 
List<byte[]>>>> firedStream,
+          final SerializablePipelineOptions options) {
+
+    final Long checkpointDurationMillis = getBatchDuration(options);
 
     if (checkpointDurationMillis > 0) {
       firedStream.checkpoint(new Duration(checkpointDurationMillis));
     }
+  }
+
+  private static Long getBatchDuration(final SerializablePipelineOptions 
options) {
+    return 
options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
+  }
+
+  private static <K, InputT> JavaDStream<WindowedValue<KV<K, 
Iterable<InputT>>>> stripStateValues(
+      final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> 
firedStream,
+      final Coder<K> keyCoder,
+      final FullWindowedValueCoder<InputT> wvCoder) {
 
-    // go back to Java now.
-    JavaPairDStream</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>
-        javaFiredStream = JavaPairDStream.fromPairDStream(
+    return JavaPairDStream.fromPairDStream(
             firedStream,
             JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(),
-            JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, 
List<byte[]>>>fakeClassTag());
-
-    // filter state-only output (nothing to fire) and remove the state from 
the output.
-    return javaFiredStream.filter(
-        new Function<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers,
-            /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, Boolean>() {
+            JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, 
List<byte[]>>>fakeClassTag())
+        .filter(
+            new Function<
+                Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>,
+                Boolean>() {
               @Override
               public Boolean call(
-                  Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers,
-                  /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception {
+                  final Tuple2<
+                          /*K*/ ByteArray,
+                          Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ 
List<byte[]>>>
+                      t2)
+                  throws Exception {
                 // filter output if defined.
                 return !t2._2()._2().isEmpty();
               }
-        })
+            })
         .flatMap(
-            new FlatMapFunction<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers,
-                /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>,
+            new FlatMapFunction<
+                Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>,
                 WindowedValue<KV<K, Iterable<InputT>>>>() {
+
+              private final FullWindowedValueCoder<KV<K, Iterable<InputT>>>
+                  windowedValueKeyValueCoder =
+                      windowedValueKeyValueCoderOf(
+                          keyCoder, wvCoder.getValueCoder(), 
wvCoder.getWindowCoder());
+
               @Override
               public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(
-                  Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers,
-                  /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception {
+                  final Tuple2<
+                          /*K*/ ByteArray,
+                          Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ 
List<byte[]>>>
+                      t2)
+                  throws Exception {
                 // drop the state since it is already persisted at this point.
                 // return in serialized form.
-                return CoderHelpers.fromByteArrays(t2._2()._2(), 
wvKvIterCoder);
+                return CoderHelpers.fromByteArrays(t2._2()._2(), 
windowedValueKeyValueCoder);
               }
-        });
+            });
   }
 
-  private static class StateAndTimers implements Serializable {
-    //Serializable state for internals (namespace to state tag to coded value).
-    private final Table<String, String, byte[]> state;
-    private final Collection<byte[]> serTimers;
+  private static <K, InputT> PairDStreamFunctions<ByteArray, byte[]> 
buildPairDStream(
+      final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> 
inputDStream,
+      final Coder<K> keyCoder,
+      final Coder<WindowedValue<InputT>> wvCoder) {
 
-    private StateAndTimers(
-        Table<String, String, byte[]> state, Collection<byte[]> timers) {
-      this.state = state;
-      this.serTimers = timers;
-    }
+    // we have to switch to Scala API to avoid Optional in the Java API, see: 
SPARK-4819.
+    // we also have a broader API for Scala (access to the actual key and 
entire iterator).
+    // we use coders to convert objects in the PCollection to byte arrays, so 
they
+    // can be transferred over the network for the shuffle and be in 
serialized form
+    // for checkpointing.
+    // for readability, we add comments with actual type next to byte[].
+    // to shorten line length, we use:
+    //---- WV: WindowedValue
+    //---- Iterable: Itr
+    //---- AccumT: A
+    //---- InputT: I
+    final DStream<Tuple2<ByteArray, byte[]>> tupleDStream =
+        inputDStream
+            .transformToPair(
+                new Function2<
+                    JavaRDD<WindowedValue<KV<K, 
Iterable<WindowedValue<InputT>>>>>, Time,
+                    JavaPairRDD<ByteArray, byte[]>>() {
 
-    public Table<String, String, byte[]> getState() {
-      return state;
-    }
+                  // we use mapPartitions with the RDD API because its the 
only available API
+                  // that allows to preserve partitioning.
+                  @Override
+                  public JavaPairRDD<ByteArray, byte[]> call(
+                      final JavaRDD<WindowedValue<KV<K, 
Iterable<WindowedValue<InputT>>>>> rdd,
+                      final Time time)
+                      throws Exception {
+                    return rdd.mapPartitions(
+                            TranslationUtils.functionToFlatMapFunction(
+                                WindowingHelpers
+                                    .<KV<K, 
Iterable<WindowedValue<InputT>>>>unwindowFunction()),
+                            true)
+                        .mapPartitionsToPair(
+                            TranslationUtils
+                                .<K, 
Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(),
+                            true)
+                        .mapValues(
+                            new Function<
+                                Iterable<WindowedValue<InputT>>,
+                                KV<Long, Iterable<WindowedValue<InputT>>>>() {
+
+                              @Override
+                              public KV<Long, Iterable<WindowedValue<InputT>>> 
call(
+                                  final Iterable<WindowedValue<InputT>> 
values) throws Exception {
+                                // add the batch timestamp for visibility 
(e.g., debugging)
+                                return KV.of(time.milliseconds(), values);
+                              }
+                            })
+                        // move to bytes representation and use coders for 
deserialization
+                        // because of checkpointing.
+                        .mapPartitionsToPair(
+                            TranslationUtils.pairFunctionToPairFlatMapFunction(
+                                CoderHelpers.toByteFunction(
+                                    keyCoder,
+                                    KvCoder.of(VarLongCoder.of(), 
IterableCoder.of(wvCoder)))),
+                            true);
+                  }
+                })
+            .dstream();
 
-    public Collection<byte[]> getTimers() {
-      return serTimers;
-    }
+    return DStream.toPairDStreamFunctions(
+        tupleDStream,
+        JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(),
+        JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(),
+        null);
   }
 
-  private static class OutputWindowedValueHolder<K, V>
-      implements OutputWindowedValue<KV<K, Iterable<V>>> {
-    private List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new 
ArrayList<>();
+  public static <K, InputT, W extends BoundedWindow>
+      JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(
+          final JavaDStream<WindowedValue<KV<K, 
Iterable<WindowedValue<InputT>>>>> inputDStream,
+          final Coder<K> keyCoder,
+          final Coder<WindowedValue<InputT>> wvCoder,
+          final WindowingStrategy<?, W> windowingStrategy,
+          final SerializablePipelineOptions options,
+          final List<Integer> sourceIds,
+          final String transformFullName) {
 
-    @Override
-    public void outputWindowedValue(
-        KV<K, Iterable<V>> output,
-        Instant timestamp,
-        Collection<? extends BoundedWindow> windows,
-        PaneInfo pane) {
-      windowedValues.add(WindowedValue.of(output, timestamp, windows, pane));
-    }
+    final PairDStreamFunctions<ByteArray, byte[]> pairDStream =
+        buildPairDStream(inputDStream, keyCoder, wvCoder);
 
-    private List<WindowedValue<KV<K, Iterable<V>>>> get() {
-      return windowedValues;
-    }
+    // use updateStateByKey to scan through the state and update elements and 
timers.
+    final UpdateStateByKeyFunction<K, InputT, W> updateFunc =
+        new UpdateStateByKeyFunction<>(
+            sourceIds,
+            windowingStrategy,
+            (FullWindowedValueCoder<InputT>) wvCoder, keyCoder, options, 
transformFullName
+        );
+
+    final DStream<
+            Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, 
Itr<I>>>*/ List<byte[]>>>>
+        firedStream =
+            pairDStream.updateStateByKey(
+                updateFunc,
+                
pairDStream.defaultPartitioner(pairDStream.defaultPartitioner$default$1()),
+                true,
+                JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, 
List<byte[]>>>fakeClassTag());
+
+    checkpointIfNeeded(firedStream, options);
 
-    @Override
-    public <AdditionalOutputT> void outputWindowedValue(
-        TupleTag<AdditionalOutputT> tag,
-        AdditionalOutputT output,
-        Instant timestamp,
-        Collection<? extends BoundedWindow> windows,
-        PaneInfo pane) {
-      throw new UnsupportedOperationException(
-          "Tagged outputs are not allowed in GroupAlsoByWindow.");
-    }
+    // filter state-only output (nothing to fire) and remove the state from 
the output.
+    return stripStateValues(firedStream, keyCoder, 
(FullWindowedValueCoder<InputT>) wvCoder);
   }
 }

Reply via email to