Port DirectRunner WriteFiles override to SDK-agnostic APIs

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/ed6bd18b
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/ed6bd18b
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/ed6bd18b

Branch: refs/heads/master
Commit: ed6bd18bffe8a51d5fc2a59ff9aaa731b196d58a
Parents: 02dbaef
Author: Kenneth Knowles <k...@google.com>
Authored: Fri May 26 16:07:45 2017 -0700
Committer: Kenneth Knowles <k...@google.com>
Committed: Fri Jun 9 19:56:52 2017 -0700

----------------------------------------------------------------------
 .../core/construction/PTransformMatchers.java   | 17 ++++++++---
 .../direct/WriteWithShardingFactory.java        | 30 ++++++++++++++------
 .../direct/WriteWithShardingFactoryTest.java    | 26 +++++++++++------
 3 files changed, 52 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/ed6bd18b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
index c339891..0d27241 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
@@ -17,13 +17,14 @@
  */
 package org.apache.beam.runners.core.construction;
 
+import static 
org.apache.beam.runners.core.construction.PTransformTranslation.WRITE_FILES_TRANSFORM_URN;
+
 import com.google.common.base.MoreObjects;
 import java.io.IOException;
 import java.util.HashSet;
 import java.util.Set;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
-import org.apache.beam.sdk.io.WriteFiles;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformMatcher;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -359,10 +360,18 @@ public class PTransformMatchers {
     return new PTransformMatcher() {
       @Override
       public boolean matches(AppliedPTransform<?, ?, ?> application) {
-        if (PTransformTranslation.WRITE_FILES_TRANSFORM_URN.equals(
+        if (WRITE_FILES_TRANSFORM_URN.equals(
             
PTransformTranslation.urnForTransformOrNull(application.getTransform()))) {
-          WriteFiles write = (WriteFiles) application.getTransform();
-          return write.getSharding() == null && write.getNumShards() == null;
+          try {
+            return WriteFilesTranslation.isRunnerDeterminedSharding(
+                (AppliedPTransform) application);
+          } catch (IOException exc) {
+            throw new RuntimeException(
+                String.format(
+                    "Transform with URN %s failed to parse: %s",
+                    WRITE_FILES_TRANSFORM_URN, application.getTransform()),
+                exc);
+          }
         }
         return false;
       }

http://git-wip-us.apache.org/repos/asf/beam/blob/ed6bd18b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
index 65a5a19..d8734a1 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -21,11 +21,13 @@ package org.apache.beam.runners.direct;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Supplier;
 import com.google.common.base.Suppliers;
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.Collections;
 import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
 import org.apache.beam.runners.core.construction.PTransformReplacements;
+import org.apache.beam.runners.core.construction.WriteFilesTranslation;
 import org.apache.beam.sdk.io.WriteFiles;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
@@ -43,23 +45,33 @@ import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
 
 /**
- * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles}
- * {@link PTransform PTransforms} with an unspecified number of shards with a 
write with a
- * specified number of shards. The number of shards is the log base 10 of the 
number of input
- * records, with up to 2 additional shards.
+ * A {@link PTransformOverrideFactory} that overrides {@link WriteFiles} 
{@link PTransform
+ * PTransforms} with an unspecified number of shards with a write with a 
specified number of shards.
+ * The number of shards is the log base 10 of the number of input records, 
with up to 2 additional
+ * shards.
  */
 class WriteWithShardingFactory<InputT>
-    implements PTransformOverrideFactory<PCollection<InputT>, PDone, 
WriteFiles<InputT>> {
+    implements PTransformOverrideFactory<
+        PCollection<InputT>, PDone, PTransform<PCollection<InputT>, PDone>> {
   static final int MAX_RANDOM_EXTRA_SHARDS = 3;
   @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3;
 
   @Override
   public PTransformReplacement<PCollection<InputT>, PDone> 
getReplacementTransform(
-      AppliedPTransform<PCollection<InputT>, PDone, WriteFiles<InputT>> 
transform) {
+      AppliedPTransform<PCollection<InputT>, PDone, 
PTransform<PCollection<InputT>, PDone>>
+          transform) {
 
-    return PTransformReplacement.of(
-        PTransformReplacements.getSingletonMainInput(transform),
-        transform.getTransform().withSharding(new 
LogElementShardsWithDrift<InputT>()));
+    try {
+      WriteFiles<InputT> replacement = 
WriteFiles.to(WriteFilesTranslation.getSink(transform));
+      if (WriteFilesTranslation.isWindowedWrites(transform)) {
+        replacement = replacement.withWindowedWrites();
+      }
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          replacement.withSharding(new LogElementShardsWithDrift<InputT>()));
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/ed6bd18b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
index a88d95e..41d671f 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -30,6 +30,7 @@ import static org.junit.Assert.assertThat;
 import java.io.File;
 import java.io.FileReader;
 import java.io.Reader;
+import java.io.Serializable;
 import java.nio.CharBuffer;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -53,6 +54,7 @@ import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
@@ -71,11 +73,17 @@ import org.junit.runners.JUnit4;
  * Tests for {@link WriteWithShardingFactory}.
  */
 @RunWith(JUnit4.class)
-public class WriteWithShardingFactoryTest {
+public class WriteWithShardingFactoryTest implements Serializable {
+
   private static final int INPUT_SIZE = 10000;
-  @Rule public TemporaryFolder tmp = new TemporaryFolder();
-  private WriteWithShardingFactory<Object> factory = new 
WriteWithShardingFactory<>();
-  @Rule public final TestPipeline p = 
TestPipeline.create().enableAbandonedNodeEnforcement(false);
+
+  @Rule public transient TemporaryFolder tmp = new TemporaryFolder();
+
+  private transient WriteWithShardingFactory<Object> factory = new 
WriteWithShardingFactory<>();
+
+  @Rule
+  public final transient TestPipeline p =
+      TestPipeline.create().enableAbandonedNodeEnforcement(false);
 
   @Test
   public void dynamicallyReshardedWrite() throws Exception {
@@ -135,7 +143,8 @@ public class WriteWithShardingFactoryTest {
             DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE,
             "",
             false);
-    WriteFiles<Object> original =
+
+    PTransform<PCollection<Object>, PDone> original =
         WriteFiles.to(
             new FileBasedSink<Object>(StaticValueProvider.of(outputDirectory), 
policy) {
               @Override
@@ -146,9 +155,10 @@ public class WriteWithShardingFactoryTest {
     @SuppressWarnings("unchecked")
     PCollection<Object> objs = (PCollection) 
p.apply(Create.empty(VoidCoder.of()));
 
-    AppliedPTransform<PCollection<Object>, PDone, WriteFiles<Object>> 
originalApplication =
-        AppliedPTransform.of(
-            "write", objs.expand(), Collections.<TupleTag<?>, 
PValue>emptyMap(), original, p);
+    AppliedPTransform<PCollection<Object>, PDone, 
PTransform<PCollection<Object>, PDone>>
+        originalApplication =
+            AppliedPTransform.of(
+                "write", objs.expand(), Collections.<TupleTag<?>, 
PValue>emptyMap(), original, p);
 
     assertThat(
         factory.getReplacementTransform(originalApplication).getTransform(),

Reply via email to