This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit bba08b4201a1dc53ae117d9bd495b117923e189d Author: Etienne Chauchot <echauc...@apache.org> AuthorDate: Thu Jun 13 11:23:52 2019 +0200 Implement reduce part of CombineGlobally translation with windowing --- .../batch/AggregatorCombinerGlobally.java | 165 +++++++++++++++++---- .../batch/CombineGloballyTranslatorBatch.java | 19 +-- 2 files changed, 144 insertions(+), 40 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java index 2f8293b..0d13218 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombinerGlobally.java @@ -18,60 +18,173 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers; -import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.Aggregator; +import org.joda.time.Instant; +import scala.Tuple2; -/** An {@link Aggregator} for the Spark Batch Runner. */ -class AggregatorCombinerGlobally<InputT, AccumT, OutputT> - extends Aggregator<InputT, AccumT, OutputT> { +/** An {@link Aggregator} for the Spark Batch Runner. It does not use ReduceFnRunner + * for windowMerging, because reduceFnRunner is based on state which requires a keyed collection. + * The accumulator is a {@code Iterable<WindowedValue<AccumT>> because an {@code InputT} can be in multiple windows. So, when accumulating {@code InputT} values, we create one accumulator per input window. + * */ + +class AggregatorCombinerGlobally<InputT, AccumT, OutputT, W extends BoundedWindow> + extends Aggregator<WindowedValue<InputT>, Iterable<WindowedValue<AccumT>>, WindowedValue<OutputT>> { private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; + private WindowingStrategy<InputT, W> windowingStrategy; + private TimestampCombiner timestampCombiner; - public AggregatorCombinerGlobally(Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + public AggregatorCombinerGlobally(Combine.CombineFn<InputT, AccumT, OutputT> combineFn, WindowingStrategy<?, ?> windowingStrategy) { this.combineFn = combineFn; + this.windowingStrategy = (WindowingStrategy<InputT, W>) windowingStrategy; + this.timestampCombiner = windowingStrategy.getTimestampCombiner(); } - @Override - public AccumT zero() { - return combineFn.createAccumulator(); + @Override public Iterable<WindowedValue<AccumT>> zero() { + return new ArrayList<>(); } - @Override - public AccumT reduce(AccumT accumulator, InputT input) { - // because of generic type InputT, spark cannot infer an input type. - // it would pass Integer as input if we had a Aggregator<Integer, ..., ...> - // without the type inference it stores input in a GenericRowWithSchema - Row row = (Row) input; - InputT t = RowHelpers.extractObjectFromRow(row); - return combineFn.addInput(accumulator, t); + @Override public Iterable<WindowedValue<AccumT>> reduce(Iterable<WindowedValue<AccumT>> accumulators, + WindowedValue<InputT> input) { + + //concatenate accumulators windows and input windows and merge the windows + Collection<W> inputWindows = (Collection<W>)input.getWindows(); + Set<W> windows = collectAccumulatorsWindows(accumulators); + windows.addAll(inputWindows); + Map<W, W> windowToMergeResult = null; + try { + windowToMergeResult = mergeWindows(windowingStrategy, windows); + } catch (Exception e) { + throw new RuntimeException("Unable to merge accumulators windows and input windows", e); + } + + // iterate through the input windows and for each, create an accumulator with the merged window + // associated to it and call addInput with the accumulator. + // Maintain a map of the accumulators for use as output + Map<W, Tuple2<AccumT, Instant>> mapState = new HashMap<>(); + for (W inputWindow:inputWindows) { + W mergedWindow = windowToMergeResult.get(inputWindow); + mergedWindow = mergedWindow == null ? inputWindow : mergedWindow; + Tuple2<AccumT, Instant> accumAndInstant = mapState.get(mergedWindow); + // if there is no accumulator associated with this window yet, create one + if (accumAndInstant == null) { + AccumT accum = combineFn.addInput(combineFn.createAccumulator(), input.getValue()); + Instant windowTimestamp = + timestampCombiner.assign( + mergedWindow, windowingStrategy.getWindowFn().getOutputTime(input.getTimestamp(), mergedWindow)); + accumAndInstant = new Tuple2<>(accum, windowTimestamp); + mapState.put(mergedWindow, accumAndInstant); + } else { + AccumT updatedAccum = + combineFn.addInput(accumAndInstant._1, input.getValue()); + Instant updatedTimestamp = timestampCombiner.combine(accumAndInstant._2, timestampCombiner + .assign(mergedWindow, + windowingStrategy.getWindowFn().getOutputTime(input.getTimestamp(), mergedWindow))); + accumAndInstant = new Tuple2<>(updatedAccum, updatedTimestamp); + } + } + // output the accumulators map + List<WindowedValue<AccumT>> result = new ArrayList<>(); + for (Map.Entry<W, Tuple2<AccumT, Instant>> entry : mapState.entrySet()) { + AccumT accumulator = entry.getValue()._1; + Instant windowTimestamp = entry.getValue()._2; + W window = entry.getKey(); + result.add(WindowedValue.of(accumulator, windowTimestamp, window, PaneInfo.NO_FIRING)); + } + return result; } - @Override - public AccumT merge(AccumT accumulator1, AccumT accumulator2) { + @Override public Iterable<WindowedValue<AccumT>> merge( + Iterable<WindowedValue<AccumT>> accumulators1, + Iterable<WindowedValue<AccumT>> accumulators2) { + // TODO + /* ArrayList<AccumT> accumulators = new ArrayList<>(); accumulators.add(accumulator1); accumulators.add(accumulator2); return combineFn.mergeAccumulators(accumulators); +*/ + return null; } - @Override - public OutputT finish(AccumT reduction) { - return combineFn.extractOutput(reduction); + @Override public WindowedValue<OutputT> finish(Iterable<WindowedValue<AccumT>> reduction) { + // TODO + // return combineFn.extractOutput(reduction); + return null; } - @Override - public Encoder<AccumT> bufferEncoder() { + @Override public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() { // TODO replace with accumulatorCoder if possible return EncoderHelpers.genericEncoder(); } - @Override - public Encoder<OutputT> outputEncoder() { + @Override public Encoder<WindowedValue<OutputT>> outputEncoder() { // TODO replace with outputCoder if possible return EncoderHelpers.genericEncoder(); } + + private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> accumulators) { + Set<W> windows = new HashSet<>(); + for (WindowedValue<?> accumulator : accumulators) { + // an accumulator has only one window associated to it. + W accumulatorWindow = (W) accumulator.getWindows().iterator().next(); + windows.add(accumulatorWindow); + } return windows; + } + + private Map<W, W> mergeWindows(WindowingStrategy<InputT, W> windowingStrategy, Set<W> windows) + throws Exception { + WindowFn<InputT, W> windowFn = windowingStrategy.getWindowFn(); + + if (windowingStrategy.getWindowFn().isNonMerging()) { + // Return an empty map, indicating that every window is not merged. + return Collections.emptyMap(); + } + + Map<W, W> windowToMergeResult = new HashMap<>(); + windowFn.mergeWindows(new MergeContextImpl(windowFn, windows, windowToMergeResult)); + return windowToMergeResult; + } + + + private class MergeContextImpl extends WindowFn<InputT, W>.MergeContext { + + private Set<W> windows; + private Map<W, W> windowToMergeResult; + + MergeContextImpl(WindowFn<InputT, W> windowFn, Set<W> windows, Map<W, W> windowToMergeResult) { + windowFn.super(); + this.windows = windows; + this.windowToMergeResult = windowToMergeResult; + } + + @Override + public Collection<W> windows() { + return windows; + } + + @Override + public void merge(Collection<W> toBeMerged, W mergeResult) throws Exception { + for (W w : toBeMerged) { + windowToMergeResult.put(w, mergeResult); + } + } + } + } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java index 53651cf..f18572b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java @@ -26,6 +26,7 @@ import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; @@ -50,25 +51,15 @@ class CombineGloballyTranslatorBatch<InputT, AccumT, OutputT> @SuppressWarnings("unchecked") final Combine.CombineFn<InputT, AccumT, OutputT> combineFn = (Combine.CombineFn<InputT, AccumT, OutputT>) combineTransform.getFn(); - + WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); Dataset<WindowedValue<InputT>> inputDataset = context.getDataset(input); - //TODO merge windows instead of doing unwindow/window to comply with beam model - Dataset<InputT> unWindowedDataset = - inputDataset.map(WindowingHelpers.unwindowMapFunction(), EncoderHelpers.genericEncoder()); - Dataset<Row> combinedRowDataset = - unWindowedDataset.agg(new AggregatorCombinerGlobally<>(combineFn).toColumn()); - - Dataset<OutputT> combinedDataset = - combinedRowDataset.map( - RowHelpers.extractObjectFromRowMapFunction(), EncoderHelpers.genericEncoder()); + inputDataset.agg(new AggregatorCombinerGlobally<>(combineFn, windowingStrategy).toColumn()); - // Window the result into global window. Dataset<WindowedValue<OutputT>> outputDataset = - combinedDataset.map( - WindowingHelpers.windowMapFunction(), EncoderHelpers.windowedValueEncoder()); - + combinedRowDataset.map( + RowHelpers.extractObjectFromRowMapFunction(), EncoderHelpers.windowedValueEncoder()); context.putDataset(output, outputDataset); } }