[BEAM-807] Replace OldDoFn with DoFn. Add a custom AssignWindows implementation.
Setup and teardown DoFn. Add implementation for GroupAlsoByWindow via flatMap. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ffed3e0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ffed3e0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ffed3e0 Branch: refs/heads/master Commit: 4ffed3e09a2f0ec3583098f6cfd53a2ddcc6f8c2 Parents: 2be9a15 Author: Sela <ans...@paypal.com> Authored: Sun Dec 11 14:32:49 2016 +0200 Committer: Sela <ans...@paypal.com> Committed: Tue Dec 13 10:05:18 2016 +0200 ---------------------------------------------------------------------- .../beam/runners/spark/examples/WordCount.java | 6 +- .../runners/spark/translation/DoFnFunction.java | 2 +- .../translation/GroupCombineFunctions.java | 23 +- .../spark/translation/MultiDoFnFunction.java | 2 +- .../spark/translation/SparkAssignWindowFn.java | 69 ++++++ .../translation/SparkGroupAlsoByWindowFn.java | 214 +++++++++++++++++++ .../spark/translation/SparkProcessContext.java | 10 + .../spark/translation/TransformTranslator.java | 31 +-- .../streaming/StreamingTransformTranslator.java | 35 ++- .../streaming/utils/PAssertStreaming.java | 26 +-- 10 files changed, 345 insertions(+), 73 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java index b2672b5..1252d12 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java @@ -25,8 +25,8 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SimpleFunction; @@ -44,11 +44,11 @@ public class WordCount { * of-line. This DoFn tokenizes lines of text into individual words; we pass it to a ParDo in the * pipeline. */ - static class ExtractWordsFn extends OldDoFn<String, String> { + static class ExtractWordsFn extends DoFn<String, String> { private final Aggregator<Long, Long> emptyLines = createAggregator("emptyLines", new Sum.SumLongFn()); - @Override + @ProcessElement public void processElement(ProcessContext c) { if (c.element().trim().isEmpty()) { emptyLines.addValue(1L); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java index 4c49a7f..6a641b5 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java @@ -93,7 +93,7 @@ public class DoFnFunction<InputT, OutputT> windowingStrategy ); - return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter); + return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter); } private class DoFnOutputManager http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java index 421b1b0..4875b0c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java @@ -18,11 +18,9 @@ package org.apache.beam.runners.spark.translation; - import com.google.common.collect.Lists; import java.util.Collections; import java.util.Map; -import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; @@ -33,9 +31,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.transforms.CombineWithContext; -import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; @@ -59,7 +55,7 @@ public class GroupCombineFunctions { /** * Apply {@link org.apache.beam.sdk.transforms.GroupByKey} to a Spark RDD. */ - public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K, + public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupByKey(JavaRDD<WindowedValue<KV<K, V>>> rdd, Accumulator<NamedAggregators> accum, KvCoder<K, V> coder, @@ -86,15 +82,14 @@ public class GroupCombineFunctions { .map(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()); //--- now group also by window. - @SuppressWarnings("unchecked") - WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn(); - // GroupAlsoByWindow current uses a dummy in-memory StateInternals - OldDoFn<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>> gabwDoFn = - new GroupAlsoByWindowsViaOutputBufferDoFn<K, V, Iterable<V>, W>( - windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<K>(), - SystemReduceFn.<K, V, W>buffering(valueCoder)); - return groupedByKey.mapPartitions(new DoFnFunction<>(accum, gabwDoFn, runtimeContext, null, - windowFn)); + // GroupAlsoByWindow currently uses a dummy in-memory StateInternals + return groupedByKey.flatMap( + new SparkGroupAlsoByWindowFn<>( + windowingStrategy, + new TranslationUtils.InMemoryStateInternalsFactory<K>(), + SystemReduceFn.<K, V, W>buffering(valueCoder), + runtimeContext, + accum)); } /** http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java index 710c5cd..8a55369 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java @@ -102,7 +102,7 @@ public class MultiDoFnFunction<InputT, OutputT> windowingStrategy ); - return new SparkProcessContext<>(doFnRunner, outputManager).processPartition(iter); + return new SparkProcessContext<>(doFn, doFnRunner, outputManager).processPartition(iter); } private class DoFnOutputManager http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java new file mode 100644 index 0000000..9d7ed7d --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.common.collect.Iterables; +import java.util.Collection; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.spark.api.java.function.Function; +import org.joda.time.Instant; + + +/** + * An implementation of {@link org.apache.beam.runners.core.AssignWindows} for the Spark runner. + */ +public class SparkAssignWindowFn<T, W extends BoundedWindow> + implements Function<WindowedValue<T>, WindowedValue<T>> { + + private WindowFn<? super T, W> fn; + + public SparkAssignWindowFn(WindowFn<? super T, W> fn) { + this.fn = fn; + } + + @Override + @SuppressWarnings("unchecked") + public WindowedValue<T> call(WindowedValue<T> windowedValue) throws Exception { + final BoundedWindow boundedWindow = Iterables.getOnlyElement(windowedValue.getWindows()); + final T element = windowedValue.getValue(); + final Instant timestamp = windowedValue.getTimestamp(); + Collection<W> windows = + ((WindowFn<T, W>) fn).assignWindows( + ((WindowFn<T, W>) fn).new AssignContext() { + @Override + public T element() { + return element; + } + + @Override + public Instant timestamp() { + return timestamp; + } + + @Override + public BoundedWindow window() { + return boundedWindow; + } + }); + return WindowedValue.of(element, timestamp, windows, PaneInfo.NO_FIRING); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java new file mode 100644 index 0000000..87d3f50 --- /dev/null +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.common.collect.Iterables; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import org.apache.beam.runners.core.GroupAlsoByWindowsDoFn; +import org.apache.beam.runners.core.OutputWindowedValue; +import org.apache.beam.runners.core.ReduceFnRunner; +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine; +import org.apache.beam.runners.core.triggers.TriggerStateMachines; +import org.apache.beam.runners.spark.aggregators.NamedAggregators; +import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.Sum; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.util.SideInputReader; +import org.apache.beam.sdk.util.TimerInternals; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.util.state.InMemoryTimerInternals; +import org.apache.beam.sdk.util.state.StateInternals; +import org.apache.beam.sdk.util.state.StateInternalsFactory; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.spark.Accumulator; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.joda.time.Instant; + + + +/** + * An implementation of {@link org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn} + * for the Spark runner. + */ +public class SparkGroupAlsoByWindowFn<K, InputT, W extends BoundedWindow> + implements FlatMapFunction<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>, + WindowedValue<KV<K, Iterable<InputT>>>> { + + private final WindowingStrategy<?, W> windowingStrategy; + private final StateInternalsFactory<K> stateInternalsFactory; + private final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn; + private final SparkRuntimeContext runtimeContext; + private final Aggregator<Long, Long> droppedDueToClosedWindow; + + + public SparkGroupAlsoByWindowFn( + WindowingStrategy<?, W> windowingStrategy, + StateInternalsFactory<K> stateInternalsFactory, + SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn, + SparkRuntimeContext runtimeContext, + Accumulator<NamedAggregators> accumulator) { + this.windowingStrategy = windowingStrategy; + this.stateInternalsFactory = stateInternalsFactory; + this.reduceFn = reduceFn; + this.runtimeContext = runtimeContext; + + droppedDueToClosedWindow = runtimeContext.createAggregator( + accumulator, + GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER, + new Sum.SumLongFn()); + } + + @Override + public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call( + WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>> windowedValue) throws Exception { + K key = windowedValue.getValue().getKey(); + Iterable<WindowedValue<InputT>> inputs = windowedValue.getValue().getValue(); + + //------ based on GroupAlsoByWindowsViaOutputBufferDoFn ------// + + // Used with Batch, we know that all the data is available for this key. We can't use the + // timer manager from the context because it doesn't exist. So we create one and emulate the + // watermark, knowing that we have all data and it is in timestamp order. + InMemoryTimerInternals timerInternals = new InMemoryTimerInternals(); + timerInternals.advanceProcessingTime(Instant.now()); + timerInternals.advanceSynchronizedProcessingTime(Instant.now()); + StateInternals<K> stateInternals = stateInternalsFactory.stateInternalsForKey(key); + GABWOutputWindowedValue<K, InputT> outputter = new GABWOutputWindowedValue<>(); + + ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = + new ReduceFnRunner<>( + key, + windowingStrategy, + ExecutableTriggerStateMachine.create( + TriggerStateMachines.stateMachineForTrigger(windowingStrategy.getTrigger())), + stateInternals, + timerInternals, + outputter, + new SideInputReader() { + @Override + public <T> T get(PCollectionView<T> view, BoundedWindow sideInputWindow) { + throw new UnsupportedOperationException( + "GroupAlsoByWindow must not have side inputs"); + } + + @Override + public <T> boolean contains(PCollectionView<T> view) { + throw new UnsupportedOperationException( + "GroupAlsoByWindow must not have side inputs"); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException( + "GroupAlsoByWindow must not have side inputs"); + } + }, + droppedDueToClosedWindow, + reduceFn, + runtimeContext.getPipelineOptions()); + + Iterable<List<WindowedValue<InputT>>> chunks = Iterables.partition(inputs, 1000); + for (Iterable<WindowedValue<InputT>> chunk : chunks) { + // Process the chunk of elements. + reduceFnRunner.processElements(chunk); + + // Then, since elements are sorted by their timestamp, advance the input watermark + // to the first element. + timerInternals.advanceInputWatermark(chunk.iterator().next().getTimestamp()); + // Advance the processing times. + timerInternals.advanceProcessingTime(Instant.now()); + timerInternals.advanceSynchronizedProcessingTime(Instant.now()); + + // Fire all the eligible timers. + fireEligibleTimers(timerInternals, reduceFnRunner); + + // Leave the output watermark undefined. Since there's no late data in batch mode + // there's really no need to track it as we do for streaming. + } + + // Finish any pending windows by advancing the input watermark to infinity. + timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE); + + // Finally, advance the processing time to infinity to fire any timers. + timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + + fireEligibleTimers(timerInternals, reduceFnRunner); + + reduceFnRunner.persist(); + + return outputter.getOutputs(); + } + + private void fireEligibleTimers(InMemoryTimerInternals timerInternals, + ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner) throws Exception { + List<TimerInternals.TimerData> timers = new ArrayList<>(); + while (true) { + TimerInternals.TimerData timer; + while ((timer = timerInternals.removeNextEventTimer()) != null) { + timers.add(timer); + } + while ((timer = timerInternals.removeNextProcessingTimer()) != null) { + timers.add(timer); + } + while ((timer = timerInternals.removeNextSynchronizedProcessingTimer()) != null) { + timers.add(timer); + } + if (timers.isEmpty()) { + break; + } + reduceFnRunner.onTimers(timers); + timers.clear(); + } + } + + private static class GABWOutputWindowedValue<K, V> + implements OutputWindowedValue<KV<K, Iterable<V>>> { + private final List<WindowedValue<KV<K, Iterable<V>>>> outputs = new ArrayList<>(); + + @Override + public void outputWindowedValue( + KV<K, Iterable<V>> output, + Instant timestamp, + Collection<? extends BoundedWindow> windows, + PaneInfo pane) { + outputs.add(WindowedValue.of(output, timestamp, windows, pane)); + } + + @Override + public <SideOutputT> void sideOutputWindowedValue( + TupleTag<SideOutputT> tag, + SideOutputT output, + Instant timestamp, + Collection<? extends BoundedWindow> windows, PaneInfo pane) { + throw new UnsupportedOperationException("GroupAlsoByWindow should not use side outputs."); + } + + Iterable<WindowedValue<KV<K, Iterable<V>>>> getOutputs() { + return outputs; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java index efd8202..3a31cae 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java @@ -25,6 +25,8 @@ import java.util.Iterator; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners.OutputManager; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.ExecutionContext.StepContext; import org.apache.beam.sdk.util.TimerInternals; @@ -38,13 +40,16 @@ import org.apache.beam.sdk.values.TupleTag; */ class SparkProcessContext<FnInputT, FnOutputT, OutputT> { + private final DoFn<FnInputT, FnOutputT> doFn; private final DoFnRunner<FnInputT, FnOutputT> doFnRunner; private final SparkOutputManager<OutputT> outputManager; SparkProcessContext( + DoFn<FnInputT, FnOutputT> doFn, DoFnRunner<FnInputT, FnOutputT> doFnRunner, SparkOutputManager<OutputT> outputManager) { + this.doFn = doFn; this.doFnRunner = doFnRunner; this.outputManager = outputManager; } @@ -52,6 +57,9 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> { Iterable<OutputT> processPartition( Iterator<WindowedValue<FnInputT>> partition) throws Exception { + // setup DoFn. + DoFnInvokers.invokerFor(doFn).invokeSetup(); + // skip if partition is empty. if (!partition.hasNext()) { return Lists.newArrayList(); @@ -160,6 +168,8 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> { clearOutput(); calledFinish = true; doFnRunner.finishBundle(); + // teardown DoFn. + DoFnInvokers.invokerFor(doFn).invokeTeardown(); outputIterator = getOutputIterator(); continue; // try to consume outputIterator from start of loop } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 964eb37..ac91892 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -32,7 +32,6 @@ import java.util.Map; import org.apache.avro.mapred.AvroKey; import org.apache.avro.mapreduce.AvroJob; import org.apache.avro.mapreduce.AvroKeyInputFormat; -import org.apache.beam.runners.core.AssignWindowsDoFn; import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.coders.CoderHelpers; @@ -54,13 +53,11 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; -import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; @@ -235,16 +232,15 @@ public final class TransformTranslator { @SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRDD = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); - @SuppressWarnings("unchecked") - final WindowFn<Object, ?> windowFn = - (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn(); + WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(context.getSparkContext()); Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); context.putDataset(transform, - new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, transform.getFn(), - context.getRuntimeContext(), sideInputs, windowFn)))); + new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, doFn, + context.getRuntimeContext(), sideInputs, windowingStrategy)))); } }; } @@ -259,16 +255,15 @@ public final class TransformTranslator { @SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRDD = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); - @SuppressWarnings("unchecked") - final WindowFn<Object, ?> windowFn = - (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn(); + WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(context.getSparkContext()); JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD .mapPartitionsToPair( - new MultiDoFnFunction<>(accum, transform.getFn(), context.getRuntimeContext(), + new MultiDoFnFunction<>(accum, doFn, context.getRuntimeContext(), transform.getMainOutputTag(), TranslationUtils.getSideInputs( - transform.getSideInputs(), context), windowFn)).cache(); + transform.getSideInputs(), context), windowingStrategy)).cache(); PCollectionTuple pct = context.getOutput(transform); for (Map.Entry<TupleTag<?>, PCollection<?>> e : pct.getAll().entrySet()) { @SuppressWarnings("unchecked") @@ -508,14 +503,8 @@ public final class TransformTranslator { if (TranslationUtils.skipAssignWindows(transform, context)) { context.putDataset(transform, new BoundedDataset<>(inRDD)); } else { - @SuppressWarnings("unchecked") - WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); - OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - Accumulator<NamedAggregators> accum = - SparkAggregators.getNamedAggregators(context.getSparkContext()); - context.putDataset(transform, - new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, addWindowsDoFn, - context.getRuntimeContext(), null, null)))); + context.putDataset(transform, new BoundedDataset<>( + inRDD.map(new SparkAssignWindowFn<>(transform.getWindowFn())))); } } }; http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 00df7d4..27204ed 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -24,7 +24,6 @@ import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.List; import java.util.Map; -import org.apache.beam.runners.core.AssignWindowsDoFn; import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.aggregators.SparkAggregators; import org.apache.beam.runners.spark.io.ConsoleIO; @@ -36,6 +35,7 @@ import org.apache.beam.runners.spark.translation.DoFnFunction; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.GroupCombineFunctions; import org.apache.beam.runners.spark.translation.MultiDoFnFunction; +import org.apache.beam.runners.spark.translation.SparkAssignWindowFn; import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn; import org.apache.beam.runners.spark.translation.SparkPipelineTranslator; import org.apache.beam.runners.spark.translation.SparkRuntimeContext; @@ -51,7 +51,6 @@ import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; -import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -163,7 +162,7 @@ final class StreamingTransformTranslator { private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() { return new TransformEvaluator<Window.Bound<T>>() { @Override - public void evaluate(Window.Bound<T> transform, EvaluationContext context) { + public void evaluate(final Window.Bound<T> transform, EvaluationContext context) { @SuppressWarnings("unchecked") WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) transform.getWindowFn(); @SuppressWarnings("unchecked") @@ -189,16 +188,11 @@ final class StreamingTransformTranslator { if (TranslationUtils.skipAssignWindows(transform, context)) { context.putDataset(transform, new UnboundedDataset<>(windowedDStream)); } else { - final OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn); - final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); JavaDStream<WindowedValue<T>> outStream = windowedDStream.transform( new Function<JavaRDD<WindowedValue<T>>, JavaRDD<WindowedValue<T>>>() { @Override public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> rdd) throws Exception { - final Accumulator<NamedAggregators> accum = - SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); - return rdd.mapPartitions( - new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, null, null)); + return rdd.map(new SparkAssignWindowFn<>(transform.getWindowFn())); } }); context.putDataset(transform, new UnboundedDataset<>(outStream)); @@ -350,13 +344,13 @@ final class StreamingTransformTranslator { @Override public void evaluate(final ParDo.Bound<InputT, OutputT> transform, final EvaluationContext context) { - DoFn<InputT, OutputT> doFn = transform.getNewFn(); + final DoFn<InputT, OutputT> doFn = transform.getNewFn(); rejectStateAndTimers(doFn); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); - final WindowFn<Object, ?> windowFn = - (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn(); + final WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); JavaDStream<WindowedValue<InputT>> dStream = ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream(); @@ -369,7 +363,7 @@ final class StreamingTransformTranslator { final Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); return rdd.mapPartitions( - new DoFnFunction<>(accum, transform.getFn(), runtimeContext, sideInputs, windowFn)); + new DoFnFunction<>(accum, doFn, runtimeContext, sideInputs, windowingStrategy)); } }); @@ -384,14 +378,13 @@ final class StreamingTransformTranslator { @Override public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform, final EvaluationContext context) { - DoFn<InputT, OutputT> doFn = transform.getNewFn(); + final DoFn<InputT, OutputT> doFn = transform.getNewFn(); rejectStateAndTimers(doFn); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), context); - @SuppressWarnings("unchecked") - final WindowFn<Object, ?> windowFn = - (WindowFn<Object, ?>) context.getInput(transform).getWindowingStrategy().getWindowFn(); + final WindowingStrategy<?, ?> windowingStrategy = + context.getInput(transform).getWindowingStrategy(); @SuppressWarnings("unchecked") JavaDStream<WindowedValue<InputT>> dStream = ((UnboundedDataset<InputT>) context.borrowDataset(transform)).getDStream(); @@ -403,8 +396,8 @@ final class StreamingTransformTranslator { JavaRDD<WindowedValue<InputT>> rdd) throws Exception { final Accumulator<NamedAggregators> accum = SparkAggregators.getNamedAggregators(new JavaSparkContext(rdd.context())); - return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, transform.getFn(), - runtimeContext, transform.getMainOutputTag(), sideInputs, windowFn)); + return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, doFn, + runtimeContext, transform.getMainOutputTag(), sideInputs, windowingStrategy)); } }).cache(); PCollectionTuple pct = context.getOutput(transform); @@ -423,8 +416,8 @@ final class StreamingTransformTranslator { }; } - private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = Maps - .newHashMap(); + private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> EVALUATORS = + Maps.newHashMap(); static { EVALUATORS.put(Read.Unbounded.class, readUnbounded()); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java index 471ec92..0284b3d 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java @@ -27,8 +27,8 @@ import org.apache.beam.runners.spark.SparkPipelineResult; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.transforms.Aggregator; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; -import org.apache.beam.sdk.transforms.OldDoFn; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.Values; @@ -55,11 +55,12 @@ public final class PAssertStreaming implements Serializable { * Note that it is oblivious to windowing, so the assertion will apply indiscriminately to all * windows. */ - public static <T> SparkPipelineResult runAndAssertContents(Pipeline p, - PCollection<T> actual, - T[] expected, - Duration timeout, - boolean stopGracefully) { + public static <T> SparkPipelineResult runAndAssertContents( + Pipeline p, + PCollection<T> actual, + T[] expected, + Duration timeout, + boolean stopGracefully) { // Because PAssert does not support non-global windowing, but all our data is in one window, // we set up the assertion directly. actual @@ -86,14 +87,15 @@ public final class PAssertStreaming implements Serializable { * Default to stop gracefully so that tests will finish processing even if slower for reasons * such as a slow runtime environment. */ - public static <T> SparkPipelineResult runAndAssertContents(Pipeline p, - PCollection<T> actual, - T[] expected, - Duration timeout) { + public static <T> SparkPipelineResult runAndAssertContents( + Pipeline p, + PCollection<T> actual, + T[] expected, + Duration timeout) { return runAndAssertContents(p, actual, expected, timeout, true); } - private static class AssertDoFn<T> extends OldDoFn<Iterable<T>, Void> { + private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> { private final Aggregator<Integer, Integer> success = createAggregator(PAssert.SUCCESS_COUNTER, new Sum.SumIntegerFn()); private final Aggregator<Integer, Integer> failure = @@ -104,7 +106,7 @@ public final class PAssertStreaming implements Serializable { this.expected = expected; } - @Override + @ProcessElement public void processElement(ProcessContext c) throws Exception { try { assertThat(c.element(), containsInAnyOrder(expected));