Get DataflowRunner out of the #apply()

Use Pipeline Surgery in the Dataflow Runner

Add additional override factories for Dataflow overrides.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a6cd8c38
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a6cd8c38
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a6cd8c38

Branch: refs/heads/master
Commit: a6cd8c3842cc1e3f72e02b5a9d1332bc6cddd35c
Parents: f5c04e1
Author: Thomas Groh <tg...@google.com>
Authored: Thu Feb 16 13:57:21 2017 -0800
Committer: Thomas Groh <tg...@google.com>
Committed: Mon Feb 27 09:06:13 2017 -0800

----------------------------------------------------------------------
 .../EmptyFlattenAsCreateFactory.java            |  72 ++++
 .../UnsupportedOverrideFactory.java             |   6 +-
 runners/google-cloud-dataflow-java/pom.xml      |   9 +
 .../dataflow/DataflowPipelineTranslator.java    |  46 +--
 .../beam/runners/dataflow/DataflowRunner.java   | 414 ++++++++++++-------
 .../dataflow/StreamingViewOverrides.java        |   1 -
 .../DataflowPipelineTranslatorTest.java         |  22 +-
 .../runners/dataflow/DataflowRunnerTest.java    |   8 +-
 .../org/apache/beam/sdk/util/NameUtils.java     |   7 +-
 9 files changed, 395 insertions(+), 190 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
new file mode 100644
index 0000000..3b29c0a
--- /dev/null
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.Flatten.FlattenPCollectionList;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TaggedPValue;
+
+/**
+ * A {@link PTransformOverrideFactory} that provides an empty {@link Create} 
to replace a {@link
+ * Flatten.FlattenPCollectionList} that takes no input {@link PCollection 
PCollections}.
+ */
+public class EmptyFlattenAsCreateFactory<T>
+    implements PTransformOverrideFactory<
+        PCollectionList<T>, PCollection<T>, Flatten.FlattenPCollectionList<T>> 
{
+  private static final EmptyFlattenAsCreateFactory<Object> INSTANCE =
+      new EmptyFlattenAsCreateFactory<>();
+
+  public static <T> EmptyFlattenAsCreateFactory<T> instance() {
+    return (EmptyFlattenAsCreateFactory<T>) INSTANCE;
+  }
+
+  private EmptyFlattenAsCreateFactory() {}
+
+  @Override
+  public PTransform<PCollectionList<T>, PCollection<T>> 
getReplacementTransform(
+      FlattenPCollectionList<T> transform) {
+    return (PTransform) Create.empty(VoidCoder.of());
+  }
+
+  @Override
+  public PCollectionList<T> getInput(
+      List<TaggedPValue> inputs, Pipeline p) {
+    checkArgument(
+        inputs.isEmpty(), "Must have an empty input to use %s", 
getClass().getSimpleName());
+    return PCollectionList.empty(p);
+  }
+
+  @Override
+  public Map<PValue, ReplacementOutput> mapOutputs(
+      List<TaggedPValue> outputs, PCollection<T> newOutput) {
+    return ReplacementOutputs.singleton(outputs, newOutput);
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
index 2072574..38cbd2a 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
@@ -60,14 +60,12 @@ public final class UnsupportedOverrideFactory<
   }
 
   @Override
-  public InputT getInput(
-      List<TaggedPValue> inputs, Pipeline p) {
+  public InputT getInput(List<TaggedPValue> inputs, Pipeline p) {
     throw new UnsupportedOperationException(message);
   }
 
   @Override
-  public Map<PValue, ReplacementOutput> mapOutputs(
-      List<TaggedPValue> outputs, OutputT newOutput) {
+  public Map<PValue, ReplacementOutput> mapOutputs(List<TaggedPValue> outputs, 
OutputT newOutput) {
     throw new UnsupportedOperationException(message);
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/google-cloud-dataflow-java/pom.xml
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/pom.xml 
b/runners/google-cloud-dataflow-java/pom.xml
index fd9c331..1e707d4 100644
--- a/runners/google-cloud-dataflow-java/pom.xml
+++ b/runners/google-cloud-dataflow-java/pom.xml
@@ -175,6 +175,11 @@
     </dependency>
 
     <dependency>
+      <groupId>org.apache.beam</groupId>
+      <artifactId>beam-runners-core-construction-java</artifactId>
+    </dependency>
+
+    <dependency>
       <groupId>com.google.api-client</groupId>
       <artifactId>google-api-client</artifactId>
     </dependency>
@@ -349,5 +354,9 @@
       <type>test-jar</type>
       <scope>test</scope>
     </dependency>
+      <dependency>
+          <groupId>org.apache.beam</groupId>
+          <artifactId>beam-runners-core-construction-java</artifactId>
+      </dependency>
   </dependencies>
 </project>

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
----------------------------------------------------------------------
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index 6eec603..c672e99 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -58,6 +58,7 @@ import java.util.Map;
 import java.util.concurrent.atomic.AtomicLong;
 import javax.annotation.Nullable;
 import 
org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly;
+import org.apache.beam.runners.dataflow.DataflowRunner.CombineGroupedValues;
 import 
org.apache.beam.runners.dataflow.TransformTranslator.StepTranslationContext;
 import org.apache.beam.runners.dataflow.TransformTranslator.TranslationContext;
 import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
@@ -72,7 +73,6 @@ import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.Combine.GroupedValues;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
@@ -405,20 +405,12 @@ public class DataflowPipelineTranslator {
       return currentTransform;
     }
 
-
-    @Override
-    public void leaveCompositeTransform(TransformHierarchy.Node node) {
-    }
-
     @Override
     public void visitPrimitiveTransform(TransformHierarchy.Node node) {
       PTransform<?, ?> transform = node.getTransform();
-      TransformTranslator translator =
-          getTransformTranslator(transform.getClass());
-      if (translator == null) {
-        throw new IllegalStateException(
-            "no translator registered for " + transform);
-      }
+      TransformTranslator translator = 
getTransformTranslator(transform.getClass());
+      checkState(
+          translator != null, "no translator registered for primitive 
transform %s", transform);
       LOG.debug("Translating {}", transform);
       currentTransform = node.toAppliedPTransform();
       translator.translate(transform, this);
@@ -718,32 +710,36 @@ public class DataflowPipelineTranslator {
         });
 
     DataflowPipelineTranslator.registerTransformTranslator(
-        Combine.GroupedValues.class,
-        new TransformTranslator<GroupedValues>() {
+        DataflowRunner.CombineGroupedValues.class,
+        new TransformTranslator<CombineGroupedValues>() {
           @Override
-          public void translate(
-              Combine.GroupedValues transform,
-              TranslationContext context) {
+          public void translate(CombineGroupedValues transform, 
TranslationContext context) {
             translateHelper(transform, context);
           }
 
           private <K, InputT, OutputT> void translateHelper(
-              final Combine.GroupedValues<K, InputT, OutputT> transform,
+              final CombineGroupedValues<K, InputT, OutputT> 
primitiveTransform,
               TranslationContext context) {
-            StepTranslationContext stepContext = context.addStep(transform, 
"CombineValues");
+            Combine.GroupedValues<K, InputT, OutputT> originalTransform =
+                primitiveTransform.getOriginalCombine();
+            StepTranslationContext stepContext =
+                context.addStep(primitiveTransform, "CombineValues");
             translateInputs(
-                stepContext, context.getInput(transform), 
transform.getSideInputs(), context);
+                stepContext,
+                context.getInput(primitiveTransform),
+                originalTransform.getSideInputs(),
+                context);
 
             AppliedCombineFn<? super K, ? super InputT, ?, OutputT> fn =
-                transform.getAppliedFn(
-                    
context.getInput(transform).getPipeline().getCoderRegistry(),
-                    context.getInput(transform).getCoder(),
-                    context.getInput(transform).getWindowingStrategy());
+                originalTransform.getAppliedFn(
+                    
context.getInput(primitiveTransform).getPipeline().getCoderRegistry(),
+                    context.getInput(primitiveTransform).getCoder(),
+                    
context.getInput(primitiveTransform).getWindowingStrategy());
 
             stepContext.addEncodingInput(fn.getAccumulatorCoder());
             stepContext.addInput(
                 PropertyNames.SERIALIZED_FN, 
byteArrayToJsonString(serializeToByteArray(fn)));
-            stepContext.addOutput(context.getOutput(transform));
+            stepContext.addOutput(context.getOutput(primitiveTransform));
           }
         });
 

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index e5ed933..0fe3a89 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -38,6 +38,8 @@ import com.google.common.base.Joiner;
 import com.google.common.base.Strings;
 import com.google.common.base.Utf8;
 import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
 import java.io.File;
 import java.io.IOException;
 import java.io.PrintWriter;
@@ -48,6 +50,7 @@ import java.nio.channels.Channels;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -56,7 +59,11 @@ import java.util.Random;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.TreeSet;
-import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.EmptyFlattenAsCreateFactory;
+import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.ReplacementOutputs;
+import 
org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
+import org.apache.beam.runners.core.construction.UnsupportedOverrideFactory;
 import 
org.apache.beam.runners.dataflow.DataflowPipelineTranslator.JobSpecification;
 import org.apache.beam.runners.dataflow.options.DataflowPipelineDebugOptions;
 import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions;
@@ -72,27 +79,35 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.FileBasedSink;
-import org.apache.beam.sdk.io.PubsubIO;
+import org.apache.beam.sdk.io.PubsubIO.Read.PubsubBoundedReader;
+import org.apache.beam.sdk.io.PubsubIO.Write.PubsubBoundedWriter;
 import org.apache.beam.sdk.io.PubsubUnboundedSink;
 import org.apache.beam.sdk.io.PubsubUnboundedSource;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.io.Write;
+import org.apache.beam.sdk.io.Write.Bound;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
-import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.options.ValueProvider.NestedValueProvider;
+import org.apache.beam.sdk.runners.PTransformMatcher;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.transforms.Aggregator;
 import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView;
+import org.apache.beam.sdk.transforms.Combine.GroupedValues;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.transforms.View.AsIterable;
+import org.apache.beam.sdk.transforms.View.AsList;
+import org.apache.beam.sdk.transforms.View.AsMap;
+import org.apache.beam.sdk.transforms.View.AsMultimap;
+import org.apache.beam.sdk.transforms.View.AsSingleton;
 import org.apache.beam.sdk.transforms.WithKeys;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -111,12 +126,11 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollection.IsBounded;
-import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PInput;
-import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TaggedPValue;
 import org.joda.time.DateTimeUtils;
 import org.joda.time.DateTimeZone;
 import org.joda.time.format.DateTimeFormat;
@@ -151,7 +165,7 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
   private final DataflowPipelineTranslator translator;
 
   /** Custom transforms implementations. */
-  private final Map<Class<?>, Class<?>> overrides;
+  private final ImmutableMap<PTransformMatcher, PTransformOverrideFactory> 
overrides;
 
   /** A set of user defined functions to invoke at different points in 
execution. */
   private DataflowRunnerHooks hooks;
@@ -280,94 +294,166 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
     this.pcollectionsRequiringIndexedFormat = new HashSet<>();
     this.ptransformViewsWithNonDeterministicKeyCoders = new HashSet<>();
 
-    ImmutableMap.Builder<Class<?>, Class<?>> builder = ImmutableMap.<Class<?>, 
Class<?>>builder();
+    ImmutableMap.Builder<PTransformMatcher, PTransformOverrideFactory> 
ptoverrides =
+        ImmutableMap.builder();
+    // Create is implemented in terms of a Read, so it must precede the 
override to Read in
+    // streaming
+    ptoverrides.put(PTransformMatchers.emptyFlatten(), 
EmptyFlattenAsCreateFactory.instance());
     if (options.isStreaming()) {
-      builder.put(Combine.GloballyAsSingletonView.class,
-                  
StreamingViewOverrides.StreamingCombineGloballyAsSingletonView.class);
-      builder.put(View.AsMap.class, 
StreamingViewOverrides.StreamingViewAsMap.class);
-      builder.put(View.AsMultimap.class, 
StreamingViewOverrides.StreamingViewAsMultimap.class);
-      builder.put(View.AsSingleton.class, 
StreamingViewOverrides.StreamingViewAsSingleton.class);
-      builder.put(View.AsList.class, 
StreamingViewOverrides.StreamingViewAsList.class);
-      builder.put(View.AsIterable.class, 
StreamingViewOverrides.StreamingViewAsIterable.class);
-      builder.put(Read.Unbounded.class, StreamingUnboundedRead.class);
-      builder.put(Read.Bounded.class, StreamingBoundedRead.class);
       // In streaming mode must use either the custom Pubsub unbounded 
source/sink or
       // defer to Windmill's built-in implementation.
-      builder.put(PubsubIO.Read.PubsubBoundedReader.class, 
UnsupportedIO.class);
-      builder.put(PubsubIO.Write.PubsubBoundedWriter.class, 
UnsupportedIO.class);
+      for (Class<? extends DoFn> unsupported :
+          ImmutableSet.of(PubsubBoundedReader.class, 
PubsubBoundedWriter.class)) {
+        ptoverrides.put(
+            PTransformMatchers.parDoWithFnType(unsupported),
+            
UnsupportedOverrideFactory.withMessage(getUnsupportedMessage(unsupported, 
true)));
+      }
       if (options.getExperiments() == null
           || 
!options.getExperiments().contains("enable_custom_pubsub_source")) {
-        builder.put(PubsubUnboundedSource.class, StreamingPubsubIORead.class);
+        ptoverrides.put(
+            PTransformMatchers.classEqualTo(PubsubUnboundedSource.class),
+            new ReflectiveRootOverrideFactory(StreamingPubsubIORead.class, 
this));
       }
       if (options.getExperiments() == null
           || !options.getExperiments().contains("enable_custom_pubsub_sink")) {
-        builder.put(PubsubUnboundedSink.class, StreamingPubsubIOWrite.class);
+        ptoverrides.put(
+            PTransformMatchers.classEqualTo(PubsubUnboundedSink.class),
+            new StreamingPubsubIOWriteOverrideFactory(this));
       }
+      ptoverrides
+          .put(
+              // Streaming Bounded Read is implemented in terms of Streaming 
Unbounded Read, and
+              // must precede it
+              PTransformMatchers.classEqualTo(Read.Bounded.class),
+              new ReflectiveRootOverrideFactory(StreamingBoundedRead.class, 
this))
+          .put(
+              PTransformMatchers.classEqualTo(Read.Unbounded.class),
+              new ReflectiveRootOverrideFactory(StreamingUnboundedRead.class, 
this))
+          .put(
+              PTransformMatchers.classEqualTo(GloballyAsSingletonView.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  
StreamingViewOverrides.StreamingCombineGloballyAsSingletonView.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(AsMap.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  StreamingViewOverrides.StreamingViewAsMap.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(AsMultimap.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  StreamingViewOverrides.StreamingViewAsMultimap.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(AsSingleton.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  StreamingViewOverrides.StreamingViewAsSingleton.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(AsList.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  StreamingViewOverrides.StreamingViewAsList.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(AsIterable.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  StreamingViewOverrides.StreamingViewAsIterable.class, this));
     } else {
-      builder.put(Read.Unbounded.class, UnsupportedIO.class);
-      builder.put(Write.Bound.class, BatchWrite.class);
       // In batch mode must use the custom Pubsub bounded source/sink.
-      builder.put(PubsubUnboundedSource.class, UnsupportedIO.class);
-      builder.put(PubsubUnboundedSink.class, UnsupportedIO.class);
-      if (options.getExperiments() == null
-          || !options.getExperiments().contains("disable_ism_side_input")) {
-        builder.put(View.AsMap.class, BatchViewOverrides.BatchViewAsMap.class);
-        builder.put(View.AsMultimap.class, 
BatchViewOverrides.BatchViewAsMultimap.class);
-        builder.put(View.AsSingleton.class, 
BatchViewOverrides.BatchViewAsSingleton.class);
-        builder.put(View.AsList.class, 
BatchViewOverrides.BatchViewAsList.class);
-        builder.put(View.AsIterable.class, 
BatchViewOverrides.BatchViewAsIterable.class);
+      for (Class<? extends PTransform> unsupported :
+          ImmutableSet.of(PubsubUnboundedSink.class, 
PubsubUnboundedSource.class)) {
+        ptoverrides.put(
+            PTransformMatchers.classEqualTo(unsupported),
+            
UnsupportedOverrideFactory.withMessage(getUnsupportedMessage(unsupported, 
false)));
       }
+      ptoverrides.put(
+          PTransformMatchers.classEqualTo(Read.Unbounded.class),
+          UnsupportedOverrideFactory.withMessage(
+              "The DataflowRunner in batch mode does not support 
Read.Unbounded"));
+      ptoverrides
+          // Write uses views internally
+          .put(PTransformMatchers.classEqualTo(Write.Bound.class), new 
BatchWriteFactory(this))
+          .put(
+              PTransformMatchers.classEqualTo(View.AsMap.class),
+              new 
ReflectiveOneToOneOverrideFactory(BatchViewOverrides.BatchViewAsMap.class, 
this))
+          .put(
+              PTransformMatchers.classEqualTo(View.AsMultimap.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  BatchViewOverrides.BatchViewAsMultimap.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(View.AsSingleton.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  BatchViewOverrides.BatchViewAsSingleton.class, this))
+          .put(
+              PTransformMatchers.classEqualTo(View.AsList.class),
+              new 
ReflectiveOneToOneOverrideFactory(BatchViewOverrides.BatchViewAsList.class, 
this))
+          .put(
+              PTransformMatchers.classEqualTo(View.AsIterable.class),
+              new ReflectiveOneToOneOverrideFactory(
+                  BatchViewOverrides.BatchViewAsIterable.class, this));
     }
-    overrides = builder.build();
+    ptoverrides
+        // Order is important. Streaming views almost all use Combine 
internally.
+        .put(
+            PTransformMatchers.classEqualTo(Combine.GroupedValues.class),
+            new PrimitiveCombineGroupedValuesOverrideFactory());
+    overrides = ptoverrides.build();
   }
 
-  /**
-   * Applies the given transform to the input. For transforms with customized 
definitions
-   * for the Dataflow pipeline runner, the application is intercepted and 
modified here.
-   */
-  @Override
-  public <OutputT extends POutput, InputT extends PInput> OutputT apply(
-      PTransform<InputT, OutputT> transform, InputT input) {
-
-    if (Combine.GroupedValues.class.equals(transform.getClass())) {
-      // For both Dataflow runners (streaming and batch), GroupByKey and 
GroupedValues are
-      // primitives. Returning a primitive output instead of the expanded 
definition
-      // signals to the translator that translation is necessary.
-      @SuppressWarnings("unchecked")
-      PCollection<?> pc = (PCollection<?>) input;
-      @SuppressWarnings("unchecked")
-      OutputT outputT =
-          (OutputT)
-              PCollection.createPrimitiveOutputInternal(
-                  pc.getPipeline(), pc.getWindowingStrategy(), pc.isBounded());
-      return outputT;
-    } else if 
(Flatten.FlattenPCollectionList.class.equals(transform.getClass())
-        && ((PCollectionList<?>) input).size() == 0) {
-      // This can cause downstream coder inference to be screwy. Most of the 
time, that won't be
-      // hugely impactful, because there will never be any elements encoded 
with this coder;
-      // the issue stems from flattening this with another PCollection.
-      return (OutputT)
-          Pipeline.applyTransform(
-              input.getPipeline().begin(), Create.empty(VoidCoder.of()));
-    } else if (overrides.containsKey(transform.getClass())) {
-      // It is the responsibility of whoever constructs overrides to ensure 
this is type safe.
-      @SuppressWarnings("unchecked")
-      Class<PTransform<InputT, OutputT>> transformClass =
-          (Class<PTransform<InputT, OutputT>>) transform.getClass();
-
-      @SuppressWarnings("unchecked")
-      Class<PTransform<InputT, OutputT>> customTransformClass =
-          (Class<PTransform<InputT, OutputT>>) 
overrides.get(transform.getClass());
-
-      PTransform<InputT, OutputT> customTransform =
-          InstanceBuilder.ofType(customTransformClass)
-          .withArg(DataflowRunner.class, this)
-          .withArg(transformClass, transform)
+  private String getUnsupportedMessage(Class<?> unsupported, boolean 
streaming) {
+    return String.format(
+        "%s is not supported in %s",
+        NameUtils.approximateSimpleName(unsupported), streaming ? "streaming" 
: "batch");
+  }
+
+  private static class ReflectiveOneToOneOverrideFactory<
+          InputT extends PValue,
+          OutputT extends PValue,
+          TransformT extends PTransform<InputT, OutputT>>
+      extends SingleInputOutputOverrideFactory<InputT, OutputT, TransformT> {
+    private final Class<PTransform<InputT, OutputT>> replacement;
+    private final DataflowRunner runner;
+
+    private ReflectiveOneToOneOverrideFactory(
+        Class<PTransform<InputT, OutputT>> replacement, DataflowRunner runner) 
{
+      this.replacement = replacement;
+      this.runner = runner;
+    }
+
+    @Override
+    public PTransform<InputT, OutputT> getReplacementTransform(TransformT 
transform) {
+      return InstanceBuilder.ofType(replacement)
+          .withArg(DataflowRunner.class, runner)
+          .withArg((Class<PTransform<InputT, OutputT>>) transform.getClass(), 
transform)
           .build();
+    }
+  }
 
-      return Pipeline.applyTransform(input, customTransform);
-    } else {
-      return super.apply(transform, input);
+  private static class ReflectiveRootOverrideFactory<T>
+      implements PTransformOverrideFactory<
+          PBegin, PCollection<T>, PTransform<PInput, PCollection<T>>> {
+    private final Class<PTransform<PBegin, PCollection<T>>> replacement;
+    private final DataflowRunner runner;
+
+    private ReflectiveRootOverrideFactory(
+        Class<PTransform<PBegin, PCollection<T>>> replacement, DataflowRunner 
runner) {
+      this.replacement = replacement;
+      this.runner = runner;
+    }
+    @Override
+    public PTransform<PBegin, PCollection<T>> getReplacementTransform(
+        PTransform<PInput, PCollection<T>> transform) {
+      return InstanceBuilder.ofType(replacement)
+          .withArg(DataflowRunner.class, runner)
+          .withArg(
+              (Class<? super PTransform<PInput, PCollection<T>>>) 
transform.getClass(), transform)
+          .build();
+    }
+
+    @Override
+    public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) {
+      return p.begin();
+    }
+
+    @Override
+    public Map<PValue, ReplacementOutput> mapOutputs(
+        List<TaggedPValue> outputs, PCollection<T> newOutput) {
+      return ReplacementOutputs.singleton(outputs, newOutput);
     }
   }
 
@@ -419,6 +505,7 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
   @Override
   public DataflowPipelineJob run(Pipeline pipeline) {
     logWarningIfPCollectionViewHasNonDeterministicKeyCoder(pipeline);
+    replaceTransforms(pipeline);
 
     LOG.info("Executing pipeline on the Dataflow Service, which will have 
billing implications "
         + "related to Google Compute Engine usage and other Google Cloud 
Services.");
@@ -594,6 +681,13 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
     return dataflowPipelineJob;
   }
 
+  @VisibleForTesting
+  void replaceTransforms(Pipeline pipeline) {
+    for (Map.Entry<PTransformMatcher, PTransformOverrideFactory> override : 
overrides.entrySet()) {
+      pipeline.replace(override.getKey(), override.getValue());
+    }
+  }
+
   /**
    * Returns the DataflowPipelineTranslator associated with this object.
    */
@@ -677,6 +771,30 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
     ptransformViewsWithNonDeterministicKeyCoders.add(ptransform);
   }
 
+  private class BatchWriteFactory<T>
+      implements PTransformOverrideFactory<PCollection<T>, PDone, 
Write.Bound<T>> {
+    private final DataflowRunner runner;
+    private BatchWriteFactory(DataflowRunner dataflowRunner) {
+      this.runner = dataflowRunner;
+    }
+
+    @Override
+    public PTransform<PCollection<T>, PDone> getReplacementTransform(Bound<T> 
transform) {
+      return new BatchWrite<>(runner, transform);
+    }
+
+    @Override
+    public PCollection<T> getInput(List<TaggedPValue> inputs, Pipeline p) {
+      return (PCollection<T>) Iterables.getOnlyElement(inputs).getValue();
+    }
+
+    @Override
+    public Map<PValue, ReplacementOutput> mapOutputs(
+        List<TaggedPValue> outputs, PDone newOutput) {
+      return Collections.emptyMap();
+    }
+  }
+
   /**
    * Specialized implementation which overrides
    * {@link org.apache.beam.sdk.io.Write.Bound Write.Bound} to provide Google
@@ -889,7 +1007,7 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
    * <p>In particular, if an UnboundedSource requires deduplication, then 
features of WindmillSink
    * are leveraged to do the deduplication.
    */
-  private static class StreamingUnboundedRead<T> extends PTransform<PInput, 
PCollection<T>> {
+  private static class StreamingUnboundedRead<T> extends PTransform<PBegin, 
PCollection<T>> {
     private final UnboundedSource<T, ?> source;
 
     /**
@@ -906,7 +1024,7 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
     }
 
     @Override
-    public final PCollection<T> expand(PInput input) {
+    public final PCollection<T> expand(PBegin input) {
       source.validate();
 
       if (source.requiresDeduping()) {
@@ -1064,72 +1182,6 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
     }
   }
 
-  /**
-   * Specialized expansion for unsupported IO transforms and DoFns that throws 
an error.
-   */
-  private static class UnsupportedIO<InputT extends PInput, OutputT extends 
POutput>
-      extends PTransform<InputT, OutputT> {
-    @Nullable
-    private PTransform<?, ?> transform;
-    @Nullable
-    private DoFn<?, ?> doFn;
-
-    /**
-     * Builds an instance of this class from the overridden transform.
-     */
-    @SuppressWarnings("unused") // used via reflection in 
DataflowRunner#apply()
-    public UnsupportedIO(DataflowRunner runner, Read.Unbounded<?> transform) {
-      this.transform = transform;
-    }
-
-    /**
-     * Builds an instance of this class from the overridden doFn.
-     */
-    @SuppressWarnings("unused") // used via reflection in 
DataflowRunner#apply()
-    public UnsupportedIO(DataflowRunner runner,
-                         PubsubIO.Read<?>.PubsubBoundedReader doFn) {
-      this.doFn = doFn;
-    }
-
-    /**
-     * Builds an instance of this class from the overridden doFn.
-     */
-    @SuppressWarnings("unused") // used via reflection in 
DataflowRunner#apply()
-    public UnsupportedIO(DataflowRunner runner,
-                         PubsubIO.Write<?>.PubsubBoundedWriter doFn) {
-      this.doFn = doFn;
-    }
-
-    /**
-     * Builds an instance of this class from the overridden transform.
-     */
-    @SuppressWarnings("unused") // used via reflection in 
DataflowRunner#apply()
-    public UnsupportedIO(DataflowRunner runner, PubsubUnboundedSource<?> 
transform) {
-      this.transform = transform;
-    }
-
-    /**
-     * Builds an instance of this class from the overridden transform.
-     */
-    @SuppressWarnings("unused") // used via reflection in 
DataflowRunner#apply()
-    public UnsupportedIO(DataflowRunner runner, PubsubUnboundedSink<?> 
transform) {
-      this.transform = transform;
-    }
-
-
-    @Override
-    public OutputT expand(InputT input) {
-      String mode = 
input.getPipeline().getOptions().as(StreamingOptions.class).isStreaming()
-          ? "streaming" : "batch";
-      String name =
-          transform == null
-              ? NameUtils.approximateSimpleName(doFn)
-              : NameUtils.approximatePTransformName(transform.getClass());
-      throw new UnsupportedOperationException(
-          String.format("The DataflowRunner in %s mode does not support %s.", 
mode, name));
-    }
-  }
-
   @Override
   public String toString() {
     return "DataflowRunner#" + options.getJobName();
@@ -1192,4 +1244,72 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
 
     throw new IllegalArgumentException("Could not find running job named " + 
jobName);
   }
+
+  static class CombineGroupedValues<K, InputT, OutputT>
+      extends PTransform<PCollection<KV<K, Iterable<InputT>>>, 
PCollection<KV<K, OutputT>>> {
+    private final Combine.GroupedValues<K, InputT, OutputT> original;
+
+    CombineGroupedValues(GroupedValues<K, InputT, OutputT> original) {
+      this.original = original;
+    }
+
+    @Override
+    public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, 
Iterable<InputT>>> input) {
+      return PCollection.createPrimitiveOutputInternal(
+          input.getPipeline(), input.getWindowingStrategy(), 
input.isBounded());
+    }
+
+    public Combine.GroupedValues<K, InputT, OutputT> getOriginalCombine() {
+      return original;
+    }
+  }
+
+  private static class PrimitiveCombineGroupedValuesOverrideFactory<K, InputT, 
OutputT>
+      implements PTransformOverrideFactory<
+          PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>,
+          Combine.GroupedValues<K, InputT, OutputT>> {
+    @Override
+    public PTransform<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, 
OutputT>>>
+        getReplacementTransform(GroupedValues<K, InputT, OutputT> transform) {
+      return new CombineGroupedValues<>(transform);
+    }
+
+    @Override
+    public PCollection<KV<K, Iterable<InputT>>> getInput(
+        List<TaggedPValue> inputs, Pipeline p) {
+      return (PCollection<KV<K, Iterable<InputT>>>) 
Iterables.getOnlyElement(inputs).getValue();
+    }
+
+    @Override
+    public Map<PValue, ReplacementOutput> mapOutputs(
+        List<TaggedPValue> outputs, PCollection<KV<K, OutputT>> newOutput) {
+      return ReplacementOutputs.singleton(outputs, newOutput);
+    }
+  }
+
+  private class StreamingPubsubIOWriteOverrideFactory<T>
+      implements PTransformOverrideFactory<PCollection<T>, PDone, 
PubsubUnboundedSink<T>> {
+    private final DataflowRunner runner;
+
+    private StreamingPubsubIOWriteOverrideFactory(DataflowRunner runner) {
+      this.runner = runner;
+    }
+
+    @Override
+    public PTransform<PCollection<T>, PDone> getReplacementTransform(
+        PubsubUnboundedSink<T> transform) {
+      return new StreamingPubsubIOWrite<>(runner, transform);
+    }
+
+    @Override
+    public PCollection<T> getInput(List<TaggedPValue> inputs, Pipeline p) {
+      return (PCollection<T>) Iterables.getOnlyElement(inputs).getValue();
+    }
+
+    @Override
+    public Map<PValue, ReplacementOutput> mapOutputs(List<TaggedPValue> 
outputs, PDone newOutput) {
+      return Collections.emptyMap();
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
----------------------------------------------------------------------
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
index 6bd0cca..bab115f 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
@@ -349,5 +349,4 @@ class StreamingViewOverrides {
       return ListCoder.of(inputCoder);
     }
   }
-
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
----------------------------------------------------------------------
diff --git 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index 5d13c3e..d4271e5 100644
--- 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -456,11 +456,13 @@ public class DataflowPipelineTranslatorTest implements 
Serializable {
         .apply(ParDo.of(new NoOpFn()))
         .apply(new EmbeddedTransform(predefinedStep.clone()))
         .apply(ParDo.of(new NoOpFn()));
+    DataflowRunner runner = DataflowRunner.fromOptions(options);
+    runner.replaceTransforms(pipeline);
     Job job =
         translator
             .translate(
                 pipeline,
-                (DataflowRunner) pipeline.getRunner(),
+                runner,
                 Collections.<DataflowPackage>emptyList())
             .getJob();
     assertAllStepOutputsHaveUniqueIds(job);
@@ -511,11 +513,13 @@ public class DataflowPipelineTranslatorTest implements 
Serializable {
     pipeline.apply("ReadMyFile", TextIO.Read.from("gs://bucket/in"))
         .apply(stepName, ParDo.of(new NoOpFn()))
         .apply("WriteMyFile", TextIO.Write.to("gs://bucket/out"));
+    DataflowRunner runner = DataflowRunner.fromOptions(options);
+    runner.replaceTransforms(pipeline);
     Job job =
         translator
             .translate(
                 pipeline,
-                (DataflowRunner) pipeline.getRunner(),
+                runner,
                 Collections.<DataflowPackage>emptyList())
             .getJob();
 
@@ -833,11 +837,13 @@ public class DataflowPipelineTranslatorTest implements 
Serializable {
     Pipeline pipeline = Pipeline.create(options);
     pipeline.apply(Create.of(1))
         .apply(View.<Integer>asSingleton());
+    DataflowRunner runner = DataflowRunner.fromOptions(options);
+    runner.replaceTransforms(pipeline);
     Job job =
         translator
             .translate(
                 pipeline,
-                (DataflowRunner) pipeline.getRunner(),
+                runner,
                 Collections.<DataflowPackage>emptyList())
             .getJob();
     assertAllStepOutputsHaveUniqueIds(job);
@@ -867,13 +873,11 @@ public class DataflowPipelineTranslatorTest implements 
Serializable {
     Pipeline pipeline = Pipeline.create(options);
     pipeline.apply(Create.of(1, 2, 3))
         .apply(View.<Integer>asIterable());
+
+    DataflowRunner runner = DataflowRunner.fromOptions(options);
+    runner.replaceTransforms(pipeline);
     Job job =
-        translator
-            .translate(
-                pipeline,
-                (DataflowRunner) pipeline.getRunner(),
-                Collections.<DataflowPackage>emptyList())
-            .getJob();
+        translator.translate(pipeline, runner, 
Collections.<DataflowPackage>emptyList()).getJob();
     assertAllStepOutputsHaveUniqueIds(job);
 
     List<Step> steps = job.getSteps();

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
----------------------------------------------------------------------
diff --git 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index 4719217..a788077 100644
--- 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++ 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -560,11 +560,12 @@ public class DataflowRunnerTest {
   public void testNonGcsFilePathInWriteFailure() throws IOException {
     Pipeline p = buildDataflowPipeline(buildPipelineOptions());
 
-    PCollection<String> pc = p.apply("ReadMyGcsFile", 
TextIO.Read.from("gs://bucket/object"));
+    p.apply("ReadMyGcsFile", TextIO.Read.from("gs://bucket/object"))
+        .apply("WriteMyNonGcsFile", TextIO.Write.to("/tmp/file"));
 
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage(containsString("Expected a valid 'gs://' path but was 
given"));
-    pc.apply("WriteMyNonGcsFile", TextIO.Write.to("/tmp/file"));
+    p.run();
   }
 
   @Test
@@ -586,10 +587,11 @@ public class DataflowRunnerTest {
   public void testMultiSlashGcsFileWritePath() throws IOException {
     Pipeline p = buildDataflowPipeline(buildPipelineOptions());
     PCollection<String> pc = p.apply("ReadMyGcsFile", 
TextIO.Read.from("gs://bucket/object"));
+    pc.apply("WriteInvalidGcsFile", TextIO.Write.to("gs://bucket/tmp//file"));
 
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage("consecutive slashes");
-    pc.apply("WriteInvalidGcsFile", TextIO.Write.to("gs://bucket/tmp//file"));
+    p.run();
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/beam/blob/a6cd8c38/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NameUtils.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NameUtils.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NameUtils.java
index 72179a3..c67ccca 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NameUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/NameUtils.java
@@ -128,7 +128,12 @@ public class NameUtils {
       return ((NameOverride) object).getNameOverride();
     }
 
-    Class<?> clazz = object.getClass();
+    Class<?> clazz;
+    if (object instanceof Class) {
+      clazz = (Class<?>) object;
+    } else {
+      clazz = object.getClass();
+    }
     if (clazz.isAnonymousClass()) {
       return anonymousValue;
     }

Reply via email to