Repository: beam Updated Branches: refs/heads/master a481d5611 -> f7d4583bd
[BEAM-2825] Refactored SparkGroupAlsoByWindowViaWindowSet to improve readability. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c8b99ba3 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c8b99ba3 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c8b99ba3 Branch: refs/heads/master Commit: c8b99ba393c54da1a3ffbc61c2e5f2ae92b0b2bb Parents: a481d56 Author: Stas Levin <stasle...@apache.org> Authored: Wed Aug 30 12:01:32 2017 +0300 Committer: Stas Levin <stasle...@apache.org> Committed: Sun Sep 3 15:40:25 2017 +0300 ---------------------------------------------------------------------- .../SparkGroupAlsoByWindowViaWindowSet.java | 878 +++++++++++-------- 1 file changed, 498 insertions(+), 380 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/c8b99ba3/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index e6a55a6..2258f05 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -58,12 +58,12 @@ import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; -import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext$; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.JavaDStream; @@ -73,435 +73,553 @@ import org.apache.spark.streaming.dstream.PairDStreamFunctions; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.Function1; import scala.Option; import scala.Tuple2; import scala.Tuple3; +import scala.collection.Iterator; import scala.collection.Seq; -import scala.reflect.ClassTag; import scala.runtime.AbstractFunction1; /** - * An implementation of {@link GroupAlsoByWindow} - * logic for grouping by windows and controlling trigger firings and pane accumulation. + * An implementation of {@link GroupAlsoByWindow} logic for grouping by windows and controlling + * trigger firings and pane accumulation. * * <p>This implementation is a composite of Spark transformations revolving around state management - * using Spark's - * {@link PairDStreamFunctions#updateStateByKey(Function1, Partitioner, boolean, ClassTag)} - * to update state with new data and timers. + * using Spark's {@link PairDStreamFunctions#updateStateByKey(scala.Function1, + * org.apache.spark.Partitioner, boolean, scala.reflect.ClassTag)} to update state with new data and + * timers. * - * <p>Using updateStateByKey allows to scan through the entire state visiting not just the - * updated state (new values for key) but also check if timers are ready to fire. - * Since updateStateByKey bounds the types of state and output to be the same, - * a (state, output) tuple is used, filtering the state (and output if no firing) - * in the following steps. + * <p>Using updateStateByKey allows to scan through the entire state visiting not just the updated + * state (new values for key) but also check if timers are ready to fire. Since updateStateByKey + * bounds the types of state and output to be the same, a (state, output) tuple is used, filtering + * the state (and output if no firing) in the following steps. */ public class SparkGroupAlsoByWindowViaWindowSet implements Serializable { - private static final Logger LOG = LoggerFactory.getLogger( - SparkGroupAlsoByWindowViaWindowSet.class); - - /** - * A helper class that is essentially a {@link Serializable} {@link AbstractFunction1}. - */ - private abstract static class SerializableFunction1<T1, T2> - extends AbstractFunction1<T1, T2> implements Serializable { - } + private static final Logger LOG = + LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class); - public static <K, InputT, W extends BoundedWindow> - JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow( - final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, - final Coder<K> keyCoder, - final Coder<WindowedValue<InputT>> wvCoder, - final WindowingStrategy<?, W> windowingStrategy, - final SerializablePipelineOptions options, - final List<Integer> sourceIds, - final String transformFullName) { - - final long batchDurationMillis = - options.get().as(SparkPipelineOptions.class).getBatchIntervalMillis(); - final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder); - final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder(); - final Coder<? extends BoundedWindow> wCoder = - ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder(); - final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = - FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder); - final TimerInternals.TimerDataCoder timerDataCoder = - TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder()); - - long checkpointDurationMillis = - options.get().as(SparkPipelineOptions.class) - .getCheckpointDurationMillis(); + private static class StateAndTimers implements Serializable { + //Serializable state for internals (namespace to state tag to coded value). + private final Table<String, String, byte[]> state; + private final Collection<byte[]> serTimers; - // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819. - // we also have a broader API for Scala (access to the actual key and entire iterator). - // we use coders to convert objects in the PCollection to byte arrays, so they - // can be transferred over the network for the shuffle and be in serialized form - // for checkpointing. - // for readability, we add comments with actual type next to byte[]. - // to shorten line length, we use: - //---- WV: WindowedValue - //---- Iterable: Itr - //---- AccumT: A - //---- InputT: I - DStream<Tuple2</*K*/ ByteArray, /*Itr<WV<I>>*/ byte[]>> pairDStream = - inputDStream - .transformToPair( - new org.apache.spark.api.java.function.Function2< - JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, - Time, JavaPairRDD<ByteArray, byte[]>>() { - // we use mapPartitions with the RDD API because its the only available API - // that allows to preserve partitioning. - @Override - public JavaPairRDD<ByteArray, byte[]> call( - JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd, - final Time time) - throws Exception { - return rdd.mapPartitions( - TranslationUtils.functionToFlatMapFunction( - WindowingHelpers - .<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), - true) - .mapPartitionsToPair( - TranslationUtils - .<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), - true) - .mapValues(new Function<Iterable<WindowedValue<InputT>>, KV<Long, - Iterable<WindowedValue<InputT>>>>() { - - @Override - public KV<Long, Iterable<WindowedValue<InputT>>> call - (Iterable<WindowedValue<InputT>> values) - throws Exception { - // add the batch timestamp for visibility (e.g., debugging) - return KV.of(time.milliseconds(), values); - } - }) - // move to bytes representation and use coders for deserialization - // because of checkpointing. - .mapPartitionsToPair( - TranslationUtils.pairFunctionToPairFlatMapFunction( - CoderHelpers.toByteFunction(keyCoder, - KvCoder.of(VarLongCoder.of(), - itrWvCoder))), - true); - } - }) - .dstream(); + private StateAndTimers( + final Table<String, String, byte[]> state, final Collection<byte[]> timers) { + this.state = state; + this.serTimers = timers; + } - PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = - DStream.toPairDStreamFunctions( - pairDStream, - JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), - JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), - null); - int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1(); - Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions); + public Table<String, String, byte[]> getState() { + return state; + } - // use updateStateByKey to scan through the state and update elements and timers. - DStream<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> - firedStream = pairDStreamFunctions.updateStateByKey( - new SerializableFunction1< - scala.collection.Iterator<Tuple3</*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, - Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>, - scala.collection.Iterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>() { + Collection<byte[]> getTimers() { + return serTimers; + } + } - @Override - public scala.collection.Iterator<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> apply( - final scala.collection.Iterator<Tuple3</*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, - Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>> iter) { - //--- ACTUAL STATEFUL OPERATION: - // - // Input Iterator: the partition (~bundle) of a cogrouping of the input - // and the previous state (if exists). - // - // Output Iterator: the output key, and the updated state. - // - // possible input scenarios for (K, Seq, Option<S>): - // (1) Option<S>.isEmpty: new data with no previous state. - // (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour). - // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state. - - final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = - SystemReduceFn.buffering( - ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder()); - // use in memory Aggregators since Spark Accumulators are not resilient - // in stateful operators, once done with this partition. - final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider"); - final CounterCell droppedDueToClosedWindow = cellProvider.getCounter( - MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, - GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER)); - final CounterCell droppedDueToLateness = cellProvider.getCounter( - MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, - GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER)); - - AbstractIterator< - Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, KV<Long(Time),Itr<I>>>>*/ - List<byte[]>>>> - outIter = new AbstractIterator<Tuple2</*K*/ ByteArray, - Tuple2<StateAndTimers, /*WV<KV<K, KV<Long(Time),Itr<I>>>>*/ List<byte[]>>>>() { - @Override - protected Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> computeNext() { - // input iterator is a Spark partition (~bundle), containing keys and their - // (possibly) previous-state and (possibly) new data. - while (iter.hasNext()) { - // for each element in the partition: - Tuple3<ByteArray, Seq<byte[]>, - Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next(); - ByteArray encodedKey = next._1(); - K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder); - - Seq<byte[]> seq = next._2(); - - Option<Tuple2<StateAndTimers, - List<byte[]>>> prevStateAndTimersOpt = next._3(); - - SparkStateInternals<K> stateInternals; - Map<Integer, GlobalWatermarkHolder.SparkWatermarks> watermarks = - GlobalWatermarkHolder.get(batchDurationMillis); - SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources( - sourceIds, watermarks); - - // get state(internals) per key. - if (prevStateAndTimersOpt.isEmpty()) { - // no previous state. - stateInternals = SparkStateInternals.forKey(key); - } else { - // with pre-existing state. - StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1(); - stateInternals = SparkStateInternals.forKeyAndState(key, - prevStateAndTimers.getState()); - Collection<byte[]> serTimers = prevStateAndTimers.getTimers(); - timerInternals.addTimers( - SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder)); - } - - final OutputWindowedValueHolder<K, InputT> outputHolder = - new OutputWindowedValueHolder<>(); - - ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = - new ReduceFnRunner<>( - key, - windowingStrategy, - ExecutableTriggerStateMachine.create( - TriggerStateMachines.stateMachineForTrigger( - TriggerTranslation.toProto(windowingStrategy.getTrigger()))), - stateInternals, - timerInternals, - outputHolder, - new UnsupportedSideInputReader("GroupAlsoByWindow"), - reduceFn, - options.get()); - - if (!seq.isEmpty()) { - // new input for key. - try { - final KV<Long, Iterable<WindowedValue<InputT>>> keyedElements = - CoderHelpers.fromByteArray(seq.head(), - KvCoder.of(VarLongCoder.of(), itrWvCoder)); - - final Long rddTimestamp = keyedElements.getKey(); - - LOG.debug( - transformFullName - + ": processing RDD with timestamp: {}, watermarks: {}", - rddTimestamp, - watermarks); - - final Iterable<WindowedValue<InputT>> elements = keyedElements.getValue(); - - LOG.trace(transformFullName + ": input elements: {}", elements); - - /* - Incoming expired windows are filtered based on - timerInternals.currentInputWatermarkTime() and the configured allowed - lateness. Note that this is done prior to calling - timerInternals.advanceWatermark so essentially the inputWatermark is - the highWatermark of the previous batch and the lowWatermark of the - current batch. - The highWatermark of the current batch will only affect filtering - as of the next batch. - */ - final Iterable<WindowedValue<InputT>> nonExpiredElements = - Lists.newArrayList(LateDataUtils - .dropExpiredWindows( - key, - elements, - timerInternals, - windowingStrategy, - droppedDueToLateness)); - - LOG.trace(transformFullName + ": non expired input elements: {}", - elements); - - reduceFnRunner.processElements(nonExpiredElements); - } catch (Exception e) { - throw new RuntimeException( - "Failed to process element with ReduceFnRunner", e); - } - } else if (stateInternals.getState().isEmpty()) { - // no input and no state -> GC evict now. - continue; - } - try { - // advance the watermark to HWM to fire by timers. - LOG.debug(transformFullName + ": timerInternals before advance are {}", - timerInternals.toString()); - - // store the highWatermark as the new inputWatermark to calculate triggers - timerInternals.advanceWatermark(); - - LOG.debug(transformFullName + ": timerInternals after advance are {}", - timerInternals.toString()); - - // call on timers that are ready. - final Collection<TimerInternals.TimerData> readyToProcess = - timerInternals.getTimersReadyToProcess(); - - LOG.debug(transformFullName + ": ready timers are {}", readyToProcess); - - /* - Note that at this point, the watermark has already advanced since - timerInternals.advanceWatermark() has been called and the highWatermark - is now stored as the new inputWatermark, according to which triggers are - calculated. - */ - reduceFnRunner.onTimers(readyToProcess); - } catch (Exception e) { - throw new RuntimeException( - "Failed to process ReduceFnRunner onTimer.", e); - } - // this is mostly symbolic since actual persist is done by emitting output. - reduceFnRunner.persist(); - // obtain output, if fired. - List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get(); - - if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) { - // empty outputs are filtered later using DStream filtering - StateAndTimers updated = new StateAndTimers(stateInternals.getState(), - SparkTimerInternals.serializeTimers( - timerInternals.getTimers(), timerDataCoder)); - - /* - Not something we want to happen in production, but is very helpful - when debugging - TRACE. - */ - LOG.trace(transformFullName + ": output elements are {}", - Joiner.on(", ").join(outputs)); - - // persist Spark's state by outputting. - List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder); - return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput)); - } - // an empty state with no output, can be evicted completely - do nothing. - } - return endOfData(); - } - }; + private static class OutputWindowedValueHolder<K, V> + implements OutputWindowedValue<KV<K, Iterable<V>>> { + private final List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new ArrayList<>(); + + @Override + public void outputWindowedValue( + final KV<K, Iterable<V>> output, + final Instant timestamp, + final Collection<? extends BoundedWindow> windows, + final PaneInfo pane) { + windowedValues.add(WindowedValue.of(output, timestamp, windows, pane)); + } + + private List<WindowedValue<KV<K, Iterable<V>>>> getWindowedValues() { + return windowedValues; + } + + @Override + public <AdditionalOutputT> void outputWindowedValue( + final TupleTag<AdditionalOutputT> tag, + final AdditionalOutputT output, + final Instant timestamp, + final Collection<? extends BoundedWindow> windows, + final PaneInfo pane) { + throw new UnsupportedOperationException( + "Tagged outputs are not allowed in GroupAlsoByWindow."); + } + } - // log if there's something to log. - long lateDropped = droppedDueToLateness.getCumulative(); - if (lateDropped > 0) { - LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped)); - droppedDueToLateness.inc(-droppedDueToLateness.getCumulative()); + private static class UpdateStateByKeyFunction<K, InputT, W extends BoundedWindow> + extends AbstractFunction1< + Iterator< + Tuple3< + /*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, + Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>>, + Iterator< + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>> + implements Serializable { + + private class UpdateStateByKeyOutputIterator + extends AbstractIterator< + Tuple2< + /*K*/ ByteArray, + Tuple2<StateAndTimers, /*WV<KV<K, KV<Long(Time),Itr<I>>>>*/ List<byte[]>>>> { + + private final Iterator< + Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>>> + input; + private final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn; + private final CounterCell droppedDueToLateness; + + private SparkStateInternals<K> processPreviousState( + final Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt, + final K key, + final SparkTimerInternals timerInternals) { + + final SparkStateInternals<K> stateInternals; + + if (prevStateAndTimersOpt.isEmpty()) { + // no previous state. + stateInternals = SparkStateInternals.forKey(key); + } else { + // with pre-existing state. + final StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1(); + // get state(internals) per key. + stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState()); + + timerInternals.addTimers( + SparkTimerInternals.deserializeTimers( + prevStateAndTimers.getTimers(), timerDataCoder)); } - long closedWindowDropped = droppedDueToClosedWindow.getCumulative(); - if (closedWindowDropped > 0) { - LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped)); - droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative()); + + return stateInternals; + } + + UpdateStateByKeyOutputIterator( + final Iterator< + Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>>> + input, + final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn, + final CounterCell droppedDueToLateness) { + this.input = input; + this.reduceFn = reduceFn; + this.droppedDueToLateness = droppedDueToLateness; + } + + @Override + protected Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> + computeNext() { + // input iterator is a Spark partition (~bundle), containing keys and their + // (possibly) previous-state and (possibly) new data. + while (input.hasNext()) { + + // for each element in the partition: + final Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>> next = + input.next(); + + final ByteArray encodedKey = next._1(); + final Seq<byte[]> encodedKeyedElements = next._2(); + final Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt = next._3(); + + final K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder); + + final Map<Integer, GlobalWatermarkHolder.SparkWatermarks> watermarks = + GlobalWatermarkHolder.get(getBatchDuration(options)); + + final SparkTimerInternals timerInternals = + SparkTimerInternals.forStreamFromSources(sourceIds, watermarks); + + final SparkStateInternals<K> stateInternals = + processPreviousState(prevStateAndTimersOpt, key, timerInternals); + + final ExecutableTriggerStateMachine triggerStateMachine = + ExecutableTriggerStateMachine.create( + TriggerStateMachines.stateMachineForTrigger( + TriggerTranslation.toProto(windowingStrategy.getTrigger()))); + + final OutputWindowedValueHolder<K, InputT> outputHolder = + new OutputWindowedValueHolder<>(); + + final ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = + new ReduceFnRunner<>( + key, + windowingStrategy, + triggerStateMachine, + stateInternals, + timerInternals, + outputHolder, + new UnsupportedSideInputReader("GroupAlsoByWindow"), + reduceFn, + options.get()); + + if (!encodedKeyedElements.isEmpty()) { + // new input for key. + try { + final KV<Long, Iterable<WindowedValue<InputT>>> keyedElements = + CoderHelpers.fromByteArray( + encodedKeyedElements.head(), KvCoder.of(VarLongCoder.of(), itrWvCoder)); + + final Long rddTimestamp = keyedElements.getKey(); + + LOG.debug( + logPrefix + ": processing RDD with timestamp: {}, watermarks: {}", + rddTimestamp, + watermarks); + + final Iterable<WindowedValue<InputT>> elements = keyedElements.getValue(); + + LOG.trace(logPrefix + ": input elements: {}", elements); + + /* + Incoming expired windows are filtered based on + timerInternals.currentInputWatermarkTime() and the configured allowed + lateness. Note that this is done prior to calling + timerInternals.advanceWatermark so essentially the inputWatermark is + the highWatermark of the previous batch and the lowWatermark of the + current batch. + The highWatermark of the current batch will only affect filtering + as of the next batch. + */ + final Iterable<WindowedValue<InputT>> nonExpiredElements = + Lists.newArrayList( + LateDataUtils.dropExpiredWindows( + key, elements, timerInternals, windowingStrategy, droppedDueToLateness)); + + LOG.trace(logPrefix + ": non expired input elements: {}", nonExpiredElements); + + reduceFnRunner.processElements(nonExpiredElements); + } catch (final Exception e) { + throw new RuntimeException("Failed to process element with ReduceFnRunner", e); + } + } else if (stateInternals.getState().isEmpty()) { + // no input and no state -> GC evict now. + continue; + } + try { + // advance the watermark to HWM to fire by timers. + LOG.debug( + logPrefix + ": timerInternals before advance are {}", + timerInternals.toString()); + + // store the highWatermark as the new inputWatermark to calculate triggers + timerInternals.advanceWatermark(); + + LOG.debug( + logPrefix + ": timerInternals after advance are {}", + timerInternals.toString()); + + // call on timers that are ready. + final Collection<TimerInternals.TimerData> readyToProcess = + timerInternals.getTimersReadyToProcess(); + + LOG.debug(logPrefix + ": ready timers are {}", readyToProcess); + + /* + Note that at this point, the watermark has already advanced since + timerInternals.advanceWatermark() has been called and the highWatermark + is now stored as the new inputWatermark, according to which triggers are + calculated. + */ + reduceFnRunner.onTimers(readyToProcess); + } catch (final Exception e) { + throw new RuntimeException("Failed to process ReduceFnRunner onTimer.", e); + } + // this is mostly symbolic since actual persist is done by emitting output. + reduceFnRunner.persist(); + // obtain output, if fired. + final List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = + outputHolder.getWindowedValues(); + + if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) { + // empty outputs are filtered later using DStream filtering + final StateAndTimers updated = + new StateAndTimers( + stateInternals.getState(), + SparkTimerInternals.serializeTimers( + timerInternals.getTimers(), timerDataCoder)); + + /* + Not something we want to happen in production, but is very helpful + when debugging - TRACE. + */ + LOG.trace( + logPrefix + ": output elements are {}", Joiner.on(", ").join(outputs)); + + // persist Spark's state by outputting. + final List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder); + return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput)); + } + // an empty state with no output, can be evicted completely - do nothing. } + return endOfData(); + } + } + + private final FullWindowedValueCoder<InputT> wvCoder; + private final Coder<K> keyCoder; + private final List<Integer> sourceIds; + private final TimerInternals.TimerDataCoder timerDataCoder; + private final WindowingStrategy<?, W> windowingStrategy; + private final SerializablePipelineOptions options; + private final IterableCoder<WindowedValue<InputT>> itrWvCoder; + private final String logPrefix; + private final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder; + + UpdateStateByKeyFunction( + final List<Integer> sourceIds, + final WindowingStrategy<?, W> windowingStrategy, + final FullWindowedValueCoder<InputT> wvCoder, + final Coder<K> keyCoder, + final SerializablePipelineOptions options, + final String logPrefix) { + this.wvCoder = wvCoder; + this.keyCoder = keyCoder; + this.sourceIds = sourceIds; + this.timerDataCoder = timerDataCoderOf(windowingStrategy); + this.windowingStrategy = windowingStrategy; + this.options = options; + this.itrWvCoder = IterableCoder.of(wvCoder); + this.logPrefix = logPrefix; + this.wvKvIterCoder = + windowedValueKeyValueCoderOf( + keyCoder, + wvCoder.getValueCoder(), + ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder()); + } - return scala.collection.JavaConversions.asScalaIterator(outIter); + @Override + public Iterator< + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> + apply( + final Iterator< + Tuple3< + /*K*/ ByteArray, Seq</*Itr<WV<I>>*/ byte[]>, + Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>>> + input) { + //--- ACTUAL STATEFUL OPERATION: + // + // Input Iterator: the partition (~bundle) of a co-grouping of the input + // and the previous state (if exists). + // + // Output Iterator: the output key, and the updated state. + // + // possible input scenarios for (K, Seq, Option<S>): + // (1) Option<S>.isEmpty: new data with no previous state. + // (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour). + // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state. + + final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = + SystemReduceFn.buffering(wvCoder.getValueCoder()); + + final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider"); + + final CounterCell droppedDueToClosedWindow = + cellProvider.getCounter( + MetricName.named( + SparkGroupAlsoByWindowViaWindowSet.class, + GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER)); + + final CounterCell droppedDueToLateness = + cellProvider.getCounter( + MetricName.named( + SparkGroupAlsoByWindowViaWindowSet.class, + GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER)); + + // log if there's something to log. + final long lateDropped = droppedDueToLateness.getCumulative(); + if (lateDropped > 0) { + LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped)); + droppedDueToLateness.inc(-droppedDueToLateness.getCumulative()); } - }, partitioner, true, - JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); + final long closedWindowDropped = droppedDueToClosedWindow.getCumulative(); + if (closedWindowDropped > 0) { + LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped)); + droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative()); + } + + return scala.collection.JavaConversions.asScalaIterator( + new UpdateStateByKeyOutputIterator(input, reduceFn, droppedDueToLateness)); + } + } + + private static <K, InputT> + FullWindowedValueCoder<KV<K, Iterable<InputT>>> windowedValueKeyValueCoderOf( + final Coder<K> keyCoder, + final Coder<InputT> iCoder, + final Coder<? extends BoundedWindow> wCoder) { + return FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder); + } + + private static <W extends BoundedWindow> TimerInternals.TimerDataCoder timerDataCoderOf( + final WindowingStrategy<?, W> windowingStrategy) { + return TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder()); + } + + private static void + checkpointIfNeeded( + final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> firedStream, + final SerializablePipelineOptions options) { + + final Long checkpointDurationMillis = getBatchDuration(options); if (checkpointDurationMillis > 0) { firedStream.checkpoint(new Duration(checkpointDurationMillis)); } + } + + private static Long getBatchDuration(final SerializablePipelineOptions options) { + return options.get().as(SparkPipelineOptions.class).getCheckpointDurationMillis(); + } + + private static <K, InputT> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> stripStateValues( + final DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> firedStream, + final Coder<K> keyCoder, + final FullWindowedValueCoder<InputT> wvCoder) { - // go back to Java now. - JavaPairDStream</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> - javaFiredStream = JavaPairDStream.fromPairDStream( + return JavaPairDStream.fromPairDStream( firedStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), - JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); - - // filter state-only output (nothing to fire) and remove the state from the output. - return javaFiredStream.filter( - new Function<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, Boolean>() { + JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()) + .filter( + new Function< + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, + Boolean>() { @Override public Boolean call( - Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception { + final Tuple2< + /*K*/ ByteArray, + Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> + t2) + throws Exception { // filter output if defined. return !t2._2()._2().isEmpty(); } - }) + }) .flatMap( - new FlatMapFunction<Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, + new FlatMapFunction< + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() { + + private final FullWindowedValueCoder<KV<K, Iterable<InputT>>> + windowedValueKeyValueCoder = + windowedValueKeyValueCoderOf( + keyCoder, wvCoder.getValueCoder(), wvCoder.getWindowCoder()); + @Override public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call( - Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, - /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> t2) throws Exception { + final Tuple2< + /*K*/ ByteArray, + Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>> + t2) + throws Exception { // drop the state since it is already persisted at this point. // return in serialized form. - return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder); + return CoderHelpers.fromByteArrays(t2._2()._2(), windowedValueKeyValueCoder); } - }); + }); } - private static class StateAndTimers implements Serializable { - //Serializable state for internals (namespace to state tag to coded value). - private final Table<String, String, byte[]> state; - private final Collection<byte[]> serTimers; + private static <K, InputT> PairDStreamFunctions<ByteArray, byte[]> buildPairDStream( + final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, + final Coder<K> keyCoder, + final Coder<WindowedValue<InputT>> wvCoder) { - private StateAndTimers( - Table<String, String, byte[]> state, Collection<byte[]> timers) { - this.state = state; - this.serTimers = timers; - } + // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819. + // we also have a broader API for Scala (access to the actual key and entire iterator). + // we use coders to convert objects in the PCollection to byte arrays, so they + // can be transferred over the network for the shuffle and be in serialized form + // for checkpointing. + // for readability, we add comments with actual type next to byte[]. + // to shorten line length, we use: + //---- WV: WindowedValue + //---- Iterable: Itr + //---- AccumT: A + //---- InputT: I + final DStream<Tuple2<ByteArray, byte[]>> tupleDStream = + inputDStream + .transformToPair( + new Function2< + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, Time, + JavaPairRDD<ByteArray, byte[]>>() { - public Table<String, String, byte[]> getState() { - return state; - } + // we use mapPartitions with the RDD API because its the only available API + // that allows to preserve partitioning. + @Override + public JavaPairRDD<ByteArray, byte[]> call( + final JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd, + final Time time) + throws Exception { + return rdd.mapPartitions( + TranslationUtils.functionToFlatMapFunction( + WindowingHelpers + .<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), + true) + .mapPartitionsToPair( + TranslationUtils + .<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), + true) + .mapValues( + new Function< + Iterable<WindowedValue<InputT>>, + KV<Long, Iterable<WindowedValue<InputT>>>>() { + + @Override + public KV<Long, Iterable<WindowedValue<InputT>>> call( + final Iterable<WindowedValue<InputT>> values) throws Exception { + // add the batch timestamp for visibility (e.g., debugging) + return KV.of(time.milliseconds(), values); + } + }) + // move to bytes representation and use coders for deserialization + // because of checkpointing. + .mapPartitionsToPair( + TranslationUtils.pairFunctionToPairFlatMapFunction( + CoderHelpers.toByteFunction( + keyCoder, + KvCoder.of(VarLongCoder.of(), IterableCoder.of(wvCoder)))), + true); + } + }) + .dstream(); - public Collection<byte[]> getTimers() { - return serTimers; - } + return DStream.toPairDStreamFunctions( + tupleDStream, + JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), + JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), + null); } - private static class OutputWindowedValueHolder<K, V> - implements OutputWindowedValue<KV<K, Iterable<V>>> { - private List<WindowedValue<KV<K, Iterable<V>>>> windowedValues = new ArrayList<>(); + public static <K, InputT, W extends BoundedWindow> + JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow( + final JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, + final Coder<K> keyCoder, + final Coder<WindowedValue<InputT>> wvCoder, + final WindowingStrategy<?, W> windowingStrategy, + final SerializablePipelineOptions options, + final List<Integer> sourceIds, + final String transformFullName) { - @Override - public void outputWindowedValue( - KV<K, Iterable<V>> output, - Instant timestamp, - Collection<? extends BoundedWindow> windows, - PaneInfo pane) { - windowedValues.add(WindowedValue.of(output, timestamp, windows, pane)); - } + final PairDStreamFunctions<ByteArray, byte[]> pairDStream = + buildPairDStream(inputDStream, keyCoder, wvCoder); - private List<WindowedValue<KV<K, Iterable<V>>>> get() { - return windowedValues; - } + // use updateStateByKey to scan through the state and update elements and timers. + final UpdateStateByKeyFunction<K, InputT, W> updateFunc = + new UpdateStateByKeyFunction<>( + sourceIds, + windowingStrategy, + (FullWindowedValueCoder<InputT>) wvCoder, keyCoder, options, transformFullName + ); + + final DStream< + Tuple2</*K*/ ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/ List<byte[]>>>> + firedStream = + pairDStream.updateStateByKey( + updateFunc, + pairDStream.defaultPartitioner(pairDStream.defaultPartitioner$default$1()), + true, + JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag()); + + checkpointIfNeeded(firedStream, options); - @Override - public <AdditionalOutputT> void outputWindowedValue( - TupleTag<AdditionalOutputT> tag, - AdditionalOutputT output, - Instant timestamp, - Collection<? extends BoundedWindow> windows, - PaneInfo pane) { - throw new UnsupportedOperationException( - "Tagged outputs are not allowed in GroupAlsoByWindow."); - } + // filter state-only output (nothing to fire) and remove the state from the output. + return stripStateValues(firedStream, keyCoder, (FullWindowedValueCoder<InputT>) wvCoder); } }