[BEAM-807] Replace OldDoFn with DoFn.

Add a custom AssignWindows implementation.

Setup and teardown DoFn.

Add implementation for GroupAlsoByWindow via flatMap.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ffed3e0
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ffed3e0
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ffed3e0

Branch: refs/heads/master
Commit: 4ffed3e09a2f0ec3583098f6cfd53a2ddcc6f8c2
Parents: 2be9a15
Author: Sela <ans...@paypal.com>
Authored: Sun Dec 11 14:32:49 2016 +0200
Committer: Sela <ans...@paypal.com>
Committed: Tue Dec 13 10:05:18 2016 +0200

----------------------------------------------------------------------
 .../beam/runners/spark/examples/WordCount.java  |   6 +-
 .../runners/spark/translation/DoFnFunction.java |   2 +-
 .../translation/GroupCombineFunctions.java      |  23 +-
 .../spark/translation/MultiDoFnFunction.java    |   2 +-
 .../spark/translation/SparkAssignWindowFn.java  |  69 ++++++
 .../translation/SparkGroupAlsoByWindowFn.java   | 214 +++++++++++++++++++
 .../spark/translation/SparkProcessContext.java  |  10 +
 .../spark/translation/TransformTranslator.java  |  31 +--
 .../streaming/StreamingTransformTranslator.java |  35 ++-
 .../streaming/utils/PAssertStreaming.java       |  26 +--
 10 files changed, 345 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
index b2672b5..1252d12 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/examples/WordCount.java
@@ -25,8 +25,8 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SimpleFunction;
@@ -44,11 +44,11 @@ public class WordCount {
    * of-line. This DoFn tokenizes lines of text into individual words; we pass 
it to a ParDo in the
    * pipeline.
    */
-  static class ExtractWordsFn extends OldDoFn<String, String> {
+  static class ExtractWordsFn extends DoFn<String, String> {
     private final Aggregator<Long, Long> emptyLines =
         createAggregator("emptyLines", new Sum.SumLongFn());
 
-    @Override
+    @ProcessElement
     public void processElement(ProcessContext c) {
       if (c.element().trim().isEmpty()) {
         emptyLines.addValue(1L);

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
index 4c49a7f..6a641b5 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java
@@ -93,7 +93,7 @@ public class DoFnFunction<InputT, OutputT>
             windowingStrategy
         );
 
-    return new SparkProcessContext<>(doFnRunner, 
outputManager).processPartition(iter);
+    return new SparkProcessContext<>(doFn, doFnRunner, 
outputManager).processPartition(iter);
   }
 
   private class DoFnOutputManager

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
index 421b1b0..4875b0c 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java
@@ -18,11 +18,9 @@
 
 package org.apache.beam.runners.spark.translation;
 
-
 import com.google.common.collect.Lists;
 import java.util.Collections;
 import java.util.Map;
-import org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn;
 import org.apache.beam.runners.core.SystemReduceFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -33,9 +31,7 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.transforms.CombineWithContext;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.KV;
@@ -59,7 +55,7 @@ public class GroupCombineFunctions {
   /**
    * Apply {@link org.apache.beam.sdk.transforms.GroupByKey} to a Spark RDD.
    */
-  public static <K, V,  W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
+  public static <K, V, W extends BoundedWindow> JavaRDD<WindowedValue<KV<K,
       Iterable<V>>>> groupByKey(JavaRDD<WindowedValue<KV<K, V>>> rdd,
                                 Accumulator<NamedAggregators> accum,
                                 KvCoder<K, V> coder,
@@ -86,15 +82,14 @@ public class GroupCombineFunctions {
             .map(WindowingHelpers.<KV<K, 
Iterable<WindowedValue<V>>>>windowFunction());
 
     //--- now group also by window.
-    @SuppressWarnings("unchecked")
-    WindowFn<Object, W> windowFn = (WindowFn<Object, W>) 
windowingStrategy.getWindowFn();
-    // GroupAlsoByWindow current uses a dummy in-memory StateInternals
-    OldDoFn<KV<K, Iterable<WindowedValue<V>>>, KV<K, Iterable<V>>> gabwDoFn =
-        new GroupAlsoByWindowsViaOutputBufferDoFn<K, V, Iterable<V>, W>(
-            windowingStrategy, new 
TranslationUtils.InMemoryStateInternalsFactory<K>(),
-                SystemReduceFn.<K, V, W>buffering(valueCoder));
-    return groupedByKey.mapPartitions(new DoFnFunction<>(accum, gabwDoFn, 
runtimeContext, null,
-        windowFn));
+    // GroupAlsoByWindow currently uses a dummy in-memory StateInternals
+    return groupedByKey.flatMap(
+        new SparkGroupAlsoByWindowFn<>(
+            windowingStrategy,
+            new TranslationUtils.InMemoryStateInternalsFactory<K>(),
+            SystemReduceFn.<K, V, W>buffering(valueCoder),
+            runtimeContext,
+            accum));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
index 710c5cd..8a55369 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java
@@ -102,7 +102,7 @@ public class MultiDoFnFunction<InputT, OutputT>
             windowingStrategy
         );
 
-    return new SparkProcessContext<>(doFnRunner, 
outputManager).processPartition(iter);
+    return new SparkProcessContext<>(doFn, doFnRunner, 
outputManager).processPartition(iter);
   }
 
   private class DoFnOutputManager

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
new file mode 100644
index 0000000..9d7ed7d
--- /dev/null
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkAssignWindowFn.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.common.collect.Iterables;
+import java.util.Collection;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.spark.api.java.function.Function;
+import org.joda.time.Instant;
+
+
+/**
+ * An implementation of {@link org.apache.beam.runners.core.AssignWindows} for 
the Spark runner.
+ */
+public class SparkAssignWindowFn<T, W extends BoundedWindow>
+    implements Function<WindowedValue<T>, WindowedValue<T>> {
+
+  private WindowFn<? super T, W> fn;
+
+  public SparkAssignWindowFn(WindowFn<? super T, W> fn) {
+    this.fn = fn;
+  }
+
+  @Override
+  @SuppressWarnings("unchecked")
+  public WindowedValue<T> call(WindowedValue<T> windowedValue) throws 
Exception {
+    final BoundedWindow boundedWindow = 
Iterables.getOnlyElement(windowedValue.getWindows());
+    final T element = windowedValue.getValue();
+    final Instant timestamp = windowedValue.getTimestamp();
+    Collection<W> windows =
+        ((WindowFn<T, W>) fn).assignWindows(
+            ((WindowFn<T, W>) fn).new AssignContext() {
+                @Override
+                public T element() {
+                  return element;
+                }
+
+                @Override
+                public Instant timestamp() {
+                  return timestamp;
+                }
+
+                @Override
+                public BoundedWindow window() {
+                  return boundedWindow;
+                }
+              });
+    return WindowedValue.of(element, timestamp, windows, PaneInfo.NO_FIRING);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
new file mode 100644
index 0000000..87d3f50
--- /dev/null
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkGroupAlsoByWindowFn.java
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.spark.translation;
+
+import com.google.common.collect.Iterables;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import org.apache.beam.runners.core.GroupAlsoByWindowsDoFn;
+import org.apache.beam.runners.core.OutputWindowedValue;
+import org.apache.beam.runners.core.ReduceFnRunner;
+import org.apache.beam.runners.core.SystemReduceFn;
+import org.apache.beam.runners.core.triggers.ExecutableTriggerStateMachine;
+import org.apache.beam.runners.core.triggers.TriggerStateMachines;
+import org.apache.beam.runners.spark.aggregators.NamedAggregators;
+import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.Sum;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.util.SideInputReader;
+import org.apache.beam.sdk.util.TimerInternals;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.InMemoryTimerInternals;
+import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateInternalsFactory;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.Accumulator;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.joda.time.Instant;
+
+
+
+/**
+ * An implementation of {@link 
org.apache.beam.runners.core.GroupAlsoByWindowsViaOutputBufferDoFn}
+ * for the Spark runner.
+ */
+public class SparkGroupAlsoByWindowFn<K, InputT, W extends BoundedWindow>
+    implements FlatMapFunction<WindowedValue<KV<K, 
Iterable<WindowedValue<InputT>>>>,
+        WindowedValue<KV<K, Iterable<InputT>>>> {
+
+  private final WindowingStrategy<?, W> windowingStrategy;
+  private final StateInternalsFactory<K> stateInternalsFactory;
+  private final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, 
W> reduceFn;
+  private final SparkRuntimeContext runtimeContext;
+  private final Aggregator<Long, Long> droppedDueToClosedWindow;
+
+
+  public SparkGroupAlsoByWindowFn(
+      WindowingStrategy<?, W> windowingStrategy,
+      StateInternalsFactory<K> stateInternalsFactory,
+      SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> 
reduceFn,
+      SparkRuntimeContext runtimeContext,
+      Accumulator<NamedAggregators> accumulator) {
+    this.windowingStrategy = windowingStrategy;
+    this.stateInternalsFactory = stateInternalsFactory;
+    this.reduceFn = reduceFn;
+    this.runtimeContext = runtimeContext;
+
+    droppedDueToClosedWindow = runtimeContext.createAggregator(
+        accumulator,
+        GroupAlsoByWindowsDoFn.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER,
+        new Sum.SumLongFn());
+  }
+
+  @Override
+  public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(
+      WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>> windowedValue) 
throws Exception {
+    K key = windowedValue.getValue().getKey();
+    Iterable<WindowedValue<InputT>> inputs = 
windowedValue.getValue().getValue();
+
+    //------ based on GroupAlsoByWindowsViaOutputBufferDoFn ------//
+
+    // Used with Batch, we know that all the data is available for this key. 
We can't use the
+    // timer manager from the context because it doesn't exist. So we create 
one and emulate the
+    // watermark, knowing that we have all data and it is in timestamp order.
+    InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
+    timerInternals.advanceProcessingTime(Instant.now());
+    timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+    StateInternals<K> stateInternals = 
stateInternalsFactory.stateInternalsForKey(key);
+    GABWOutputWindowedValue<K, InputT> outputter = new 
GABWOutputWindowedValue<>();
+
+    ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner =
+        new ReduceFnRunner<>(
+            key,
+            windowingStrategy,
+            ExecutableTriggerStateMachine.create(
+                
TriggerStateMachines.stateMachineForTrigger(windowingStrategy.getTrigger())),
+            stateInternals,
+            timerInternals,
+            outputter,
+            new SideInputReader() {
+                @Override
+                public <T> T get(PCollectionView<T> view, BoundedWindow 
sideInputWindow) {
+                  throw new UnsupportedOperationException(
+                      "GroupAlsoByWindow must not have side inputs");
+                }
+
+                @Override
+                public <T> boolean contains(PCollectionView<T> view) {
+                  throw new UnsupportedOperationException(
+                      "GroupAlsoByWindow must not have side inputs");
+                }
+
+                @Override
+                public boolean isEmpty() {
+                  throw new UnsupportedOperationException(
+                      "GroupAlsoByWindow must not have side inputs");
+                }
+              },
+            droppedDueToClosedWindow,
+            reduceFn,
+            runtimeContext.getPipelineOptions());
+
+    Iterable<List<WindowedValue<InputT>>> chunks = Iterables.partition(inputs, 
1000);
+    for (Iterable<WindowedValue<InputT>> chunk : chunks) {
+      // Process the chunk of elements.
+      reduceFnRunner.processElements(chunk);
+
+      // Then, since elements are sorted by their timestamp, advance the input 
watermark
+      // to the first element.
+      
timerInternals.advanceInputWatermark(chunk.iterator().next().getTimestamp());
+      // Advance the processing times.
+      timerInternals.advanceProcessingTime(Instant.now());
+      timerInternals.advanceSynchronizedProcessingTime(Instant.now());
+
+      // Fire all the eligible timers.
+      fireEligibleTimers(timerInternals, reduceFnRunner);
+
+      // Leave the output watermark undefined. Since there's no late data in 
batch mode
+      // there's really no need to track it as we do for streaming.
+    }
+
+    // Finish any pending windows by advancing the input watermark to infinity.
+    timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+    // Finally, advance the processing time to infinity to fire any timers.
+    timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+    
timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
+
+    fireEligibleTimers(timerInternals, reduceFnRunner);
+
+    reduceFnRunner.persist();
+
+    return outputter.getOutputs();
+  }
+
+  private void fireEligibleTimers(InMemoryTimerInternals timerInternals,
+      ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner) throws 
Exception {
+    List<TimerInternals.TimerData> timers = new ArrayList<>();
+    while (true) {
+        TimerInternals.TimerData timer;
+        while ((timer = timerInternals.removeNextEventTimer()) != null) {
+          timers.add(timer);
+        }
+        while ((timer = timerInternals.removeNextProcessingTimer()) != null) {
+          timers.add(timer);
+        }
+        while ((timer = 
timerInternals.removeNextSynchronizedProcessingTimer()) != null) {
+          timers.add(timer);
+        }
+        if (timers.isEmpty()) {
+          break;
+        }
+        reduceFnRunner.onTimers(timers);
+        timers.clear();
+    }
+  }
+
+  private static class GABWOutputWindowedValue<K, V>
+      implements OutputWindowedValue<KV<K, Iterable<V>>> {
+    private final List<WindowedValue<KV<K, Iterable<V>>>> outputs = new 
ArrayList<>();
+
+    @Override
+    public void outputWindowedValue(
+        KV<K, Iterable<V>> output,
+        Instant timestamp,
+        Collection<? extends BoundedWindow> windows,
+        PaneInfo pane) {
+      outputs.add(WindowedValue.of(output, timestamp, windows, pane));
+    }
+
+    @Override
+    public <SideOutputT> void sideOutputWindowedValue(
+        TupleTag<SideOutputT> tag,
+        SideOutputT output,
+        Instant timestamp,
+        Collection<? extends BoundedWindow> windows, PaneInfo pane) {
+      throw new UnsupportedOperationException("GroupAlsoByWindow should not 
use side outputs.");
+    }
+
+    Iterable<WindowedValue<KV<K, Iterable<V>>>> getOutputs() {
+      return outputs;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
index efd8202..3a31cae 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkProcessContext.java
@@ -25,6 +25,8 @@ import java.util.Iterator;
 import org.apache.beam.runners.core.DoFnRunner;
 import org.apache.beam.runners.core.DoFnRunners.OutputManager;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.ExecutionContext.StepContext;
 import org.apache.beam.sdk.util.TimerInternals;
@@ -38,13 +40,16 @@ import org.apache.beam.sdk.values.TupleTag;
  */
 class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
 
+  private final DoFn<FnInputT, FnOutputT> doFn;
   private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
   private final SparkOutputManager<OutputT> outputManager;
 
   SparkProcessContext(
+      DoFn<FnInputT, FnOutputT> doFn,
       DoFnRunner<FnInputT, FnOutputT> doFnRunner,
       SparkOutputManager<OutputT> outputManager) {
 
+    this.doFn = doFn;
     this.doFnRunner = doFnRunner;
     this.outputManager = outputManager;
   }
@@ -52,6 +57,9 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
   Iterable<OutputT> processPartition(
       Iterator<WindowedValue<FnInputT>> partition) throws Exception {
 
+    // setup DoFn.
+    DoFnInvokers.invokerFor(doFn).invokeSetup();
+
     // skip if partition is empty.
     if (!partition.hasNext()) {
       return Lists.newArrayList();
@@ -160,6 +168,8 @@ class SparkProcessContext<FnInputT, FnOutputT, OutputT> {
             clearOutput();
             calledFinish = true;
             doFnRunner.finishBundle();
+            // teardown DoFn.
+            DoFnInvokers.invokerFor(doFn).invokeTeardown();
             outputIterator = getOutputIterator();
             continue; // try to consume outputIterator from start of loop
           }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
index 964eb37..ac91892 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java
@@ -32,7 +32,6 @@ import java.util.Map;
 import org.apache.avro.mapred.AvroKey;
 import org.apache.avro.mapreduce.AvroJob;
 import org.apache.avro.mapreduce.AvroKeyInputFormat;
-import org.apache.beam.runners.core.AssignWindowsDoFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.coders.CoderHelpers;
@@ -54,13 +53,11 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowingStrategy;
@@ -235,16 +232,15 @@ public final class TransformTranslator {
         @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<InputT>> inRDD =
             ((BoundedDataset<InputT>) 
context.borrowDataset(transform)).getRDD();
-        @SuppressWarnings("unchecked")
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) 
context.getInput(transform).getWindowingStrategy().getWindowFn();
+        WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> accum =
             SparkAggregators.getNamedAggregators(context.getSparkContext());
         Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, BroadcastHelper<?>>> 
sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
         context.putDataset(transform,
-            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, 
transform.getFn(),
-                context.getRuntimeContext(), sideInputs, windowFn))));
+            new BoundedDataset<>(inRDD.mapPartitions(new DoFnFunction<>(accum, 
doFn,
+                context.getRuntimeContext(), sideInputs, windowingStrategy))));
       }
     };
   }
@@ -259,16 +255,15 @@ public final class TransformTranslator {
         @SuppressWarnings("unchecked")
         JavaRDD<WindowedValue<InputT>> inRDD =
             ((BoundedDataset<InputT>) 
context.borrowDataset(transform)).getRDD();
-        @SuppressWarnings("unchecked")
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) 
context.getInput(transform).getWindowingStrategy().getWindowFn();
+        WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         Accumulator<NamedAggregators> accum =
             SparkAggregators.getNamedAggregators(context.getSparkContext());
         JavaPairRDD<TupleTag<?>, WindowedValue<?>> all = inRDD
             .mapPartitionsToPair(
-                new MultiDoFnFunction<>(accum, transform.getFn(), 
context.getRuntimeContext(),
+                new MultiDoFnFunction<>(accum, doFn, 
context.getRuntimeContext(),
                 transform.getMainOutputTag(), TranslationUtils.getSideInputs(
-                    transform.getSideInputs(), context), windowFn)).cache();
+                    transform.getSideInputs(), context), 
windowingStrategy)).cache();
         PCollectionTuple pct = context.getOutput(transform);
         for (Map.Entry<TupleTag<?>, PCollection<?>> e : 
pct.getAll().entrySet()) {
           @SuppressWarnings("unchecked")
@@ -508,14 +503,8 @@ public final class TransformTranslator {
         if (TranslationUtils.skipAssignWindows(transform, context)) {
           context.putDataset(transform, new BoundedDataset<>(inRDD));
         } else {
-          @SuppressWarnings("unchecked")
-          WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) 
transform.getWindowFn();
-          OldDoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn);
-          Accumulator<NamedAggregators> accum =
-              SparkAggregators.getNamedAggregators(context.getSparkContext());
-          context.putDataset(transform,
-              new BoundedDataset<>(inRDD.mapPartitions(new 
DoFnFunction<>(accum, addWindowsDoFn,
-                  context.getRuntimeContext(), null, null))));
+          context.putDataset(transform, new BoundedDataset<>(
+              inRDD.map(new SparkAssignWindowFn<>(transform.getWindowFn()))));
         }
       }
     };

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
index 00df7d4..27204ed 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java
@@ -24,7 +24,6 @@ import com.google.common.collect.Maps;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
-import org.apache.beam.runners.core.AssignWindowsDoFn;
 import org.apache.beam.runners.spark.aggregators.NamedAggregators;
 import org.apache.beam.runners.spark.aggregators.SparkAggregators;
 import org.apache.beam.runners.spark.io.ConsoleIO;
@@ -36,6 +35,7 @@ import org.apache.beam.runners.spark.translation.DoFnFunction;
 import org.apache.beam.runners.spark.translation.EvaluationContext;
 import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
 import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
+import org.apache.beam.runners.spark.translation.SparkAssignWindowFn;
 import org.apache.beam.runners.spark.translation.SparkKeyedCombineFn;
 import org.apache.beam.runners.spark.translation.SparkPipelineTranslator;
 import org.apache.beam.runners.spark.translation.SparkRuntimeContext;
@@ -51,7 +51,6 @@ import org.apache.beam.sdk.transforms.CombineWithContext;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -163,7 +162,7 @@ final class StreamingTransformTranslator {
   private static <T, W extends BoundedWindow> 
TransformEvaluator<Window.Bound<T>> window() {
     return new TransformEvaluator<Window.Bound<T>>() {
       @Override
-      public void evaluate(Window.Bound<T> transform, EvaluationContext 
context) {
+      public void evaluate(final Window.Bound<T> transform, EvaluationContext 
context) {
         @SuppressWarnings("unchecked")
         WindowFn<? super T, W> windowFn = (WindowFn<? super T, W>) 
transform.getWindowFn();
         @SuppressWarnings("unchecked")
@@ -189,16 +188,11 @@ final class StreamingTransformTranslator {
         if (TranslationUtils.skipAssignWindows(transform, context)) {
           context.putDataset(transform, new 
UnboundedDataset<>(windowedDStream));
         } else {
-          final OldDoFn<T, T> addWindowsDoFn = new 
AssignWindowsDoFn<>(windowFn);
-          final SparkRuntimeContext runtimeContext = 
context.getRuntimeContext();
           JavaDStream<WindowedValue<T>> outStream = windowedDStream.transform(
               new Function<JavaRDD<WindowedValue<T>>, 
JavaRDD<WindowedValue<T>>>() {
             @Override
             public JavaRDD<WindowedValue<T>> call(JavaRDD<WindowedValue<T>> 
rdd) throws Exception {
-              final Accumulator<NamedAggregators> accum =
-                  SparkAggregators.getNamedAggregators(new 
JavaSparkContext(rdd.context()));
-              return rdd.mapPartitions(
-                new DoFnFunction<>(accum, addWindowsDoFn, runtimeContext, 
null, null));
+              return rdd.map(new 
SparkAssignWindowFn<>(transform.getWindowFn()));
             }
           });
           context.putDataset(transform, new UnboundedDataset<>(outStream));
@@ -350,13 +344,13 @@ final class StreamingTransformTranslator {
       @Override
       public void evaluate(final ParDo.Bound<InputT, OutputT> transform,
                            final EvaluationContext context) {
-        DoFn<InputT, OutputT> doFn = transform.getNewFn();
+        final DoFn<InputT, OutputT> doFn = transform.getNewFn();
         rejectStateAndTimers(doFn);
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
         final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) 
context.getInput(transform).getWindowingStrategy().getWindowFn();
+        final WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) 
context.borrowDataset(transform)).getDStream();
 
@@ -369,7 +363,7 @@ final class StreamingTransformTranslator {
             final Accumulator<NamedAggregators> accum =
                 SparkAggregators.getNamedAggregators(new 
JavaSparkContext(rdd.context()));
             return rdd.mapPartitions(
-                new DoFnFunction<>(accum, transform.getFn(), runtimeContext, 
sideInputs, windowFn));
+                new DoFnFunction<>(accum, doFn, runtimeContext, sideInputs, 
windowingStrategy));
           }
         });
 
@@ -384,14 +378,13 @@ final class StreamingTransformTranslator {
       @Override
       public void evaluate(final ParDo.BoundMulti<InputT, OutputT> transform,
                            final EvaluationContext context) {
-        DoFn<InputT, OutputT> doFn = transform.getNewFn();
+        final DoFn<InputT, OutputT> doFn = transform.getNewFn();
         rejectStateAndTimers(doFn);
         final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
         final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, 
BroadcastHelper<?>>> sideInputs =
             TranslationUtils.getSideInputs(transform.getSideInputs(), context);
-        @SuppressWarnings("unchecked")
-        final WindowFn<Object, ?> windowFn =
-            (WindowFn<Object, ?>) 
context.getInput(transform).getWindowingStrategy().getWindowFn();
+        final WindowingStrategy<?, ?> windowingStrategy =
+            context.getInput(transform).getWindowingStrategy();
         @SuppressWarnings("unchecked")
         JavaDStream<WindowedValue<InputT>> dStream =
             ((UnboundedDataset<InputT>) 
context.borrowDataset(transform)).getDStream();
@@ -403,8 +396,8 @@ final class StreamingTransformTranslator {
               JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
             final Accumulator<NamedAggregators> accum =
                 SparkAggregators.getNamedAggregators(new 
JavaSparkContext(rdd.context()));
-            return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, 
transform.getFn(),
-                runtimeContext, transform.getMainOutputTag(), sideInputs, 
windowFn));
+            return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(accum, doFn,
+                runtimeContext, transform.getMainOutputTag(), sideInputs, 
windowingStrategy));
           }
         }).cache();
         PCollectionTuple pct = context.getOutput(transform);
@@ -423,8 +416,8 @@ final class StreamingTransformTranslator {
     };
   }
 
-  private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> 
EVALUATORS = Maps
-      .newHashMap();
+  private static final Map<Class<? extends PTransform>, TransformEvaluator<?>> 
EVALUATORS =
+      Maps.newHashMap();
 
   static {
     EVALUATORS.put(Read.Unbounded.class, readUnbounded());

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ffed3e0/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
----------------------------------------------------------------------
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
index 471ec92..0284b3d 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/PAssertStreaming.java
@@ -27,8 +27,8 @@ import org.apache.beam.runners.spark.SparkPipelineResult;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.OldDoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Values;
@@ -55,11 +55,12 @@ public final class PAssertStreaming implements Serializable 
{
    * Note that it is oblivious to windowing, so the assertion will apply 
indiscriminately to all
    * windows.
    */
-  public static <T> SparkPipelineResult runAndAssertContents(Pipeline p,
-                                                          PCollection<T> 
actual,
-                                                          T[] expected,
-                                                          Duration timeout,
-                                                          boolean 
stopGracefully) {
+  public static <T> SparkPipelineResult runAndAssertContents(
+      Pipeline p,
+      PCollection<T> actual,
+      T[] expected,
+      Duration timeout,
+      boolean stopGracefully) {
     // Because PAssert does not support non-global windowing, but all our data 
is in one window,
     // we set up the assertion directly.
     actual
@@ -86,14 +87,15 @@ public final class PAssertStreaming implements Serializable 
{
    * Default to stop gracefully so that tests will finish processing even if 
slower for reasons
    * such as a slow runtime environment.
    */
-  public static <T> SparkPipelineResult runAndAssertContents(Pipeline p,
-                                                          PCollection<T> 
actual,
-                                                          T[] expected,
-                                                          Duration timeout) {
+  public static <T> SparkPipelineResult runAndAssertContents(
+      Pipeline p,
+      PCollection<T> actual,
+      T[] expected,
+      Duration timeout) {
     return runAndAssertContents(p, actual, expected, timeout, true);
   }
 
-  private static class AssertDoFn<T> extends OldDoFn<Iterable<T>, Void> {
+  private static class AssertDoFn<T> extends DoFn<Iterable<T>, Void> {
     private final Aggregator<Integer, Integer> success =
         createAggregator(PAssert.SUCCESS_COUNTER, new Sum.SumIntegerFn());
     private final Aggregator<Integer, Integer> failure =
@@ -104,7 +106,7 @@ public final class PAssertStreaming implements Serializable 
{
       this.expected = expected;
     }
 
-    @Override
+    @ProcessElement
     public void processElement(ProcessContext c) throws Exception {
       try {
         assertThat(c.element(), containsInAnyOrder(expected));


Reply via email to