Repository: beam Updated Branches: refs/heads/master 9cdae6caf -> 43c44232d
[BEAM-2175] [BEAM-1115] Support for new State and Timer API in Spark batch mode Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e5fbed7 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e5fbed7 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e5fbed7 Branch: refs/heads/master Commit: 5e5fbed70af5d6ff827266d3db89cd5d8d51f544 Parents: 9cdae6c Author: JingsongLi <lzljs3620...@aliyun.com> Authored: Wed May 10 19:49:04 2017 +0800 Committer: Aviem Zur <aviem...@gmail.com> Committed: Sat Jun 3 16:49:59 2017 +0300 ---------------------------------------------------------------------- runners/spark/pom.xml | 2 - .../spark/translation/MultiDoFnFunction.java | 104 +++++++++++++++++-- .../spark/translation/SparkProcessContext.java | 23 +++- .../spark/translation/TransformTranslator.java | 84 ++++++++++++--- .../streaming/StreamingTransformTranslator.java | 3 +- 5 files changed, 189 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/pom.xml ---------------------------------------------------------------------- diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml index 697f67a..ddb4aca 100644 --- a/runners/spark/pom.xml +++ b/runners/spark/pom.xml @@ -77,8 +77,6 @@ org.apache.beam.runners.spark.UsesCheckpointRecovery </groups> <excludedGroups> - org.apache.beam.sdk.testing.UsesStatefulParDo, - org.apache.beam.sdk.testing.UsesTimersInParDo, org.apache.beam.sdk.testing.UsesSplittableParDo, org.apache.beam.sdk.testing.UsesCommittedMetrics, org.apache.beam.sdk.testing.UsesTestStream http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java index 3274912..23d5b32 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java @@ -22,16 +22,24 @@ import com.google.common.base.Function; import com.google.common.collect.Iterators; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.Multimap; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; +import org.apache.beam.runners.core.InMemoryStateInternals; +import org.apache.beam.runners.core.InMemoryTimerInternals; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StepContext; +import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.spark.aggregators.NamedAggregators; import org.apache.beam.runners.spark.util.SideInputBroadcast; import org.apache.beam.runners.spark.util.SparkSideInputReader; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TupleTag; @@ -60,6 +68,7 @@ public class MultiDoFnFunction<InputT, OutputT> private final List<TupleTag<?>> additionalOutputTags; private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs; private final WindowingStrategy<?, ?> windowingStrategy; + private final boolean stateful; /** * @param aggAccum The Spark {@link Accumulator} that backs the Beam Aggregators. @@ -70,6 +79,7 @@ public class MultiDoFnFunction<InputT, OutputT> * @param additionalOutputTags Additional {@link TupleTag output tags}. * @param sideInputs Side inputs used in this {@link DoFn}. * @param windowingStrategy Input {@link WindowingStrategy}. + * @param stateful Stateful {@link DoFn}. */ public MultiDoFnFunction( Accumulator<NamedAggregators> aggAccum, @@ -80,7 +90,8 @@ public class MultiDoFnFunction<InputT, OutputT> TupleTag<OutputT> mainOutputTag, List<TupleTag<?>> additionalOutputTags, Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs, - WindowingStrategy<?, ?> windowingStrategy) { + WindowingStrategy<?, ?> windowingStrategy, + boolean stateful) { this.aggAccum = aggAccum; this.metricsAccum = metricsAccum; this.stepName = stepName; @@ -90,6 +101,7 @@ public class MultiDoFnFunction<InputT, OutputT> this.additionalOutputTags = additionalOutputTags; this.sideInputs = sideInputs; this.windowingStrategy = windowingStrategy; + this.stateful = stateful; } @Override @@ -98,7 +110,35 @@ public class MultiDoFnFunction<InputT, OutputT> DoFnOutputManager outputManager = new DoFnOutputManager(); - DoFnRunner<InputT, OutputT> doFnRunner = + final InMemoryTimerInternals timerInternals; + final StepContext context; + // Now only implements the StatefulParDo in Batch mode. + if (stateful) { + Object key = null; + if (iter.hasNext()) { + WindowedValue<InputT> currentValue = iter.next(); + key = ((KV) currentValue.getValue()).getKey(); + iter = Iterators.concat(Iterators.singletonIterator(currentValue), iter); + } + final InMemoryStateInternals<?> stateInternals = InMemoryStateInternals.forKey(key); + timerInternals = new InMemoryTimerInternals(); + context = new StepContext(){ + @Override + public StateInternals stateInternals() { + return stateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + }; + } else { + timerInternals = null; + context = new SparkProcessContext.NoOpStepContext(); + } + + final DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner( runtimeContext.getPipelineOptions(), doFn, @@ -106,20 +146,72 @@ public class MultiDoFnFunction<InputT, OutputT> outputManager, mainOutputTag, additionalOutputTags, - new SparkProcessContext.NoOpStepContext(), + context, windowingStrategy); DoFnRunnerWithMetrics<InputT, OutputT> doFnRunnerWithMetrics = new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); - return new SparkProcessContext<>(doFn, doFnRunnerWithMetrics, outputManager) - .processPartition(iter); + return new SparkProcessContext<>( + doFn, doFnRunnerWithMetrics, outputManager, + stateful ? new TimerDataIterator(timerInternals) : + Collections.<TimerInternals.TimerData>emptyIterator()).processPartition(iter); + } + + private static class TimerDataIterator implements Iterator<TimerInternals.TimerData> { + + private InMemoryTimerInternals timerInternals; + private boolean hasAdvance; + private TimerInternals.TimerData timerData; + + TimerDataIterator(InMemoryTimerInternals timerInternals) { + this.timerInternals = timerInternals; + } + + @Override + public boolean hasNext() { + + // Advance + if (!hasAdvance) { + try { + // Finish any pending windows by advancing the input watermark to infinity. + timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE); + // Finally, advance the processing time to infinity to fire any timers. + timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE); + timerInternals.advanceSynchronizedProcessingTime( + BoundedWindow.TIMESTAMP_MAX_VALUE); + } catch (Exception e) { + throw new RuntimeException(e); + } + hasAdvance = true; + } + + // Get timer data + return (timerData = timerInternals.removeNextEventTimer()) != null + || (timerData = timerInternals.removeNextProcessingTimer()) != null + || (timerData = timerInternals.removeNextSynchronizedProcessingTimer()) != null; + } + + @Override + public TimerInternals.TimerData next() { + if (timerData == null) { + throw new NoSuchElementException(); + } else { + return timerData; + } + } + + @Override + public void remove() { + throw new RuntimeException("TimerDataIterator not support remove!"); + } + } private class DoFnOutputManager implements SparkProcessContext.SparkOutputManager<Tuple2<TupleTag<?>, WindowedValue<?>>> { - private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();; + private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create(); @Override public void clear() { http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java index f4ab7d9..729eb1c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java @@ -18,16 +18,21 @@ package org.apache.beam.runners.spark.translation; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.common.collect.AbstractIterator; import com.google.common.collect.Lists; import java.util.Iterator; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners.OutputManager; import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StepContext; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; @@ -39,15 +44,18 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> { private final DoFn<FnInputT, FnOutputT> doFn; private final DoFnRunner<FnInputT, FnOutputT> doFnRunner; private final SparkOutputManager<OutputT> outputManager; + private Iterator<TimerInternals.TimerData> timerDataIterator; SparkProcessContext( DoFn<FnInputT, FnOutputT> doFn, DoFnRunner<FnInputT, FnOutputT> doFnRunner, - SparkOutputManager<OutputT> outputManager) { + SparkOutputManager<OutputT> outputManager, + Iterator<TimerInternals.TimerData> timerDataIterator) { this.doFn = doFn; this.doFnRunner = doFnRunner; this.outputManager = outputManager; + this.timerDataIterator = timerDataIterator; } Iterable<OutputT> processPartition( @@ -137,6 +145,10 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> { // grab the next element and process it. doFnRunner.processElement(inputIterator.next()); outputIterator = getOutputIterator(); + } else if (timerDataIterator.hasNext()) { + clearOutput(); + fireTimer(timerDataIterator.next()); + outputIterator = getOutputIterator(); } else { // no more input to consume, but finishBundle can produce more output if (!calledFinish) { @@ -152,5 +164,14 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> { } } } + + private void fireTimer( + TimerInternals.TimerData timer) { + StateNamespace namespace = timer.getNamespace(); + checkArgument(namespace instanceof StateNamespaces.WindowNamespace); + BoundedWindow window = ((StateNamespaces.WindowNamespace) namespace).getWindow(); + doFnRunner.onTimer(timer.getTimerId(), window, timer.getTimestamp(), timer.getDomain()); + } + } } http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 742ea83..64aa35a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -21,13 +21,14 @@ package org.apache.beam.runners.spark.translation; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable; -import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import com.google.common.base.Optional; +import com.google.common.collect.FluentIterable; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.util.Collection; import java.util.Collections; +import java.util.Iterator; import java.util.Map; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; @@ -52,6 +53,8 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Reshuffle; import org.apache.beam.sdk.transforms.View; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.transforms.windowing.WindowFn; @@ -347,41 +350,57 @@ public final class TransformTranslator { private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() { return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() { @Override + @SuppressWarnings("unchecked") public void evaluate( ParDo.MultiOutput<InputT, OutputT> transform, EvaluationContext context) { String stepName = context.getCurrentTransform().getFullName(); DoFn<InputT, OutputT> doFn = transform.getFn(); rejectSplittable(doFn); - rejectStateAndTimers(doFn); - @SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRDD = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD(); WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy(); Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance(); Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance(); - JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = - inRDD.mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - transform.getMainOutputTag(), - transform.getAdditionalOutputTags().getAll(), - TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy)); + + JavaPairRDD<TupleTag<?>, WindowedValue<?>> all; + + DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass()); + boolean stateful = signature.stateDeclarations().size() > 0 + || signature.timerDeclarations().size() > 0; + + MultiDoFnFunction<InputT, OutputT> multiDoFnFunction = new MultiDoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + context.getRuntimeContext(), + transform.getMainOutputTag(), + transform.getAdditionalOutputTags().getAll(), + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + windowingStrategy, + stateful); + + if (stateful) { + // Based on the fact that the signature is stateful, DoFnSignatures ensures + // that it is also keyed + all = statefulParDoTransform( + (KvCoder) context.getInput(transform).getCoder(), + windowingStrategy.getWindowFn().windowCoder(), + (JavaRDD) inRDD, + (MultiDoFnFunction) multiDoFnFunction); + } else { + all = inRDD.mapPartitionsToPair(multiDoFnFunction); + } + Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform); if (outputs.size() > 1) { // cache the RDD if we're going to filter it more than once. all.cache(); } for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) { - @SuppressWarnings("unchecked") JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey())); - @SuppressWarnings("unchecked") // Object is the best we can do since different outputs can have different tags JavaRDD<WindowedValue<Object>> values = (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values(); @@ -396,6 +415,37 @@ public final class TransformTranslator { }; } + private static <K, V, OutputT> JavaPairRDD<TupleTag<?>, WindowedValue<?>> statefulParDoTransform( + KvCoder<K, V> kvCoder, + Coder<? extends BoundedWindow> windowCoder, + JavaRDD<WindowedValue<KV<K, V>>> kvInRDD, + MultiDoFnFunction<KV<K, V>, OutputT> doFnFunction) { + Coder<K> keyCoder = kvCoder.getKeyCoder(); + + final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of( + kvCoder.getValueCoder(), windowCoder); + + JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupRDD = + GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder); + + return groupRDD.map(new Function< + WindowedValue<KV<K, Iterable<WindowedValue<V>>>>, Iterator<WindowedValue<KV<K, V>>>>() { + @Override + public Iterator<WindowedValue<KV<K, V>>> call( + WindowedValue<KV<K, Iterable<WindowedValue<V>>>> input) throws Exception { + final K key = input.getValue().getKey(); + Iterable<WindowedValue<V>> value = input.getValue().getValue(); + return FluentIterable.from(value).transform( + new com.google.common.base.Function<WindowedValue<V>, WindowedValue<KV<K, V>>>() { + @Override + public WindowedValue<KV<K, V>> apply(WindowedValue<V> windowedValue) { + return windowedValue.withValue(KV.of(key, windowedValue.getValue())); + } + }).iterator(); + } + }).flatMapToPair(doFnFunction); + } + private static <T> TransformEvaluator<Read.Bounded<T>> readBounded() { return new TransformEvaluator<Read.Bounded<T>>() { @Override http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 43f4b75..cd5bb3e 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -413,7 +413,8 @@ public final class StreamingTransformTranslator { transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), sideInputs, - windowingStrategy)); + windowingStrategy, + false)); } }); Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);