Add support for Stateful ParDo in the Direct runner

This adds overrides and new evaluators to ensure that
state is accessed in a single-threaded manner per key
and is cleaned up when a window expires.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/ec2c0e06
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/ec2c0e06
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/ec2c0e06

Branch: refs/heads/master
Commit: ec2c0e0698c1380b309a609eb642aba445c77e27
Parents: 7e158e4
Author: Kenneth Knowles <k...@google.com>
Authored: Wed Nov 9 21:59:15 2016 -0800
Committer: Kenneth Knowles <k...@google.com>
Committed: Mon Nov 28 11:48:32 2016 -0800

----------------------------------------------------------------------
 .../beam/runners/direct/EvaluationContext.java  |  15 +
 .../beam/runners/direct/ParDoEvaluator.java     |  11 +-
 .../runners/direct/ParDoEvaluatorFactory.java   |  53 +++-
 .../direct/ParDoMultiOverrideFactory.java       |  76 ++++-
 .../ParDoSingleViaMultiOverrideFactory.java     |   6 +-
 .../direct/StatefulParDoEvaluatorFactory.java   | 256 ++++++++++++++++
 .../direct/TransformEvaluatorRegistry.java      |   2 +
 .../direct/WatermarkCallbackExecutor.java       |  34 +++
 .../StatefulParDoEvaluatorFactoryTest.java      | 300 +++++++++++++++++++
 .../org/apache/beam/sdk/transforms/DoFn.java    |   4 +-
 .../org/apache/beam/sdk/transforms/OldDoFn.java |   8 +-
 11 files changed, 741 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
index c1225f6..201aaed 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EvaluationContext.java
@@ -296,6 +296,21 @@ class EvaluationContext {
     fireAvailableCallbacks(lookupProducing(value));
   }
 
+  /**
+   * Schedule a callback to be executed after the given window is expired.
+   *
+   * <p>For example, upstream state associated with the window may be cleared.
+   */
+  public void scheduleAfterWindowExpiration(
+      AppliedPTransform<?, ?, ?> producing,
+      BoundedWindow window,
+      WindowingStrategy<?, ?> windowingStrategy,
+      Runnable runnable) {
+    callbackExecutor.callOnWindowExpiration(producing, window, 
windowingStrategy, runnable);
+
+    fireAvailableCallbacks(producing);
+  }
+
   private AppliedPTransform<?, ?, ?> getProducing(PValue value) {
     if (value.getProducingTransformInternal() != null) {
       return value.getProducingTransformInternal();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
index 3285c7e..750e5f1 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java
@@ -42,6 +42,7 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 
 class ParDoEvaluator<InputT, OutputT> implements TransformEvaluator<InputT> {
+
   public static <InputT, OutputT> ParDoEvaluator<InputT, OutputT> create(
       EvaluationContext evaluationContext,
       DirectStepContext stepContext,
@@ -84,11 +85,17 @@ class ParDoEvaluator<InputT, OutputT> implements 
TransformEvaluator<InputT> {
     }
 
     return new ParDoEvaluator<>(
-        runner, application, aggregatorChanges, outputBundles.values(), 
stepContext);
+        evaluationContext,
+        runner,
+        application,
+        aggregatorChanges,
+        outputBundles.values(),
+        stepContext);
   }
 
   
////////////////////////////////////////////////////////////////////////////////////////////////
 
+  private final EvaluationContext evaluationContext;
   private final PushbackSideInputDoFnRunner<InputT, ?> fnRunner;
   private final AppliedPTransform<?, ?, ?> transform;
   private final AggregatorContainer.Mutator aggregatorChanges;
@@ -98,11 +105,13 @@ class ParDoEvaluator<InputT, OutputT> implements 
TransformEvaluator<InputT> {
   private final ImmutableList.Builder<WindowedValue<InputT>> 
unprocessedElements;
 
   private ParDoEvaluator(
+      EvaluationContext evaluationContext,
       PushbackSideInputDoFnRunner<InputT, ?> fnRunner,
       AppliedPTransform<?, ?, ?> transform,
       AggregatorContainer.Mutator aggregatorChanges,
       Collection<UncommittedBundle<?>> outputBundles,
       DirectStepContext stepContext) {
+    this.evaluationContext = evaluationContext;
     this.fnRunner = fnRunner;
     this.transform = transform;
     this.outputBundles = outputBundles;

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
index b776da1..02e034a 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java
@@ -20,14 +20,16 @@ package org.apache.beam.runners.direct;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
+import java.util.List;
 import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.ParDo.BoundMulti;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -54,10 +56,26 @@ final class ParDoEvaluatorFactory<InputT, OutputT> 
implements TransformEvaluator
   @Override
   public <T> TransformEvaluator<T> forApplication(
       AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) 
throws Exception {
+
+    AppliedPTransform<PCollection<InputT>, PCollectionTuple, 
ParDo.BoundMulti<InputT, OutputT>>
+        parDoApplication =
+            (AppliedPTransform<
+                    PCollection<InputT>, PCollectionTuple, 
ParDo.BoundMulti<InputT, OutputT>>)
+                application;
+
+    ParDo.BoundMulti<InputT, OutputT> transform = 
parDoApplication.getTransform();
+    final DoFn<InputT, OutputT> doFn = transform.getNewFn();
+
     @SuppressWarnings({"unchecked", "rawtypes"})
     TransformEvaluator<T> evaluator =
         (TransformEvaluator<T>)
-            createEvaluator((AppliedPTransform) application, (CommittedBundle) 
inputBundle);
+            createEvaluator(
+                (AppliedPTransform) application,
+                inputBundle.getKey(),
+                doFn,
+                transform.getSideInputs(),
+                transform.getMainOutputTag(),
+                transform.getSideOutputTags().getAll());
     return evaluator;
   }
 
@@ -66,21 +84,32 @@ final class ParDoEvaluatorFactory<InputT, OutputT> 
implements TransformEvaluator
     DoFnLifecycleManagers.removeAllFromManagers(fnClones.asMap().values());
   }
 
+  /**
+   * Creates an evaluator for an arbitrary {@link AppliedPTransform} node, 
with the pieces of the
+   * {@link ParDo} unpacked.
+   *
+   * <p>This can thus be invoked regardless of whether the types in the {@link 
AppliedPTransform}
+   * correspond with the type in the unpacked {@link DoFn}, side inputs, and 
output tags.
+   */
   @SuppressWarnings({"unchecked", "rawtypes"})
-  private TransformEvaluator<InputT> createEvaluator(
-      AppliedPTransform<PCollection<InputT>, PCollectionTuple, 
BoundMulti<InputT, OutputT>>
-          application,
-      CommittedBundle<InputT> inputBundle)
+  TransformEvaluator<InputT> createEvaluator(
+        AppliedPTransform<PCollection<?>, PCollectionTuple, ?>
+        application,
+        StructuralKey<?> inputBundleKey,
+        DoFn<InputT, OutputT> doFn,
+        List<PCollectionView<?>> sideInputs,
+        TupleTag<OutputT> mainOutputTag,
+        List<TupleTag<?>> sideOutputTags)
       throws Exception {
     String stepName = evaluationContext.getStepName(application);
     DirectStepContext stepContext =
         evaluationContext
-            .getExecutionContext(application, inputBundle.getKey())
+            .getExecutionContext(application, inputBundleKey)
             .getOrCreateStepContext(stepName, stepName);
 
-    DoFnLifecycleManager fnManager = 
fnClones.getUnchecked(application.getTransform().getNewFn());
+    DoFnLifecycleManager fnManager = fnClones.getUnchecked(doFn);
+
     try {
-      ParDo.BoundMulti<InputT, OutputT> transform = application.getTransform();
       return DoFnLifecycleManagerRemovingTransformEvaluator.wrapping(
           ParDoEvaluator.<InputT, OutputT>create(
               evaluationContext,
@@ -88,9 +117,9 @@ final class ParDoEvaluatorFactory<InputT, OutputT> 
implements TransformEvaluator
               application,
               application.getInput().getWindowingStrategy(),
               fnManager.get(),
-              transform.getSideInputs(),
-              transform.getMainOutputTag(),
-              transform.getSideOutputTags().getAll(),
+              sideInputs,
+              mainOutputTag,
+              sideOutputTags,
               application.getOutput().getAll()),
           fnManager);
     } catch (Exception e) {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 6cc3e6e..8db5159 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -18,13 +18,19 @@
 package org.apache.beam.runners.direct;
 
 import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.apache.beam.sdk.values.TypedPValue;
 
 /**
  * A {@link PTransformOverrideFactory} that provides overrides for 
applications of a {@link ParDo}
@@ -42,10 +48,74 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
 
     DoFn<InputT, OutputT> fn = transform.getNewFn();
     DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass());
-    if (!signature.processElement().isSplittable()) {
-      return transform;
-    } else {
+    if (signature.processElement().isSplittable()) {
       return new SplittableParDo(fn);
+    } else if (signature.stateDeclarations().size() > 0
+        || signature.timerDeclarations().size() > 0) {
+
+      // Based on the fact that the signature is stateful, DoFnSignatures 
ensures
+      // that it is also keyed
+      ParDo.BoundMulti<KV<?, ?>, OutputT> keyedTransform =
+          (ParDo.BoundMulti<KV<?, ?>, OutputT>) transform;
+
+      return new GbkThenStatefulParDo(keyedTransform);
+    } else {
+      return transform;
+    }
+  }
+
+  static class GbkThenStatefulParDo<K, InputT, OutputT>
+      extends PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> {
+    private final ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo;
+
+    public GbkThenStatefulParDo(ParDo.BoundMulti<KV<K, InputT>, OutputT> 
underlyingParDo) {
+      this.underlyingParDo = underlyingParDo;
+    }
+
+    @Override
+    public PCollectionTuple apply(PCollection<KV<K, InputT>> input) {
+
+      PCollectionTuple outputs = input
+          .apply("Group by key", GroupByKey.<K, InputT>create())
+          .apply("Stateful ParDo", new StatefulParDo<>(underlyingParDo, 
input));
+
+      return outputs;
+    }
+  }
+
+  static class StatefulParDo<K, InputT, OutputT>
+      extends PTransform<PCollection<? extends KV<K, Iterable<InputT>>>, 
PCollectionTuple> {
+    private final transient ParDo.BoundMulti<KV<K, InputT>, OutputT> 
underlyingParDo;
+    private final transient PCollection<KV<K, InputT>> originalInput;
+
+    public StatefulParDo(
+        ParDo.BoundMulti<KV<K, InputT>, OutputT> underlyingParDo,
+        PCollection<KV<K, InputT>> originalInput) {
+      this.underlyingParDo = underlyingParDo;
+      this.originalInput = originalInput;
+    }
+
+    public ParDo.BoundMulti<KV<K, InputT>, OutputT> getUnderlyingParDo() {
+      return underlyingParDo;
+    }
+
+    @Override
+    public <T> Coder<T> getDefaultOutputCoder(
+        PCollection<? extends KV<K, Iterable<InputT>>> input, TypedPValue<T> 
output)
+        throws CannotProvideCoderException {
+      return underlyingParDo.getDefaultOutputCoder(originalInput, output);
+    }
+
+    public PCollectionTuple apply(PCollection<? extends KV<K, 
Iterable<InputT>>> input) {
+
+      PCollectionTuple outputs = PCollectionTuple.ofPrimitiveOutputsInternal(
+          input.getPipeline(),
+          TupleTagList.of(underlyingParDo.getMainOutputTag())
+              .and(underlyingParDo.getSideOutputTags().getAll()),
+          input.getWindowingStrategy(),
+          input.isBounded());
+
+      return outputs;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
index ee3dfc5..f220a46 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoSingleViaMultiOverrideFactory.java
@@ -54,13 +54,15 @@ class ParDoSingleViaMultiOverrideFactory<InputT, OutputT>
       // Output tags for ParDo need only be unique up to applied transform
       TupleTag<OutputT> mainOutputTag = new TupleTag<OutputT>(MAIN_OUTPUT_TAG);
 
-      PCollectionTuple output =
+      PCollectionTuple outputs =
           input.apply(
               ParDo.of(underlyingParDo.getNewFn())
                   .withSideInputs(underlyingParDo.getSideInputs())
                   .withOutputTags(mainOutputTag, TupleTagList.empty()));
+      PCollection<OutputT> output = outputs.get(mainOutputTag);
 
-      return output.get(mainOutputTag);
+      
output.setTypeDescriptorInternal(underlyingParDo.getNewFn().getOutputTypeDescriptor());
+      return output;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
new file mode 100644
index 0000000..1f3286c
--- /dev/null
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java
@@ -0,0 +1,256 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.direct;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.google.common.collect.Lists;
+import java.util.Collections;
+import org.apache.beam.runners.direct.DirectExecutionContext.DirectStepContext;
+import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.StateNamespace;
+import org.apache.beam.sdk.util.state.StateNamespaces;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateTag;
+import org.apache.beam.sdk.util.state.StateTags;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+
+/** A {@link TransformEvaluatorFactory} for stateful {@link ParDo}. */
+final class StatefulParDoEvaluatorFactory<K, InputT, OutputT> implements 
TransformEvaluatorFactory {
+
+  private final LoadingCache<AppliedPTransformOutputKeyAndWindow<K, InputT, 
OutputT>, Runnable>
+      cleanupRegistry;
+
+  private final ParDoEvaluatorFactory<KV<K, InputT>, OutputT> delegateFactory;
+
+  StatefulParDoEvaluatorFactory(EvaluationContext evaluationContext) {
+    this.delegateFactory = new ParDoEvaluatorFactory<>(evaluationContext);
+    this.cleanupRegistry =
+        CacheBuilder.newBuilder()
+            .weakValues()
+            .build(new CleanupSchedulingLoader(evaluationContext));
+  }
+
+  @Override
+  public <T> TransformEvaluator<T> forApplication(
+      AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) 
throws Exception {
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    TransformEvaluator<T> evaluator =
+        (TransformEvaluator<T>)
+            createEvaluator((AppliedPTransform) application, (CommittedBundle) 
inputBundle);
+    return evaluator;
+  }
+
+  @Override
+  public void cleanup() throws Exception {
+    delegateFactory.cleanup();
+  }
+
+  @SuppressWarnings({"unchecked", "rawtypes"})
+  private TransformEvaluator<KV<K, Iterable<InputT>>> createEvaluator(
+      AppliedPTransform<
+              PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+              StatefulParDo<K, InputT, OutputT>>
+          application,
+      CommittedBundle<KV<K, Iterable<InputT>>> inputBundle)
+      throws Exception {
+
+    final DoFn<KV<K, InputT>, OutputT> doFn =
+        application.getTransform().getUnderlyingParDo().getNewFn();
+    final DoFnSignature signature = 
DoFnSignatures.getSignature(doFn.getClass());
+
+    // If the DoFn is stateful, schedule state clearing.
+    // It is semantically correct to schedule any number of redundant clear 
tasks; the
+    // cache is used to limit the number of tasks to avoid performance 
degradation.
+    if (signature.stateDeclarations().size() > 0) {
+      for (final WindowedValue<?> element : inputBundle.getElements()) {
+        for (final BoundedWindow window : element.getWindows()) {
+          cleanupRegistry.get(
+              AppliedPTransformOutputKeyAndWindow.create(
+                  application, (StructuralKey<K>) inputBundle.getKey(), 
window));
+        }
+      }
+    }
+
+    TransformEvaluator<KV<K, InputT>> delegateEvaluator =
+        delegateFactory.createEvaluator(
+            (AppliedPTransform) application,
+            inputBundle.getKey(),
+            doFn,
+            application.getTransform().getUnderlyingParDo().getSideInputs(),
+            application.getTransform().getUnderlyingParDo().getMainOutputTag(),
+            
application.getTransform().getUnderlyingParDo().getSideOutputTags().getAll());
+
+    return new StatefulParDoEvaluator<>(delegateEvaluator);
+  }
+
+  private class CleanupSchedulingLoader
+      extends CacheLoader<AppliedPTransformOutputKeyAndWindow<K, InputT, 
OutputT>, Runnable> {
+
+    private final EvaluationContext evaluationContext;
+
+    public CleanupSchedulingLoader(EvaluationContext evaluationContext) {
+      this.evaluationContext = evaluationContext;
+    }
+
+    @Override
+    public Runnable load(
+        final AppliedPTransformOutputKeyAndWindow<K, InputT, OutputT> 
transformOutputWindow) {
+      String stepName = 
evaluationContext.getStepName(transformOutputWindow.getTransform());
+
+      PCollection<?> pc =
+          transformOutputWindow
+              .getTransform()
+              .getOutput()
+              .get(
+                  transformOutputWindow
+                      .getTransform()
+                      .getTransform()
+                      .getUnderlyingParDo()
+                      .getMainOutputTag());
+      WindowingStrategy<?, ?> windowingStrategy = pc.getWindowingStrategy();
+      BoundedWindow window = transformOutputWindow.getWindow();
+      final DoFn<?, ?> doFn =
+          
transformOutputWindow.getTransform().getTransform().getUnderlyingParDo().getNewFn();
+      final DoFnSignature signature = 
DoFnSignatures.getSignature(doFn.getClass());
+
+      final DirectStepContext stepContext =
+          evaluationContext
+              .getExecutionContext(
+                  transformOutputWindow.getTransform(), 
transformOutputWindow.getKey())
+              .getOrCreateStepContext(stepName, stepName);
+
+      final StateNamespace namespace =
+          StateNamespaces.window(
+              (Coder<BoundedWindow>) 
windowingStrategy.getWindowFn().windowCoder(), window);
+
+      Runnable cleanup =
+          new Runnable() {
+            @Override
+            public void run() {
+              for (StateDeclaration stateDecl : 
signature.stateDeclarations().values()) {
+                StateTag<Object, ?> tag;
+                try {
+                  tag =
+                      StateTags.tagForSpec(stateDecl.id(), (StateSpec) 
stateDecl.field().get(doFn));
+                } catch (IllegalAccessException e) {
+                  throw new RuntimeException(
+                      String.format(
+                          "Error accessing %s for %s",
+                          StateSpec.class.getName(), 
doFn.getClass().getName()),
+                      e);
+                }
+                stepContext.stateInternals().state(namespace, tag).clear();
+              }
+              cleanupRegistry.invalidate(transformOutputWindow);
+            }
+          };
+
+      evaluationContext.scheduleAfterWindowExpiration(
+          transformOutputWindow.getTransform(), window, windowingStrategy, 
cleanup);
+      return cleanup;
+    }
+  }
+
+  @AutoValue
+  abstract static class AppliedPTransformOutputKeyAndWindow<K, InputT, 
OutputT> {
+    abstract AppliedPTransform<
+            PCollection<? extends KV<K, Iterable<InputT>>>, PCollectionTuple,
+            StatefulParDo<K, InputT, OutputT>>
+        getTransform();
+
+    abstract StructuralKey<K> getKey();
+
+    abstract BoundedWindow getWindow();
+
+    static <K, InputT, OutputT> AppliedPTransformOutputKeyAndWindow<K, InputT, 
OutputT> create(
+        AppliedPTransform<
+                PCollection<? extends KV<K, Iterable<InputT>>>, 
PCollectionTuple,
+                StatefulParDo<K, InputT, OutputT>>
+            transform,
+        StructuralKey<K> key,
+        BoundedWindow w) {
+      return new 
AutoValue_StatefulParDoEvaluatorFactory_AppliedPTransformOutputKeyAndWindow<>(
+          transform, key, w);
+    }
+  }
+
+  private static class StatefulParDoEvaluator<K, InputT>
+      implements TransformEvaluator<KV<K, Iterable<InputT>>> {
+
+    private final TransformEvaluator<KV<K, InputT>> delegateEvaluator;
+
+    public StatefulParDoEvaluator(TransformEvaluator<KV<K, InputT>> 
delegateEvaluator) {
+      this.delegateEvaluator = delegateEvaluator;
+    }
+
+    @Override
+    public void processElement(WindowedValue<KV<K, Iterable<InputT>>> 
gbkResult) throws Exception {
+
+      for (InputT value : gbkResult.getValue().getValue()) {
+        delegateEvaluator.processElement(
+            gbkResult.withValue(KV.of(gbkResult.getValue().getKey(), value)));
+      }
+    }
+
+    @Override
+    public TransformResult<KV<K, Iterable<InputT>>> finishBundle() throws 
Exception {
+      TransformResult<KV<K, InputT>> delegateResult = 
delegateEvaluator.finishBundle();
+
+      StepTransformResult.Builder<KV<K, Iterable<InputT>>> regroupedResult =
+          StepTransformResult.<KV<K, Iterable<InputT>>>withHold(
+                  delegateResult.getTransform(), 
delegateResult.getWatermarkHold())
+              .withTimerUpdate(delegateResult.getTimerUpdate())
+              .withAggregatorChanges(delegateResult.getAggregatorChanges())
+              .withMetricUpdates(delegateResult.getLogicalMetricUpdates())
+              
.addOutput(Lists.newArrayList(delegateResult.getOutputBundles()));
+
+      // The delegate may have pushed back unprocessed elements across 
multiple keys and windows.
+      // Since processing is single-threaded per key and window, we don't need 
to regroup the
+      // outputs, but just make a bunch of singletons
+      for (WindowedValue<?> untypedUnprocessed : 
delegateResult.getUnprocessedElements()) {
+        WindowedValue<KV<K, InputT>> windowedKv = (WindowedValue<KV<K, 
InputT>>) untypedUnprocessed;
+        WindowedValue<KV<K, Iterable<InputT>>> pushedBack =
+            windowedKv.withValue(
+                KV.of(
+                    windowedKv.getValue().getKey(),
+                    (Iterable<InputT>)
+                        
Collections.singletonList(windowedKv.getValue().getValue())));
+
+        regroupedResult.addUnprocessedElements(pushedBack);
+      }
+
+      return regroupedResult.build();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index 0514c3a..a4c462a 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -28,6 +28,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupAlsoByWindow;
 import org.apache.beam.runners.direct.DirectGroupByKey.DirectGroupByKeyOnly;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Flatten.FlattenPCollectionList;
@@ -50,6 +51,7 @@ class TransformEvaluatorRegistry implements 
TransformEvaluatorFactory {
             .put(Read.Bounded.class, new BoundedReadEvaluatorFactory(ctxt))
             .put(Read.Unbounded.class, new UnboundedReadEvaluatorFactory(ctxt))
             .put(ParDo.BoundMulti.class, new ParDoEvaluatorFactory<>(ctxt))
+            .put(StatefulParDo.class, new 
StatefulParDoEvaluatorFactory<>(ctxt))
             .put(FlattenPCollectionList.class, new 
FlattenEvaluatorFactory(ctxt))
             .put(ViewEvaluatorFactory.WriteView.class, new 
ViewEvaluatorFactory(ctxt))
             .put(Window.Bound.class, new WindowEvaluatorFactory(ctxt))

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
index 54cab7c..fcefc5f 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkCallbackExecutor.java
@@ -89,6 +89,32 @@ class WatermarkCallbackExecutor {
   }
 
   /**
+   * Execute the provided {@link Runnable} after the next call to
+   * {@link #fireForWatermark(AppliedPTransform, Instant)} where the window
+   * is guaranteed to be expired.
+   */
+  public void callOnWindowExpiration(
+      AppliedPTransform<?, ?, ?> step,
+      BoundedWindow window,
+      WindowingStrategy<?, ?> windowingStrategy,
+      Runnable runnable) {
+    WatermarkCallback callback =
+        WatermarkCallback.afterWindowExpiration(window, windowingStrategy, 
runnable);
+
+    PriorityQueue<WatermarkCallback> callbackQueue = callbacks.get(step);
+    if (callbackQueue == null) {
+      callbackQueue = new PriorityQueue<>(11, new CallbackOrdering());
+      if (callbacks.putIfAbsent(step, callbackQueue) != null) {
+        callbackQueue = callbacks.get(step);
+      }
+    }
+
+    synchronized (callbackQueue) {
+      callbackQueue.offer(callback);
+    }
+  }
+
+  /**
    * Schedule all pending callbacks that must have produced output by the time 
of the provided
    * watermark.
    */
@@ -112,6 +138,14 @@ class WatermarkCallbackExecutor {
       return new WatermarkCallback(firingAfter, callback);
     }
 
+    public static <W extends BoundedWindow> WatermarkCallback 
afterWindowExpiration(
+        BoundedWindow window, WindowingStrategy<?, W> strategy, Runnable 
callback) {
+      // Fire one milli past the end of the window. This ensures that all 
window expiration
+      // timers are delivered first
+      Instant firingAfter = 
window.maxTimestamp().plus(strategy.getAllowedLateness()).plus(1L);
+      return new WatermarkCallback(firingAfter, callback);
+    }
+
     private final Instant fireAfter;
     private final Runnable callback;
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
new file mode 100644
index 0000000..ecf11ed
--- /dev/null
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.nullValue;
+import static org.junit.Assert.assertThat;
+import static org.mockito.Matchers.anyList;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
+import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
+import org.apache.beam.runners.direct.ParDoMultiOverrideFactory.StatefulParDo;
+import org.apache.beam.runners.direct.WatermarkManager.TimerUpdate;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.ReadyCheckingSideInputReader;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.util.state.CopyOnAccessInMemoryStateInternals;
+import org.apache.beam.sdk.util.state.StateInternals;
+import org.apache.beam.sdk.util.state.StateNamespace;
+import org.apache.beam.sdk.util.state.StateNamespaces;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.StateTag;
+import org.apache.beam.sdk.util.state.StateTags;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Matchers;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+/** Tests for {@link StatefulParDoEvaluatorFactory}. */
+@RunWith(JUnit4.class)
+public class StatefulParDoEvaluatorFactoryTest implements Serializable {
+  @Mock private transient EvaluationContext mockEvaluationContext;
+  @Mock private transient DirectExecutionContext mockExecutionContext;
+  @Mock private transient DirectExecutionContext.DirectStepContext 
mockStepContext;
+  @Mock private transient ReadyCheckingSideInputReader mockSideInputReader;
+  @Mock private transient UncommittedBundle<Integer> mockUncommittedBundle;
+
+  private static final String KEY = "any-key";
+  private transient StateInternals<Object> stateInternals =
+      CopyOnAccessInMemoryStateInternals.<Object>withUnderlying(KEY, null);
+
+  private static final BundleFactory BUNDLE_FACTORY = 
ImmutableListBundleFactory.create();
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+    when((StateInternals<Object>) 
mockStepContext.stateInternals()).thenReturn(stateInternals);
+  }
+
+  @Test
+  public void windowCleanupScheduled() throws Exception {
+    // To test the factory, first we set up a pipeline and then we use the 
constructed
+    // pipeline to create the right parameters to pass to the factory
+    TestPipeline pipeline = TestPipeline.create();
+
+    final String stateId = "my-state-id";
+
+    // For consistency, window it into FixedWindows. Actually we will 
fabricate an input bundle.
+    PCollection<KV<String, Integer>> input =
+        pipeline
+            .apply(Create.of(KV.of("hello", 1), KV.of("hello", 2)))
+            .apply(Window.<KV<String, 
Integer>>into(FixedWindows.of(Duration.millis(10))));
+
+    PCollection<Integer> produced =
+        input.apply(
+            ParDo.of(
+                new DoFn<KV<String, Integer>, Integer>() {
+                  @StateId(stateId)
+                  private final StateSpec<Object, ValueState<String>> spec =
+                      StateSpecs.value(StringUtf8Coder.of());
+
+                  @ProcessElement
+                  public void process(ProcessContext c) {}
+                }));
+
+    StatefulParDoEvaluatorFactory<String, Integer, Integer> factory =
+        new StatefulParDoEvaluatorFactory(mockEvaluationContext);
+
+    AppliedPTransform<
+            PCollection<? extends KV<String, Iterable<Integer>>>, 
PCollectionTuple,
+            StatefulParDo<String, Integer, Integer>>
+        producingTransform = (AppliedPTransform) 
produced.getProducingTransformInternal();
+
+    // Then there will be a digging down to the step context to get the state 
internals
+    when(mockEvaluationContext.getExecutionContext(
+            eq(producingTransform), Mockito.<StructuralKey>any()))
+        .thenReturn(mockExecutionContext);
+    when(mockExecutionContext.getOrCreateStepContext(anyString(), anyString()))
+        .thenReturn(mockStepContext);
+
+    IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new 
Instant(9));
+    IntervalWindow secondWindow = new IntervalWindow(new Instant(10), new 
Instant(19));
+
+    StateNamespace firstWindowNamespace =
+        StateNamespaces.window(IntervalWindow.getCoder(), firstWindow);
+    StateNamespace secondWindowNamespace =
+        StateNamespaces.window(IntervalWindow.getCoder(), secondWindow);
+    StateTag<Object, ValueState<String>> tag =
+        StateTags.tagForSpec(stateId, StateSpecs.value(StringUtf8Coder.of()));
+
+    // Set up non-empty state. We don't mock + verify calls to clear() but 
instead
+    // check that state is actually empty. We musn't care how it is 
accomplished.
+    stateInternals.state(firstWindowNamespace, tag).write("first");
+    stateInternals.state(secondWindowNamespace, tag).write("second");
+
+    // A single bundle with some elements in the global window; it should 
register cleanup for the
+    // global window state merely by having the evaluator created. The cleanup 
logic does not
+    // depend on the window.
+    CommittedBundle<KV<String, Integer>> inputBundle =
+        BUNDLE_FACTORY
+            .createBundle(input)
+            .add(
+                WindowedValue.of(
+                    KV.of("hello", 1), new Instant(3), firstWindow, 
PaneInfo.NO_FIRING))
+            .add(
+                WindowedValue.of(
+                    KV.of("hello", 2), new Instant(11), secondWindow, 
PaneInfo.NO_FIRING))
+            .commit(Instant.now());
+
+    // Merely creating the evaluator should suffice to register the cleanup 
callback
+    factory.forApplication(producingTransform, inputBundle);
+
+    ArgumentCaptor<Runnable> argumentCaptor = 
ArgumentCaptor.forClass(Runnable.class);
+    verify(mockEvaluationContext)
+        .scheduleAfterWindowExpiration(
+            eq(producingTransform),
+            eq(firstWindow),
+            Mockito.<WindowingStrategy<?, ?>>any(),
+            argumentCaptor.capture());
+
+    // Should actually clear the state for the first window
+    argumentCaptor.getValue().run();
+    assertThat(stateInternals.state(firstWindowNamespace, tag).read(), 
nullValue());
+    assertThat(stateInternals.state(secondWindowNamespace, tag).read(), 
equalTo("second"));
+
+    verify(mockEvaluationContext)
+        .scheduleAfterWindowExpiration(
+            eq(producingTransform),
+            eq(secondWindow),
+            Mockito.<WindowingStrategy<?, ?>>any(),
+            argumentCaptor.capture());
+
+    // Should actually clear the state for the second window
+    argumentCaptor.getValue().run();
+    assertThat(stateInternals.state(secondWindowNamespace, tag).read(), 
nullValue());
+  }
+
+  /**
+   * A test that explicitly delays a side input so that the main input will 
have to be reprocessed,
+   * testing that {@code finishBundle()} re-assembles the GBK outputs 
correctly.
+   */
+  @Test
+  public void testUnprocessedElements() throws Exception {
+    // To test the factory, first we set up a pipeline and then we use the 
constructed
+    // pipeline to create the right parameters to pass to the factory
+    TestPipeline pipeline = TestPipeline.create();
+
+    final String stateId = "my-state-id";
+
+    // For consistency, window it into FixedWindows. Actually we will 
fabricate an input bundle.
+    PCollection<KV<String, Integer>> mainInput =
+        pipeline
+            .apply(Create.of(KV.of("hello", 1), KV.of("hello", 2)))
+            .apply(Window.<KV<String, 
Integer>>into(FixedWindows.of(Duration.millis(10))));
+
+    final PCollectionView<List<Integer>> sideInput =
+        pipeline
+            .apply("Create side input", Create.of(42))
+            .apply("Window side input", 
Window.<Integer>into(FixedWindows.of(Duration.millis(10))))
+            .apply("View side input", View.<Integer>asList());
+
+    PCollection<Integer> produced =
+        mainInput.apply(
+            ParDo.withSideInputs(sideInput)
+                .of(
+                    new DoFn<KV<String, Integer>, Integer>() {
+                      @StateId(stateId)
+                      private final StateSpec<Object, ValueState<String>> spec 
=
+                          StateSpecs.value(StringUtf8Coder.of());
+
+                      @ProcessElement
+                      public void process(ProcessContext c) {}
+                    }));
+
+    StatefulParDoEvaluatorFactory<String, Integer, Integer> factory =
+        new StatefulParDoEvaluatorFactory(mockEvaluationContext);
+
+    // This will be the stateful ParDo from the expansion
+    AppliedPTransform<
+            PCollection<KV<String, Iterable<Integer>>>, PCollectionTuple,
+            StatefulParDo<String, Integer, Integer>>
+        producingTransform = (AppliedPTransform) 
produced.getProducingTransformInternal();
+
+    // Then there will be a digging down to the step context to get the state 
internals
+    when(mockEvaluationContext.getExecutionContext(
+            eq(producingTransform), Mockito.<StructuralKey>any()))
+        .thenReturn(mockExecutionContext);
+    when(mockExecutionContext.getOrCreateStepContext(anyString(), anyString()))
+        .thenReturn(mockStepContext);
+    
when(mockEvaluationContext.createBundle(Matchers.<PCollection<Integer>>any()))
+        .thenReturn(mockUncommittedBundle);
+    when(mockStepContext.getTimerUpdate()).thenReturn(TimerUpdate.empty());
+
+    // And digging to check whether the window is ready
+    
when(mockEvaluationContext.createSideInputReader(anyList())).thenReturn(mockSideInputReader);
+    when(mockSideInputReader.isReady(
+            Matchers.<PCollectionView<?>>any(), Matchers.<BoundedWindow>any()))
+        .thenReturn(false);
+
+    IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new 
Instant(9));
+
+    // A single bundle with some elements in the global window; it should 
register cleanup for the
+    // global window state merely by having the evaluator created. The cleanup 
logic does not
+    // depend on the window.
+    WindowedValue<KV<String, Iterable<Integer>>> gbkOutputElement =
+        WindowedValue.of(
+            KV.<String, Iterable<Integer>>of("hello", Lists.newArrayList(1, 
13, 15)),
+            new Instant(3),
+            firstWindow,
+            PaneInfo.NO_FIRING);
+    CommittedBundle<KV<String, Iterable<Integer>>> inputBundle =
+        BUNDLE_FACTORY
+            .createBundle(producingTransform.getInput())
+            .add(gbkOutputElement)
+            .commit(Instant.now());
+    TransformEvaluator<KV<String, Iterable<Integer>>> evaluator =
+        factory.forApplication(producingTransform, inputBundle);
+    evaluator.processElement(gbkOutputElement);
+
+    // This should push back every element as a KV<String, Iterable<Integer>>
+    // in the appropriate window. Since the keys are equal they are 
single-threaded
+    TransformResult<KV<String, Iterable<Integer>>> result = 
evaluator.finishBundle();
+
+    List<Integer> pushedBackInts = new ArrayList<>();
+
+    for (WindowedValue<?> unprocessedElement : 
result.getUnprocessedElements()) {
+      WindowedValue<KV<String, Iterable<Integer>>> unprocessedKv =
+          (WindowedValue<KV<String, Iterable<Integer>>>) unprocessedElement;
+
+      assertThat(
+          Iterables.getOnlyElement(unprocessedElement.getWindows()),
+          equalTo((BoundedWindow) firstWindow));
+      assertThat(unprocessedKv.getValue().getKey(), equalTo("hello"));
+      for (Integer i : unprocessedKv.getValue().getValue()) {
+        pushedBackInts.add(i);
+      }
+    }
+    assertThat(pushedBackInts, containsInAnyOrder(1, 13, 15));
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index 221d942..3f1a3f9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -315,7 +315,7 @@ public abstract class DoFn<InputT, OutputT> implements 
Serializable, HasDisplayD
    *
    * <p>See {@link #getOutputTypeDescriptor} for more discussion.
    */
-  protected TypeDescriptor<InputT> getInputTypeDescriptor() {
+  public TypeDescriptor<InputT> getInputTypeDescriptor() {
     return new TypeDescriptor<InputT>(getClass()) {};
   }
 
@@ -330,7 +330,7 @@ public abstract class DoFn<InputT, OutputT> implements 
Serializable, HasDisplayD
    * for choosing a default output {@code Coder<O>} for the output
    * {@code PCollection<O>}.
    */
-  protected TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+  public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
     return new TypeDescriptor<OutputT>(getClass()) {};
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/ec2c0e06/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
index 9bf9003..2d2c1fd 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/OldDoFn.java
@@ -671,7 +671,7 @@ public abstract class OldDoFn<InputT, OutputT> implements 
Serializable, HasDispl
     }
 
     @Override
-    protected TypeDescriptor<InputT> getInputTypeDescriptor() {
+    public TypeDescriptor<InputT> getInputTypeDescriptor() {
       return OldDoFn.this.getInputTypeDescriptor();
     }
 
@@ -681,7 +681,7 @@ public abstract class OldDoFn<InputT, OutputT> implements 
Serializable, HasDispl
     }
 
     @Override
-    protected TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
       return OldDoFn.this.getOutputTypeDescriptor();
     }
   }
@@ -746,12 +746,12 @@ public abstract class OldDoFn<InputT, OutputT> implements 
Serializable, HasDispl
     }
 
     @Override
-    protected TypeDescriptor<InputT> getInputTypeDescriptor() {
+    public TypeDescriptor<InputT> getInputTypeDescriptor() {
       return OldDoFn.this.getInputTypeDescriptor();
     }
 
     @Override
-    protected TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
       return OldDoFn.this.getOutputTypeDescriptor();
     }
   }

Reply via email to