Refactor translators according to new GroupAlsoByWindow implemenation for the Spark runnner.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/96abe4f0 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/96abe4f0 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/96abe4f0 Branch: refs/heads/master Commit: 96abe4f08be12ac10dac39b55b8f8319a227b1ea Parents: 8c37970 Author: Sela <ans...@paypal.com> Authored: Fri Feb 17 01:19:23 2017 +0200 Committer: Sela <ans...@paypal.com> Committed: Wed Mar 1 00:17:59 2017 +0200 ---------------------------------------------------------------------- .../SparkGroupAlsoByWindowViaWindowSet.java | 14 +- .../spark/stateful/SparkStateInternals.java | 4 +- .../spark/stateful/SparkTimerInternals.java | 6 +- .../translation/GroupCombineFunctions.java | 237 ++++++------------ .../spark/translation/TransformTranslator.java | 238 +++++++++++++------ .../spark/translation/TranslationUtils.java | 22 +- .../streaming/StreamingTransformTranslator.java | 163 ++++--------- .../translation/streaming/UnboundedDataset.java | 12 +- .../beam/runners/spark/util/LateDataUtils.java | 2 +- .../spark/util/UnsupportedSideInputReader.java | 2 +- .../streaming/TrackStreamingSourcesTest.java | 2 +- 11 files changed, 314 insertions(+), 388 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java index 7902d7c..2fb4100 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkGroupAlsoByWindowViaWindowSet.java @@ -17,8 +17,8 @@ */ package org.apache.beam.runners.spark.stateful; -import com.google.common.collect.Table; import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Table; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -84,7 +84,8 @@ import scala.runtime.AbstractFunction1; * in the following steps. */ public class SparkGroupAlsoByWindowViaWindowSet { - private static final Logger LOG = LoggerFactory.getLogger(SparkGroupAlsoByWindowViaWindowSet.class); + private static final Logger LOG = LoggerFactory.getLogger( + SparkGroupAlsoByWindowViaWindowSet.class); /** * A helper class that is essentially a {@link Serializable} {@link AbstractFunction1}. @@ -101,7 +102,7 @@ public class SparkGroupAlsoByWindowViaWindowSet { final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) { - Long checkpointDuration = + long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class) .getCheckpointDurationMillis(); @@ -271,8 +272,11 @@ public class SparkGroupAlsoByWindowViaWindowSet { return scala.collection.JavaConversions.asScalaIterator(outIter); } }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, - List<WindowedValue<KV<K, Iterable<InputT>>>>>>fakeClassTag()) - .checkpoint(new Duration(checkpointDuration)); + List<WindowedValue<KV<K, Iterable<InputT>>>>>>fakeClassTag()); + + if (checkpointDurationMillis > 0) { + firedStream.checkpoint(new Duration(checkpointDurationMillis)); + } // go back to Java now. JavaPairDStream<K, Tuple2<StateAndTimers, List<WindowedValue<KV<K, Iterable<InputT>>>>>> http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java index e628d31..93b1f63 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java @@ -22,11 +22,11 @@ import com.google.common.collect.Table; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTag.StateBinder; +import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.ListCoder; @@ -399,4 +399,4 @@ class SparkStateInternals<K> implements StateInternals<K> { }; } } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java index 65225c5..4072240 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkTimerInternals.java @@ -27,9 +27,9 @@ import java.util.List; import java.util.Map; import java.util.Set; import javax.annotation.Nullable; -import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.spark.util.GlobalWatermarkHolder.SparkWatermarks; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.TimeDomain; import org.apache.spark.broadcast.Broadcast; @@ -145,7 +145,7 @@ class SparkTimerInternals implements TimerInternals { return inputWatermark; } - /** Advances the watermark - since */ + /** Advances the watermark. */ public void advanceWatermark() { inputWatermark = highWatermark; } @@ -170,4 +170,4 @@ class SparkTimerInternals implements TimerInternals { throw new UnsupportedOperationException("Deleting a timer by ID is not yet supported."); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java index 8a41b4e..1e879ce 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java @@ -18,32 +18,21 @@ package org.apache.beam.runners.spark.translation; -import com.google.common.collect.Lists; -import java.util.Collections; -import java.util.Map; -import org.apache.beam.runners.core.SystemReduceFn; -import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import static com.google.common.base.Preconditions.checkArgument; + import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.ByteArray; -import org.apache.beam.runners.spark.util.SideInputBroadcast; -import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.transforms.CombineWithContext; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.spark.Accumulator; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFlatMapFunction; -import scala.Tuple2; @@ -53,113 +42,71 @@ import scala.Tuple2; public class GroupCombineFunctions { /** - * Apply {@link org.apache.beam.sdk.transforms.GroupByKey} to a Spark RDD. + * An implementation of + * {@link org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly} + * for the Spark runner. */ - public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K, - Iterable<V>>>> groupByKey(JavaRDD<WindowedValue<KV<K, V>>> rdd, - Accumulator<NamedAggregators> accum, - KvCoder<K, V> coder, - SparkRuntimeContext runtimeContext, - WindowingStrategy<?, W> windowingStrategy) { - //--- coders. - final Coder<K> keyCoder = coder.getKeyCoder(); - final Coder<V> valueCoder = coder.getValueCoder(); - final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of( - valueCoder, windowingStrategy.getWindowFn().windowCoder()); + public static <K, V> JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupByKeyOnly( + JavaRDD<WindowedValue<KV<K, V>>> rdd, + Coder<K> keyCoder, + WindowedValueCoder<V> wvCoder) { - //--- groupByKey. // Use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. - JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKey = - rdd.map(new ReifyTimestampsAndWindowsFunction<K, V>()) - .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()) - .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()) - .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)) - .groupByKey() - .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)) - // empty windows are OK here, see GroupByKey#evaluateHelper in the SDK - .map(TranslationUtils.<K, Iterable<WindowedValue<V>>>fromPairFunction()) - .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()); - - //--- now group also by window. - // GroupAlsoByWindow currently uses a dummy in-memory StateInternals - return groupedByKey.flatMap( - new SparkGroupAlsoByWindowViaOutputBufferFn<>( - windowingStrategy, - new TranslationUtils.InMemoryStateInternalsFactory<K>(), - SystemReduceFn.<K, V, W>buffering(valueCoder), - runtimeContext, - accum)); + return rdd + .map(new ReifyTimestampsAndWindowsFunction<K, V>()) + .map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()) + .mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()) + .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)) + .groupByKey() + .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)) + .map(TranslationUtils.<K, Iterable<WindowedValue<V>>>fromPairFunction()) + .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()); } /** * Apply a composite {@link org.apache.beam.sdk.transforms.Combine.Globally} transformation. */ - public static <InputT, AccumT, OutputT> JavaRDD<WindowedValue<OutputT>> - combineGlobally(JavaRDD<WindowedValue<InputT>> rdd, - final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn, - final Coder<InputT> iCoder, - final Coder<OutputT> oCoder, - final SparkRuntimeContext runtimeContext, - final WindowingStrategy<?, ?> windowingStrategy, - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> - sideInputs, - boolean hasDefault) { - // handle empty input RDD, which will natively skip the entire execution as Spark will not - // run on empty RDDs. - if (rdd.isEmpty()) { - JavaSparkContext jsc = new JavaSparkContext(rdd.context()); - if (hasDefault) { - OutputT defaultValue = combineFn.defaultValue(); - return jsc - .parallelize(Lists.newArrayList(CoderHelpers.toByteArray(defaultValue, oCoder))) - .map(CoderHelpers.fromByteFunction(oCoder)) - .map(WindowingHelpers.<OutputT>windowFunction()); - } else { - return jsc.emptyRDD(); - } - } - - //--- coders. - final Coder<AccumT> aCoder; - try { - aCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder); - } catch (CannotProvideCoderException e) { - throw new IllegalStateException("Could not determine coder for accumulator", e); - } - // windowed coders. + public static <InputT, AccumT> Iterable<WindowedValue<AccumT>> combineGlobally( + JavaRDD<WindowedValue<InputT>> rdd, + final SparkGlobalCombineFn<InputT, AccumT, ?> sparkCombineFn, + final Coder<InputT> iCoder, + final Coder<AccumT> aCoder, + final WindowingStrategy<?, ?> windowingStrategy) { + checkArgument(!rdd.isEmpty(), "CombineGlobally computation should be skipped for empty RDDs."); + + // coders. final WindowedValue.FullWindowedValueCoder<InputT> wviCoder = WindowedValue.FullWindowedValueCoder.of(iCoder, windowingStrategy.getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder<AccumT> wvaCoder = WindowedValue.FullWindowedValueCoder.of(aCoder, windowingStrategy.getWindowFn().windowCoder()); - final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder = - WindowedValue.FullWindowedValueCoder.of(oCoder, - windowingStrategy.getWindowFn().windowCoder()); - - final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn = - new SparkGlobalCombineFn<>(combineFn, runtimeContext, sideInputs, windowingStrategy); final IterableCoder<WindowedValue<AccumT>> iterAccumCoder = IterableCoder.of(wvaCoder); - // Use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. - JavaRDD<byte[]> inRddBytes = rdd.map(CoderHelpers.toByteFunction(wviCoder)); - /*AccumT*/ byte[] acc = inRddBytes.aggregate( + // for readability, we add comments with actual type next to byte[]. + // to shorten line length, we use: + //---- WV: WindowedValue + //---- Iterable: Itr + //---- AccumT: A + //---- InputT: I + JavaRDD<byte[]> inputRDDBytes = rdd.map(CoderHelpers.toByteFunction(wviCoder)); + /*Itr<WV<A>>*/ byte[] accumulatedBytes = inputRDDBytes.aggregate( CoderHelpers.toByteArray(sparkCombineFn.zeroValue(), iterAccumCoder), - new Function2</*AccumT*/ byte[], /*InputT*/ byte[], /*AccumT*/ byte[]>() { + new Function2</*A*/ byte[], /*I*/ byte[], /*A*/ byte[]>() { @Override - public /*AccumT*/ byte[] call(/*AccumT*/ byte[] ab, /*InputT*/ byte[] ib) + public /*Itr<WV<A>>*/ byte[] call(/*Itr<WV<A>>*/ byte[] ab, /*WV<I>*/ byte[] ib) throws Exception { Iterable<WindowedValue<AccumT>> a = CoderHelpers.fromByteArray(ab, iterAccumCoder); WindowedValue<InputT> i = CoderHelpers.fromByteArray(ib, wviCoder); return CoderHelpers.toByteArray(sparkCombineFn.seqOp(a, i), iterAccumCoder); } }, - new Function2</*AccumT*/ byte[], /*AccumT*/ byte[], /*AccumT*/ byte[]>() { + new Function2</*Itr<WV<A>>>*/ byte[], /*Itr<WV<A>>>*/ byte[], /*Itr<WV<A>>>*/ byte[]>() { @Override - public /*AccumT*/ byte[] call(/*AccumT*/ byte[] a1b, /*AccumT*/ byte[] a2b) + public /*Itr<WV<A>>>*/ byte[] call(/*Itr<WV<A>>>*/ byte[] a1b, /*Itr<WV<A>>>*/ byte[] a2b) throws Exception { Iterable<WindowedValue<AccumT>> a1 = CoderHelpers.fromByteArray(a1b, iterAccumCoder); Iterable<WindowedValue<AccumT>> a2 = CoderHelpers.fromByteArray(a2b, iterAccumCoder); @@ -168,10 +115,7 @@ public class GroupCombineFunctions { } } ); - Iterable<WindowedValue<OutputT>> output = - sparkCombineFn.extractOutput(CoderHelpers.fromByteArray(acc, iterAccumCoder)); - return new JavaSparkContext(rdd.context()).parallelize( - CoderHelpers.toByteArrays(output, wvoCoder)).map(CoderHelpers.fromByteFunction(wvoCoder)); + return CoderHelpers.fromByteArray(accumulatedBytes, iterAccumCoder); } /** @@ -183,31 +127,22 @@ public class GroupCombineFunctions { * For streaming, this will be called from within a serialized context * (DStream's transform callback), so passed arguments need to be Serializable. */ - public static <K, InputT, AccumT, OutputT> JavaRDD<WindowedValue<KV<K, OutputT>>> - combinePerKey(JavaRDD<WindowedValue<KV<K, InputT>>> rdd, - final CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> - combineFn, - final KvCoder<K, InputT> inputCoder, - final SparkRuntimeContext runtimeContext, - final WindowingStrategy<?, ?> windowingStrategy, - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> - sideInputs) { - //--- coders. - final Coder<K> keyCoder = inputCoder.getKeyCoder(); - final Coder<InputT> viCoder = inputCoder.getValueCoder(); - final Coder<AccumT> vaCoder; - try { - vaCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), keyCoder, viCoder); - } catch (CannotProvideCoderException e) { - throw new IllegalStateException("Could not determine coder for accumulator", e); - } - // windowed coders. + public static <K, InputT, AccumT> JavaPairRDD<K, Iterable<WindowedValue<KV<K, AccumT>>>> + combinePerKey( + JavaRDD<WindowedValue<KV<K, InputT>>> rdd, + final SparkKeyedCombineFn<K, InputT, AccumT, ?> sparkCombineFn, + final Coder<K> keyCoder, + final Coder<InputT> iCoder, + final Coder<AccumT> aCoder, + final WindowingStrategy<?, ?> windowingStrategy) { + // coders. final WindowedValue.FullWindowedValueCoder<KV<K, InputT>> wkviCoder = - WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, viCoder), + WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, iCoder), windowingStrategy.getWindowFn().windowCoder()); final WindowedValue.FullWindowedValueCoder<KV<K, AccumT>> wkvaCoder = - WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, vaCoder), + WindowedValue.FullWindowedValueCoder.of(KvCoder.of(keyCoder, aCoder), windowingStrategy.getWindowFn().windowCoder()); + final IterableCoder<WindowedValue<KV<K, AccumT>>> iterAccumCoder = IterableCoder.of(wkvaCoder); // We need to duplicate K as both the key of the JavaPairRDD as well as inside the value, // since the functions passed to combineByKey don't receive the associated key of each @@ -217,53 +152,46 @@ public class GroupCombineFunctions { // we won't need to duplicate the keys anymore. // Key has to bw windowed in order to group by window as well. JavaPairRDD<K, WindowedValue<KV<K, InputT>>> inRddDuplicatedKeyPair = - rdd.flatMapToPair( - new PairFlatMapFunction<WindowedValue<KV<K, InputT>>, K, - WindowedValue<KV<K, InputT>>>() { - @Override - public Iterable<Tuple2<K, WindowedValue<KV<K, InputT>>>> - call(WindowedValue<KV<K, InputT>> wkv) { - return Collections.singletonList(new Tuple2<>(wkv.getValue().getKey(), wkv)); - } - }); - - final SparkKeyedCombineFn<K, InputT, AccumT, OutputT> sparkCombineFn = - new SparkKeyedCombineFn<>(combineFn, runtimeContext, sideInputs, windowingStrategy); - final IterableCoder<WindowedValue<KV<K, AccumT>>> iterAccumCoder = IterableCoder.of(wkvaCoder); + rdd.mapToPair(TranslationUtils.<K, InputT>toPairByKeyInWindowedValue()); // Use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. + // for readability, we add comments with actual type next to byte[]. + // to shorten line length, we use: + //---- WV: WindowedValue + //---- Iterable: Itr + //---- AccumT: A + //---- InputT: I JavaPairRDD<ByteArray, byte[]> inRddDuplicatedKeyPairBytes = inRddDuplicatedKeyPair .mapToPair(CoderHelpers.toByteFunction(keyCoder, wkviCoder)); - // The output of combineByKey will be "AccumT" (accumulator) - // types rather than "OutputT" (final output types) since Combine.CombineFn - // only provides ways to merge VAs, and no way to merge VOs. - JavaPairRDD</*K*/ ByteArray, /*KV<K, AccumT>*/ byte[]> accumulatedBytes = + JavaPairRDD</*K*/ ByteArray, /*Itr<WV<KV<K, A>>>*/ byte[]> accumulatedBytes = inRddDuplicatedKeyPairBytes.combineByKey( - new Function</*KV<K, InputT>*/ byte[], /*KV<K, AccumT>*/ byte[]>() { + new Function</*WV<KV<K, I>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[]>() { @Override - public /*KV<K, AccumT>*/ byte[] call(/*KV<K, InputT>*/ byte[] input) { + public /*Itr<WV<KV<K, A>>>*/ byte[] call(/*WV<KV<K, I>>*/ byte[] input) { WindowedValue<KV<K, InputT>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); return CoderHelpers.toByteArray(sparkCombineFn.createCombiner(wkvi), iterAccumCoder); } }, - new Function2</*KV<K, AccumT>*/ byte[], /*KV<K, InputT>*/ byte[], - /*KV<K, AccumT>*/ byte[]>() { + new Function2</*Itr<WV<KV<K, A>>>*/ byte[], /*WV<KV<K, I>>*/ byte[], + /*Itr<WV<KV<K, A>>>*/ byte[]>() { @Override - public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc, - /*KV<K, InputT>*/ byte[] input) { + public /*Itr<WV<KV<K, A>>>*/ byte[] call( + /*Itr<WV<KV<K, A>>>*/ byte[] acc, + /*WV<KV<K, I>>*/ byte[] input) { Iterable<WindowedValue<KV<K, AccumT>>> wkvas = CoderHelpers.fromByteArray(acc, iterAccumCoder); WindowedValue<KV<K, InputT>> wkvi = CoderHelpers.fromByteArray(input, wkviCoder); return CoderHelpers.toByteArray(sparkCombineFn.mergeValue(wkvi, wkvas), iterAccumCoder); } }, - new Function2</*KV<K, AccumT>*/ byte[], /*KV<K, AccumT>*/ byte[], - /*KV<K, AccumT>*/ byte[]>() { + new Function2</*Itr<WV<KV<K, A>>>*/ byte[], /*Itr<WV<KV<K, A>>>*/ byte[], + /*Itr<WV<KV<K, A>>>*/ byte[]>() { @Override - public /*KV<K, AccumT>*/ byte[] call(/*KV<K, AccumT>*/ byte[] acc1, - /*KV<K, AccumT>*/ byte[] acc2) { + public /*Itr<WV<KV<K, A>>>*/ byte[] call( + /*Itr<WV<KV<K, A>>>*/ byte[] acc1, + /*Itr<WV<KV<K, A>>>*/ byte[] acc2) { Iterable<WindowedValue<KV<K, AccumT>>> wkvas1 = CoderHelpers.fromByteArray(acc1, iterAccumCoder); Iterable<WindowedValue<KV<K, AccumT>>> wkvas2 = @@ -273,23 +201,6 @@ public class GroupCombineFunctions { } }); - JavaPairRDD<K, WindowedValue<OutputT>> extracted = accumulatedBytes - .mapToPair(CoderHelpers.fromByteFunction(keyCoder, iterAccumCoder)) - .flatMapValues(new Function<Iterable<WindowedValue<KV<K, AccumT>>>, - Iterable<WindowedValue<OutputT>>>() { - @Override - public Iterable<WindowedValue<OutputT>> call( - Iterable<WindowedValue<KV<K, AccumT>>> accums) { - return sparkCombineFn.extractOutput(accums); - } - }); - return extracted.map(TranslationUtils.<K, WindowedValue<OutputT>>fromPairFunction()).map( - new Function<KV<K, WindowedValue<OutputT>>, WindowedValue<KV<K, OutputT>>>() { - @Override - public WindowedValue<KV<K, OutputT>> call(KV<K, WindowedValue<OutputT>> kwvo) - throws Exception { - return kwvo.getValue().withValue(KV.of(kwvo.getKey(), kwvo.getValue().getValue())); - } - }); + return accumulatedBytes.mapToPair(CoderHelpers.fromByteFunction(keyCoder, iterAccumCoder)); } } http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 14c14dc..a643651 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -27,6 +27,7 @@ import static org.apache.beam.runners.spark.io.hadoop.ShardNameBuilder.replaceSh import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; import java.util.Collections; @@ -35,6 +36,7 @@ import java.util.Map; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroJob; import org.apache.avro.mapreduce.AvroKeyInputFormat; +import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; @@ -46,6 +48,7 @@ import org.apache.beam.runners.spark.io.hadoop.TemplatedTextOutputFormat; import org.apache.beam.runners.spark.metrics.MetricsAccumulator; import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; import org.apache.beam.runners.spark.util.SideInputBroadcast; +import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -63,6 +66,7 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; @@ -119,101 +123,159 @@ public final class TransformTranslator { }; } - private static <K, V> TransformEvaluator<GroupByKey<K, V>> groupByKey() { + private static <K, V, W extends BoundedWindow> TransformEvaluator<GroupByKey<K, V>> groupByKey() { return new TransformEvaluator<GroupByKey<K, V>>() { @Override public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) { @SuppressWarnings("unchecked") JavaRDD<WindowedValue<KV<K, V>>> inRDD = ((BoundedDataset<KV<K, V>>) context.borrowDataset(transform)).getRDD(); - @SuppressWarnings("unchecked") final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); - final Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(context.getSparkContext()); - - context.putDataset( - transform, - new BoundedDataset<>( - GroupCombineFunctions.groupByKey( - inRDD, - accum, - coder, - context.getRuntimeContext(), - context.getInput(transform).getWindowingStrategy()))); + @SuppressWarnings("unchecked") + final WindowingStrategy<?, W> windowingStrategy = + (WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy(); + @SuppressWarnings("unchecked") + final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn(); + + //--- coders. + final Coder<K> keyCoder = coder.getKeyCoder(); + final WindowedValue.WindowedValueCoder<V> wvCoder = + WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder()); + + //--- group by key only. + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKey = + GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder); + + //--- now group also by window. + // for batch, GroupAlsoByWindow uses an in-memory StateInternals. + JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedAlsoByWindow = groupedByKey.flatMap( + new SparkGroupAlsoByWindowViaOutputBufferFn<>( + windowingStrategy, + new TranslationUtils.InMemoryStateInternalsFactory<K>(), + SystemReduceFn.<K, V, W>buffering(coder.getValueCoder()), + context.getRuntimeContext(), + accum)); + + context.putDataset(transform, new BoundedDataset<>(groupedAlsoByWindow)); } }; } private static <K, InputT, OutputT> TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>> - combineGrouped() { - return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() { - @Override - public void evaluate(Combine.GroupedValues<K, InputT, OutputT> transform, - EvaluationContext context) { - // get the applied combine function. - PCollection<? extends KV<K, ? extends Iterable<InputT>>> input = - context.getInput(transform); - WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); - @SuppressWarnings("unchecked") - CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT> fn = - (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT>) - CombineFnUtil.toFnWithContext(transform.getFn()); - - @SuppressWarnings("unchecked") - JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> inRDD = - ((BoundedDataset<KV<K, Iterable<InputT>>>) - context.borrowDataset(transform)).getRDD(); - - SparkKeyedCombineFn<K, InputT, ?, OutputT> combineFnWithContext = - new SparkKeyedCombineFn<>(fn, context.getRuntimeContext(), - TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy); - context.putDataset(transform, new BoundedDataset<>(inRDD.map(new TranslationUtils - .CombineGroupedValues<>( - combineFnWithContext)))); - } - }; + combineGrouped() { + return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() { + @Override + public void evaluate( + Combine.GroupedValues<K, InputT, OutputT> transform, + EvaluationContext context) { + @SuppressWarnings("unchecked") + CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT> combineFn = + (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT>) + CombineFnUtil.toFnWithContext(transform.getFn()); + final SparkKeyedCombineFn<K, InputT, ?, OutputT> sparkCombineFn = + new SparkKeyedCombineFn<>(combineFn, context.getRuntimeContext(), + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + context.getInput(transform).getWindowingStrategy()); + + @SuppressWarnings("unchecked") + JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> inRDD = + ((BoundedDataset<KV<K, Iterable<InputT>>>) context.borrowDataset(transform)) + .getRDD(); + + JavaRDD<WindowedValue<KV<K, OutputT>>> outRDD = inRDD.map( + new Function<WindowedValue<KV<K, Iterable<InputT>>>, + WindowedValue<KV<K, OutputT>>>() { + @Override + public WindowedValue<KV<K, OutputT>> call( + WindowedValue<KV<K, Iterable<InputT>>> in) throws Exception { + return WindowedValue.of( + KV.of(in.getValue().getKey(), sparkCombineFn.apply(in)), + in.getTimestamp(), + in.getWindows(), + in.getPane()); + } + }); + context.putDataset(transform, new BoundedDataset<>(outRDD)); + } + }; } private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> - combineGlobally() { - return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() { - - @Override - public void evaluate(Combine.Globally<InputT, OutputT> transform, EvaluationContext context) { - final PCollection<InputT> input = context.getInput(transform); - // serializable arguments to pass. - final Coder<InputT> iCoder = context.getInput(transform).getCoder(); - final Coder<OutputT> oCoder = context.getOutput(transform).getCoder(); - @SuppressWarnings("unchecked") - final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn = - (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>) - CombineFnUtil.toFnWithContext(transform.getFn()); - final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), context); - final boolean hasDefault = transform.isInsertDefault(); + combineGlobally() { + return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() { - @SuppressWarnings("unchecked") - JavaRDD<WindowedValue<InputT>> inRdd = - ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); - - context.putDataset(transform, new BoundedDataset<>(GroupCombineFunctions - .combineGlobally(inRdd, combineFn, - iCoder, oCoder, runtimeContext, windowingStrategy, sideInputs, hasDefault))); - } - }; + @Override + public void evaluate( + Combine.Globally<InputT, OutputT> transform, + EvaluationContext context) { + final PCollection<InputT> input = context.getInput(transform); + final Coder<InputT> iCoder = context.getInput(transform).getCoder(); + final Coder<OutputT> oCoder = context.getOutput(transform).getCoder(); + final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); + @SuppressWarnings("unchecked") + final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn = + (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>) + CombineFnUtil.toFnWithContext(transform.getFn()); + final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder = + WindowedValue.FullWindowedValueCoder.of(oCoder, + windowingStrategy.getWindowFn().windowCoder()); + final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); + final boolean hasDefault = transform.isInsertDefault(); + + final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn = + new SparkGlobalCombineFn<>( + combineFn, + runtimeContext, + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + windowingStrategy); + final Coder<AccumT> aCoder; + try { + aCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } + + @SuppressWarnings("unchecked") + JavaRDD<WindowedValue<InputT>> inRdd = + ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); + + JavaRDD<WindowedValue<OutputT>> outRdd; + // handle empty input RDD, which will naturally skip the entire execution + // as Spark will not run on empty RDDs. + if (inRdd.isEmpty()) { + JavaSparkContext jsc = new JavaSparkContext(inRdd.context()); + if (hasDefault) { + OutputT defaultValue = combineFn.defaultValue(); + outRdd = jsc + .parallelize(Lists.newArrayList(CoderHelpers.toByteArray(defaultValue, oCoder))) + .map(CoderHelpers.fromByteFunction(oCoder)) + .map(WindowingHelpers.<OutputT>windowFunction()); + } else { + outRdd = jsc.emptyRDD(); + } + } else { + Iterable<WindowedValue<AccumT>> accumulated = GroupCombineFunctions.combineGlobally( + inRdd, sparkCombineFn, iCoder, aCoder, windowingStrategy); + Iterable<WindowedValue<OutputT>> output = sparkCombineFn.extractOutput(accumulated); + outRdd = context.getSparkContext() + .parallelize(CoderHelpers.toByteArrays(output, wvoCoder)) + .map(CoderHelpers.fromByteFunction(wvoCoder)); + } + context.putDataset(transform, new BoundedDataset<>(outRdd)); + } + }; } private static <K, InputT, AccumT, OutputT> TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() { return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() { @Override - public void evaluate(Combine.PerKey<K, InputT, OutputT> transform, - EvaluationContext context) { + public void evaluate( + Combine.PerKey<K, InputT, OutputT> transform, + EvaluationContext context) { final PCollection<KV<K, InputT>> input = context.getInput(transform); // serializable arguments to pass. @SuppressWarnings("unchecked") @@ -227,14 +289,44 @@ public final class TransformTranslator { final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); + final SparkKeyedCombineFn<K, InputT, AccumT, OutputT> sparkCombineFn = + new SparkKeyedCombineFn<>(combineFn, runtimeContext, sideInputs, windowingStrategy); + final Coder<AccumT> vaCoder; + try { + vaCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), + inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + throw new IllegalStateException("Could not determine coder for accumulator", e); + } @SuppressWarnings("unchecked") JavaRDD<WindowedValue<KV<K, InputT>>> inRdd = ((BoundedDataset<KV<K, InputT>>) context.borrowDataset(transform)).getRDD(); - context.putDataset(transform, new BoundedDataset<>(GroupCombineFunctions - .combinePerKey(inRdd, combineFn, - inputCoder, runtimeContext, windowingStrategy, sideInputs))); + JavaPairRDD<K, Iterable<WindowedValue<KV<K, AccumT>>>> accumulatePerKey = + GroupCombineFunctions.combinePerKey(inRdd, sparkCombineFn, inputCoder.getKeyCoder(), + inputCoder.getValueCoder(), vaCoder, windowingStrategy); + + JavaRDD<WindowedValue<KV<K, OutputT>>> outRdd = + accumulatePerKey.flatMapValues(new Function<Iterable<WindowedValue<KV<K, AccumT>>>, + Iterable<WindowedValue<OutputT>>>() { + @Override + public Iterable<WindowedValue<OutputT>> call( + Iterable<WindowedValue<KV<K, AccumT>>> iter) throws Exception { + return sparkCombineFn.extractOutput(iter); + } + }).map(TranslationUtils.<K, WindowedValue<OutputT>>fromPairFunction()) + .map(new Function<KV<K, WindowedValue<OutputT>>, + WindowedValue<KV<K, OutputT>>>() { + @Override + public WindowedValue<KV<K, OutputT>> call( + KV<K, WindowedValue<OutputT>> kv) throws Exception { + WindowedValue<OutputT> wv = kv.getValue(); + return wv.withValue(KV.of(kv.getKey(), wv.getValue())); + } + }); + + context.putDataset(transform, new BoundedDataset<>(outRdd)); } }; } http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index 7d83230..6b27436 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -21,7 +21,6 @@ package org.apache.beam.runners.spark.translation; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import java.io.Serializable; -import java.util.Iterator; import java.util.List; import java.util.Map; import org.apache.beam.runners.core.InMemoryStateInternals; @@ -42,11 +41,11 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; + import scala.Tuple2; /** @@ -148,20 +147,17 @@ public final class TranslationUtils { }; } - /** A Flatmap iterator function, flattening iterators into their elements. */ - public static <T> FlatMapFunction<Iterator<T>, T> flattenIter() { - return new FlatMapFunction<Iterator<T>, T>() { - @Override - public Iterable<T> call(final Iterator<T> t) throws Exception { - return new Iterable<T>() { + /** Extract key from a {@link WindowedValue} {@link KV} into a pair. */ + public static <K, V> PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>> + toPairByKeyInWindowedValue() { + return new PairFunction<WindowedValue<KV<K, V>>, K, WindowedValue<KV<K, V>>>() { @Override - public Iterator<T> iterator() { - return t; - } + public Tuple2<K, WindowedValue<KV<K, V>>> call( + WindowedValue<KV<K, V>> windowedKv) throws Exception { + return new Tuple2<>(windowedKv.getValue().getKey(), windowedKv); + } }; } - }; - } /** * A utility class to filter {@link TupleTag}s. http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 9451df7..e90b490 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -31,6 +31,7 @@ import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.io.ConsoleIO; import org.apache.beam.runners.spark.io.CreateStream; import org.apache.beam.runners.spark.io.SparkUnboundedSource; +import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet; import org.apache.beam.runners.spark.metrics.MetricsAccumulator; import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; import org.apache.beam.runners.spark.translation.BoundedDataset; @@ -148,7 +149,7 @@ final class StreamingTransformTranslator { Dataset dataset = context.borrowDataset(pcol); if (dataset instanceof UnboundedDataset) { UnboundedDataset<T> unboundedDataset = (UnboundedDataset<T>) dataset; - streamingSources.addAll(unboundedDataset.getStreamingSources()); + streamingSources.addAll(unboundedDataset.getStreamSources()); dStreams.add(unboundedDataset.getDStream()); } else { rdds.add(((BoundedDataset<T>) dataset).getRDD()); @@ -205,7 +206,7 @@ final class StreamingTransformTranslator { //--- then we apply windowing to the elements if (TranslationUtils.skipAssignWindows(transform, context)) { context.putDataset(transform, - new UnboundedDataset<>(windowedDStream, unboundedDataset.getStreamingSources())); + new UnboundedDataset<>(windowedDStream, unboundedDataset.getStreamSources())); } else { JavaDStream<WindowedValue<T>> outStream = windowedDStream.transform( new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { @@ -215,42 +216,55 @@ final class StreamingTransformTranslator { } }); context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamingSources())); + new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); } } }; } - private static <K, V> TransformEvaluator<GroupByKey<K, V>> groupByKey() { + private static <K, V, W extends BoundedWindow> TransformEvaluator<GroupByKey<K, V>> groupByKey() { return new TransformEvaluator<GroupByKey<K, V>>() { @Override public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) { - @SuppressWarnings("unchecked") - UnboundedDataset<KV<K, V>> unboundedDataset = - ((UnboundedDataset<KV<K, V>>) context.borrowDataset(transform)); - JavaDStream<WindowedValue<KV<K, V>>> dStream = unboundedDataset.getDStream(); - + @SuppressWarnings("unchecked") UnboundedDataset<KV<K, V>> inputDataset = + (UnboundedDataset<KV<K, V>>) context.borrowDataset(transform); + List<Integer> streamSources = inputDataset.getStreamSources(); + JavaDStream<WindowedValue<KV<K, V>>> dStream = inputDataset.getDStream(); @SuppressWarnings("unchecked") final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder(); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final WindowingStrategy<?, ?> windowingStrategy = - context.getInput(transform).getWindowingStrategy(); + @SuppressWarnings("unchecked") + final WindowingStrategy<?, W> windowingStrategy = + (WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy(); + @SuppressWarnings("unchecked") + final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn(); - JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = + //--- coders. + final WindowedValue.WindowedValueCoder<V> wvCoder = + WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder()); + + //--- group by key only. + JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKeyStream = dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, V>>>, - JavaRDD<WindowedValue<KV<K, Iterable<V>>>>>() { - @Override - public JavaRDD<WindowedValue<KV<K, Iterable<V>>>> call( - JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception { - final Accumulator<NamedAggregators> accum = - SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); - return GroupCombineFunctions.groupByKey(rdd, accum, coder, runtimeContext, - windowingStrategy); - } - }); - context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamingSources())); + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>>>() { + @Override + public JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> call( + JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception { + return GroupCombineFunctions.groupByKeyOnly( + rdd, coder.getKeyCoder(), wvCoder); + } + }); + + //--- now group also by window. + JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = + SparkGroupAlsoByWindowViaWindowSet.groupAlsoByWindow( + groupedByKeyStream, + coder.getValueCoder(), + windowingStrategy, + runtimeContext, + streamSources); + + context.putDataset(transform, new UnboundedDataset<>(outStream, streamSources)); } }; } @@ -296,96 +310,7 @@ final class StreamingTransformTranslator { }); context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamingSources())); - } - }; - } - - private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> - combineGlobally() { - return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() { - - @Override - public void evaluate( - final Combine.Globally<InputT, OutputT> transform, - EvaluationContext context) { - final PCollection<InputT> input = context.getInput(transform); - // serializable arguments to pass. - final Coder<InputT> iCoder = context.getInput(transform).getCoder(); - final Coder<OutputT> oCoder = context.getOutput(transform).getCoder(); - @SuppressWarnings("unchecked") - final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn = - (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>) - CombineFnUtil.toFnWithContext(transform.getFn()); - final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final boolean hasDefault = transform.isInsertDefault(); - final SparkPCollectionView pviews = context.getPViews(); - - @SuppressWarnings("unchecked") - UnboundedDataset<InputT> unboundedDataset = - ((UnboundedDataset<InputT>) context.borrowDataset(transform)); - JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream(); - - JavaDStream<WindowedValue<OutputT>> outStream = dStream.transform( - new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() { - @Override - public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd) - throws Exception { - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), - JavaSparkContext.fromSparkContext(rdd.context()), - pviews); - return GroupCombineFunctions.combineGlobally(rdd, combineFn, iCoder, oCoder, - runtimeContext, windowingStrategy, sideInputs, hasDefault); - } - }); - - context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamingSources())); - } - }; - } - - private static <K, InputT, AccumT, OutputT> - TransformEvaluator<Combine.PerKey<K, InputT, OutputT>> combinePerKey() { - return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() { - @Override - public void evaluate(final Combine.PerKey<K, InputT, OutputT> transform, - final EvaluationContext context) { - final PCollection<KV<K, InputT>> input = context.getInput(transform); - // serializable arguments to pass. - final KvCoder<K, InputT> inputCoder = - (KvCoder<K, InputT>) context.getInput(transform).getCoder(); - @SuppressWarnings("unchecked") - final CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn = - (CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) - CombineFnUtil.toFnWithContext(transform.getFn()); - final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); - final SparkPCollectionView pviews = context.getPViews(); - - @SuppressWarnings("unchecked") - UnboundedDataset<KV<K, InputT>> unboundedDataset = - ((UnboundedDataset<KV<K, InputT>>) context.borrowDataset(transform)); - JavaDStream<WindowedValue<KV<K, InputT>>> dStream = unboundedDataset.getDStream(); - - JavaDStream<WindowedValue<KV<K, OutputT>>> outStream = - dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, InputT>>>, - JavaRDD<WindowedValue<KV<K, OutputT>>>>() { - @Override - public JavaRDD<WindowedValue<KV<K, OutputT>>> call( - JavaRDD<WindowedValue<KV<K, InputT>>> rdd) throws Exception { - final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), - JavaSparkContext.fromSparkContext(rdd.context()), - pviews); - return GroupCombineFunctions.combinePerKey(rdd, combineFn, inputCoder, runtimeContext, - windowingStrategy, sideInputs); - } - }); - context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamingSources())); + new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); } }; } @@ -431,7 +356,7 @@ final class StreamingTransformTranslator { }); context.putDataset(transform, - new UnboundedDataset<>(outStream, unboundedDataset.getStreamingSources())); + new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); } }; } @@ -486,7 +411,7 @@ final class StreamingTransformTranslator { (JavaDStream<WindowedValue<Object>>) (JavaDStream<?>) TranslationUtils.dStreamValues(filtered); context.putDataset(e.getValue(), - new UnboundedDataset<>(values, unboundedDataset.getStreamingSources())); + new UnboundedDataset<>(values, unboundedDataset.getStreamSources())); } } }; @@ -499,8 +424,6 @@ final class StreamingTransformTranslator { EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); - EVALUATORS.put(Combine.Globally.class, combineGlobally()); - EVALUATORS.put(Combine.PerKey.class, combinePerKey()); EVALUATORS.put(ParDo.Bound.class, parDo()); EVALUATORS.put(ParDo.BoundMulti.class, multiDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); @@ -523,7 +446,7 @@ final class StreamingTransformTranslator { @Override public boolean hasTranslation(Class<? extends PTransform<?, ?>> clazz) { // streaming includes rdd/bounded transformations as well - return EVALUATORS.containsKey(clazz) || batchTranslator.hasTranslation(clazz); + return EVALUATORS.containsKey(clazz); } @Override http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java index 6f5fa93..8624f41 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/UnboundedDataset.java @@ -56,11 +56,11 @@ public class UnboundedDataset<T> implements Dataset { // should be greater > 1 in case of Flatten for example. // when using GlobalWatermarkHolder this information helps to take only the relevant watermarks // and reason about them accordingly. - private final List<Integer> streamingSources = new ArrayList<>(); + private final List<Integer> streamSources = new ArrayList<>(); - public UnboundedDataset(JavaDStream<WindowedValue<T>> dStream, List<Integer> streamingSources) { + public UnboundedDataset(JavaDStream<WindowedValue<T>> dStream, List<Integer> streamSources) { this.dStream = dStream; - this.streamingSources.addAll(streamingSources); + this.streamSources.addAll(streamSources); } public UnboundedDataset(Iterable<Iterable<T>> values, JavaStreamingContext jssc, Coder<T> coder) { @@ -68,7 +68,7 @@ public class UnboundedDataset<T> implements Dataset { this.jssc = jssc; this.coder = coder; // QueuedStream will have a negative (decreasing) unique id. - this.streamingSources.add(queuedStreamIds.decrementAndGet()); + this.streamSources.add(queuedStreamIds.decrementAndGet()); } @VisibleForTesting @@ -97,8 +97,8 @@ public class UnboundedDataset<T> implements Dataset { return dStream; } - public List<Integer> getStreamingSources() { - return streamingSources; + public List<Integer> getStreamSources() { + return streamSources; } public void cache() { http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java index 96e6ee5..18689bd 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/LateDataUtils.java @@ -89,4 +89,4 @@ public class LateDataUtils { } }); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java index 6de7e86..96d889d 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/util/UnsupportedSideInputReader.java @@ -49,4 +49,4 @@ public class UnsupportedSideInputReader implements SideInputReader { throw new UnsupportedOperationException( String.format("%s does not support side inputs.", transformName)); } -} \ No newline at end of file +} http://git-wip-us.apache.org/repos/asf/beam/blob/96abe4f0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java index fbe5777..8449724 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/TrackStreamingSourcesTest.java @@ -157,7 +157,7 @@ public class TrackStreamingSourcesTest { ctxt.setCurrentTransform(appliedTransform); //noinspection unchecked Dataset dataset = ctxt.borrowDataset((PTransform<? extends PValue, ?>) transform); - assertSourceIds(((UnboundedDataset<?>) dataset).getStreamingSources()); + assertSourceIds(((UnboundedDataset<?>) dataset).getStreamSources()); ctxt.setCurrentTransform(null); } else { evaluator.visitPrimitiveTransform(node);