Repository: beam
Updated Branches:
  refs/heads/master 9cdae6caf -> 43c44232d


[BEAM-2175] [BEAM-1115] Support for new State and Timer API in Spark batch mode


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e5fbed7
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e5fbed7
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e5fbed7

Branch: refs/heads/master
Commit: 5e5fbed70af5d6ff827266d3db89cd5d8d51f544
Parents: 9cdae6c
Author: JingsongLi <lzljs3620...@aliyun.com>
Authored: Wed May 10 19:49:04 2017 +0800
Committer: Aviem Zur <aviem...@gmail.com>
Committed: Sat Jun 3 16:49:59 2017 +0300

----------------------------------------------------------------------
 runners/spark/pom.xml                           |   2 -
 .../spark/translation/MultiDoFnFunction.java    | 104 +++++++++++++++++--
 .../spark/translation/SparkProcessContext.java  |  23 +++-
 .../spark/translation/TransformTranslator.java  |  84 ++++++++++++---
 .../streaming/StreamingTransformTranslator.java |   3 +-
 5 files changed, 189 insertions(+), 27 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/pom.xml
----------------------------------------------------------------------
diff --git a/runners/spark/pom.xml b/runners/spark/pom.xml
index 697f67a..ddb4aca 100644
--- a/runners/spark/pom.xml
+++ b/runners/spark/pom.xml
@@ -77,8 +77,6 @@
                     org.apache.beam.runners.spark.UsesCheckpointRecovery
                   </groups>
                   <excludedGroups>
-                    org.apache.beam.sdk.testing.UsesStatefulParDo,
-                    org.apache.beam.sdk.testing.UsesTimersInParDo,
                     org.apache.beam.sdk.testing.UsesSplittableParDo,
                     org.apache.beam.sdk.testing.UsesCommittedMetrics,
                     org.apache.beam.sdk.testing.UsesTestStream

http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 3274912..23d5b32 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -22,16 +22,24 @@ import com.google.common.base.Function;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.LinkedListMultimap;
 import com.google.common.collect.Multimap;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.NoSuchElementException;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.InMemoryStateInternals;
+import org.apache.beam.runners.core.InMemoryTimerInternals;
+import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StepContext;
+import org.apache.beam.runners.core.TimerInternals;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.util.SideInputBroadcast;
 import org.apache.beam.runners.spark.util.SparkSideInputReader;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TupleTag;
@@ -60,6 +68,7 @@ public class MultiDoFnFunction<InputT, OutputT>
   private final List<TupleTag<?>> additionalOutputTags;
   private final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
SideInputBroadcast<?>>> sideInputs;
   private final WindowingStrategy<?, ?> windowingStrategy;
+  private final boolean stateful;
 
   /**
    * @param aggAccum       The Spark {@link Accumulator} that backs the Beam 
Aggregators.
@@ -70,6 +79,7 @@ public class MultiDoFnFunction<InputT, OutputT>
    * @param additionalOutputTags Additional {@link TupleTag output tags}.
    * @param sideInputs        Side inputs used in this {@link DoFn}.
    * @param windowingStrategy Input {@link WindowingStrategy}.
+   * @param stateful          Stateful {@link DoFn}.
    */
   public MultiDoFnFunction(
       Accumulator<NamedAggregators> aggAccum,
@@ -80,7 +90,8 @@ public class MultiDoFnFunction<InputT, OutputT>
       TupleTag<OutputT> mainOutputTag,
       List<TupleTag<?>> additionalOutputTags,
       Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> 
sideInputs,
-      WindowingStrategy<?, ?> windowingStrategy) {
+      WindowingStrategy<?, ?> windowingStrategy,
+      boolean stateful) {
     this.aggAccum = aggAccum;
     this.metricsAccum = metricsAccum;
     this.stepName = stepName;
@@ -90,6 +101,7 @@ public class MultiDoFnFunction<InputT, OutputT>
     this.additionalOutputTags = additionalOutputTags;
     this.sideInputs = sideInputs;
     this.windowingStrategy = windowingStrategy;
+    this.stateful = stateful;
   }
 
   @Override
@@ -98,7 +110,35 @@ public class MultiDoFnFunction<InputT, OutputT>
 
     DoFnOutputManager outputManager = new DoFnOutputManager();
 
-    DoFnRunner<InputT, OutputT> doFnRunner =
+    final InMemoryTimerInternals timerInternals;
+    final StepContext context;
+    // Now only implements the StatefulParDo in Batch mode.
+    if (stateful) {
+      Object key = null;
+      if (iter.hasNext()) {
+        WindowedValue<InputT> currentValue = iter.next();
+        key = ((KV) currentValue.getValue()).getKey();
+        iter = Iterators.concat(Iterators.singletonIterator(currentValue), 
iter);
+      }
+      final InMemoryStateInternals<?> stateInternals = 
InMemoryStateInternals.forKey(key);
+      timerInternals = new InMemoryTimerInternals();
+      context = new StepContext(){
+        @Override
+        public StateInternals stateInternals() {
+          return stateInternals;
+        }
+
+        @Override
+        public TimerInternals timerInternals() {
+          return timerInternals;
+        }
+      };
+    } else {
+      timerInternals = null;
+      context = new SparkProcessContext.NoOpStepContext();
+    }
+
+    final DoFnRunner<InputT, OutputT> doFnRunner =
         DoFnRunners.simpleRunner(
             runtimeContext.getPipelineOptions(),
             doFn,
@@ -106,20 +146,72 @@ public class MultiDoFnFunction<InputT, OutputT>
             outputManager,
             mainOutputTag,
             additionalOutputTags,
-            new SparkProcessContext.NoOpStepContext(),
+            context,
             windowingStrategy);
 
     DoFnRunnerWithMetrics<InputT, OutputT> doFnRunnerWithMetrics =
         new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum);
 
-    return new SparkProcessContext<>(doFn, doFnRunnerWithMetrics, 
outputManager)
-        .processPartition(iter);
+    return new SparkProcessContext<>(
+        doFn, doFnRunnerWithMetrics, outputManager,
+        stateful ? new TimerDataIterator(timerInternals) :
+            
Collections.<TimerInternals.TimerData>emptyIterator()).processPartition(iter);
+  }
+
+  private static class TimerDataIterator implements 
Iterator<TimerInternals.TimerData> {
+
+    private InMemoryTimerInternals timerInternals;
+    private boolean hasAdvance;
+    private TimerInternals.TimerData timerData;
+
+    TimerDataIterator(InMemoryTimerInternals timerInternals) {
+      this.timerInternals = timerInternals;
+    }
+
+    @Override
+    public boolean hasNext() {
+
+      // Advance
+      if (!hasAdvance) {
+        try {
+          // Finish any pending windows by advancing the input watermark to 
infinity.
+          
timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
+          // Finally, advance the processing time to infinity to fire any 
timers.
+          
timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+          timerInternals.advanceSynchronizedProcessingTime(
+              BoundedWindow.TIMESTAMP_MAX_VALUE);
+        } catch (Exception e) {
+          throw new RuntimeException(e);
+        }
+        hasAdvance = true;
+      }
+
+      // Get timer data
+      return (timerData = timerInternals.removeNextEventTimer()) != null
+          || (timerData = timerInternals.removeNextProcessingTimer()) != null
+          || (timerData = 
timerInternals.removeNextSynchronizedProcessingTimer()) != null;
+    }
+
+    @Override
+    public TimerInternals.TimerData next() {
+      if (timerData == null) {
+        throw new NoSuchElementException();
+      } else {
+        return timerData;
+      }
+    }
+
+    @Override
+    public void remove() {
+      throw new RuntimeException("TimerDataIterator not support remove!");
+    }
+
   }
 
   private class DoFnOutputManager
       implements SparkProcessContext.SparkOutputManager<Tuple2<TupleTag<?>, 
WindowedValue<?>>> {
 
-    private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = 
LinkedListMultimap.create();;
+    private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = 
LinkedListMultimap.create();
 
     @Override
     public void clear() {

http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index f4ab7d9..729eb1c 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -18,16 +18,21 @@
 
 package org.apache.beam.runners.spark.translation;
 
+import static com.google.common.base.Preconditions.checkArgument;
+
 import com.google.common.collect.AbstractIterator;
 import com.google.common.collect.Lists;
 import java.util.Iterator;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners.OutputManager;
 import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.StateNamespace;
+import org.apache.beam.runners.core.StateNamespaces;
 import org.apache.beam.runners.core.StepContext;
 import org.apache.beam.runners.core.TimerInternals;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 
 
@@ -39,15 +44,18 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
   private final DoFn<FnInputT, FnOutputT> doFn;
   private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
   private final SparkOutputManager<OutputT> outputManager;
+  private Iterator<TimerInternals.TimerData> timerDataIterator;
 
   SparkProcessContext(
       DoFn<FnInputT, FnOutputT> doFn,
       DoFnRunner<FnInputT, FnOutputT> doFnRunner,
-      SparkOutputManager<OutputT> outputManager) {
+      SparkOutputManager<OutputT> outputManager,
+      Iterator<TimerInternals.TimerData> timerDataIterator) {
 
     this.doFn = doFn;
     this.doFnRunner = doFnRunner;
     this.outputManager = outputManager;
+    this.timerDataIterator = timerDataIterator;
   }
 
   Iterable<OutputT> processPartition(
@@ -137,6 +145,10 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
           // grab the next element and process it.
           doFnRunner.processElement(inputIterator.next());
           outputIterator = getOutputIterator();
+        } else if (timerDataIterator.hasNext()) {
+          clearOutput();
+          fireTimer(timerDataIterator.next());
+          outputIterator = getOutputIterator();
         } else {
           // no more input to consume, but finishBundle can produce more output
           if (!calledFinish) {
@@ -152,5 +164,14 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
         }
       }
     }
+
+    private void fireTimer(
+        TimerInternals.TimerData timer) {
+      StateNamespace namespace = timer.getNamespace();
+      checkArgument(namespace instanceof StateNamespaces.WindowNamespace);
+      BoundedWindow window = ((StateNamespaces.WindowNamespace) 
namespace).getWindow();
+      doFnRunner.onTimer(timer.getTimerId(), window, timer.getTimestamp(), 
timer.getDomain());
+    }
+
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 742ea83..64aa35a 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -21,13 +21,14 @@ package org.apache.beam.runners.spark.translation;
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkState;
 import static 
org.apache.beam.runners.spark.translation.TranslationUtils.rejectSplittable;
-import static 
org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;
 
 import com.google.common.base.Optional;
+import com.google.common.collect.FluentIterable;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.Iterator;
 import java.util.Map;
 import org.apache.beam.runners.core.SystemReduceFn;
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
@@ -52,6 +53,8 @@ import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.Window;
 import org.apache.beam.sdk.transforms.windowing.WindowFn;
@@ -347,41 +350,57 @@ public final class TransformTranslator {
   private static <InputT, OutputT> 
TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
     return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
       @Override
+      @SuppressWarnings("unchecked")
       public void evaluate(
           ParDo.MultiOutput<InputT, OutputT> transform, EvaluationContext 
context) {
         String stepName = context.getCurrentTransform().getFullName();
         DoFn<InputT, OutputT> doFn = transform.getFn();
         rejectSplittable(doFn);
-        rejectStateAndTimers(doFn);
-        @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<InputT>> inRDD =
             ((BoundedDataset<InputT>) 
context.borrowDataset(transform)).getRDD();
         WindowingStrategy<?, ?> windowingStrategy =
             context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> aggAccum = 
AggregatorsAccumulator.getInstance();
         Accumulator<MetricsContainerStepMap> metricsAccum = 
MetricsAccumulator.getInstance();
-        JavaPairRDD<TupleTag<?>, WindowedValue<?>> all =
-            inRDD.mapPartitionsToPair(
-                new MultiDoFnFunction<>(
-                    aggAccum,
-                    metricsAccum,
-                    stepName,
-                    doFn,
-                    context.getRuntimeContext(),
-                    transform.getMainOutputTag(),
-                    transform.getAdditionalOutputTags().getAll(),
-                    TranslationUtils.getSideInputs(transform.getSideInputs(), 
context),
-                    windowingStrategy));
+
+        JavaPairRDD<TupleTag<?>, WindowedValue<?>> all;
+
+        DoFnSignature signature = 
DoFnSignatures.getSignature(transform.getFn().getClass());
+        boolean stateful = signature.stateDeclarations().size() > 0
+            || signature.timerDeclarations().size() > 0;
+
+        MultiDoFnFunction<InputT, OutputT> multiDoFnFunction = new 
MultiDoFnFunction<>(
+            aggAccum,
+            metricsAccum,
+            stepName,
+            doFn,
+            context.getRuntimeContext(),
+            transform.getMainOutputTag(),
+            transform.getAdditionalOutputTags().getAll(),
+            TranslationUtils.getSideInputs(transform.getSideInputs(), context),
+            windowingStrategy,
+            stateful);
+
+        if (stateful) {
+          // Based on the fact that the signature is stateful, DoFnSignatures 
ensures
+          // that it is also keyed
+          all = statefulParDoTransform(
+              (KvCoder) context.getInput(transform).getCoder(),
+              windowingStrategy.getWindowFn().windowCoder(),
+              (JavaRDD) inRDD,
+              (MultiDoFnFunction) multiDoFnFunction);
+        } else {
+          all = inRDD.mapPartitionsToPair(multiDoFnFunction);
+        }
+
         Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
         if (outputs.size() > 1) {
           // cache the RDD if we're going to filter it more than once.
           all.cache();
         }
         for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
-          @SuppressWarnings("unchecked")
           JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
               all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
-          @SuppressWarnings("unchecked")
           // Object is the best we can do since different outputs can have 
different tags
           JavaRDD<WindowedValue<Object>> values =
               (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
@@ -396,6 +415,37 @@ public final class TransformTranslator {
     };
   }
 
+  private static <K, V, OutputT> JavaPairRDD<TupleTag<?>, WindowedValue<?>> 
statefulParDoTransform(
+      KvCoder<K, V> kvCoder,
+      Coder<? extends BoundedWindow> windowCoder,
+      JavaRDD<WindowedValue<KV<K, V>>> kvInRDD,
+      MultiDoFnFunction<KV<K, V>, OutputT> doFnFunction) {
+    Coder<K> keyCoder = kvCoder.getKeyCoder();
+
+    final WindowedValue.WindowedValueCoder<V> wvCoder = 
WindowedValue.FullWindowedValueCoder.of(
+        kvCoder.getValueCoder(), windowCoder);
+
+    JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupRDD =
+        GroupCombineFunctions.groupByKeyOnly(kvInRDD, keyCoder, wvCoder);
+
+    return groupRDD.map(new Function<
+        WindowedValue<KV<K, Iterable<WindowedValue<V>>>>, 
Iterator<WindowedValue<KV<K, V>>>>() {
+      @Override
+      public Iterator<WindowedValue<KV<K, V>>> call(
+          WindowedValue<KV<K, Iterable<WindowedValue<V>>>> input) throws 
Exception {
+        final K key = input.getValue().getKey();
+        Iterable<WindowedValue<V>> value = input.getValue().getValue();
+        return FluentIterable.from(value).transform(
+            new com.google.common.base.Function<WindowedValue<V>, 
WindowedValue<KV<K, V>>>() {
+              @Override
+              public WindowedValue<KV<K, V>> apply(WindowedValue<V> 
windowedValue) {
+                return windowedValue.withValue(KV.of(key, 
windowedValue.getValue()));
+              }
+            }).iterator();
+      }
+    }).flatMapToPair(doFnFunction);
+  }
+
   private static <T> TransformEvaluator<Read.Bounded<T>> readBounded() {
     return new TransformEvaluator<Read.Bounded<T>>() {
       @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/5e5fbed7/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 43f4b75..cd5bb3e 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -413,7 +413,8 @@ public final class StreamingTransformTranslator {
                             transform.getMainOutputTag(),
                             transform.getAdditionalOutputTags().getAll(),
                             sideInputs,
-                            windowingStrategy));
+                            windowingStrategy,
+                            false));
                   }
                 });
         Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);

Reply via email to