Honor user requested shard limits for AvroIO.Write on DirectPipelineRunner

During the migration to custom sink within AvroIO, shard controls
were removed for DirectPipelineRunner. This change adds them
back.

----Release Notes----

[]
-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=115515647


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/510a55db
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/510a55db
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/510a55db

Branch: refs/heads/master
Commit: 510a55dbbf9b6d1a94817f7e8e78e8211dd559a4
Parents: 8b5257f
Author: lcwik <lc...@google.com>
Authored: Wed Feb 24 18:01:53 2016 -0800
Committer: Davor Bonaci <davorbon...@users.noreply.github.com>
Committed: Thu Feb 25 23:58:28 2016 -0800

----------------------------------------------------------------------
 .../sdk/runners/DirectPipelineRunner.java       | 59 ++++++++++++++++++++
 .../sdk/runners/DirectPipelineRunnerTest.java   | 53 ++++++++++++++++++
 2 files changed, 112 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/510a55db/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java
----------------------------------------------------------------------
diff --git 
a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java
 
b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java
index 4543b5a..872cfef 100644
--- 
a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java
+++ 
b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java
@@ -25,6 +25,7 @@ import com.google.cloud.dataflow.sdk.PipelineResult;
 import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException;
 import com.google.cloud.dataflow.sdk.coders.Coder;
 import com.google.cloud.dataflow.sdk.coders.ListCoder;
+import com.google.cloud.dataflow.sdk.io.AvroIO;
 import com.google.cloud.dataflow.sdk.io.FileBasedSink;
 import com.google.cloud.dataflow.sdk.io.TextIO;
 import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions;
@@ -240,6 +241,8 @@ public class DirectPipelineRunner
       return (OutputT) applyTestCombine((Combine.GroupedValues) transform, 
(PCollection) input);
     } else if (transform instanceof TextIO.Write.Bound) {
       return (OutputT) applyTextIOWrite((TextIO.Write.Bound) transform, 
(PCollection<?>) input);
+    } else if (transform instanceof AvroIO.Write.Bound) {
+      return (OutputT) applyAvroIOWrite((AvroIO.Write.Bound) transform, 
(PCollection<?>) input);
     } else {
       return super.apply(transform, input);
     }
@@ -343,6 +346,62 @@ public class DirectPipelineRunner
   }
 
   /**
+   * Applies AvroIO.Write honoring user requested sharding controls (i.e. 
withNumShards)
+   * by applying a partition function based upon the number of shards the user 
requested.
+   */
+  private static class DirectAvroIOWrite<T> extends PTransform<PCollection<T>, 
PDone> {
+    private final AvroIO.Write.Bound<T> transform;
+
+    private DirectAvroIOWrite(AvroIO.Write.Bound<T> transform) {
+      this.transform = transform;
+    }
+
+    @Override
+    public PDone apply(PCollection<T> input) {
+      checkState(transform.getNumShards() > 1,
+          "DirectAvroIOWrite is expected to only be used when sharding 
controls are required.");
+
+      // Evenly distribute all the elements across the partitions.
+      PCollectionList<T> partitionedElements =
+          input.apply(Partition.of(transform.getNumShards(),
+                                   new 
ElementProcessingOrderPartitionFn<T>()));
+
+      // For each input PCollection partition, create a write transform that 
represents
+      // one of the specific shards.
+      for (int i = 0; i < transform.getNumShards(); ++i) {
+        /*
+         * This logic mirrors the file naming strategy within
+         * {@link FileBasedSink#generateDestinationFilenames()}
+         */
+        String outputFilename = IOChannelUtils.constructName(
+            transform.getFilenamePrefix(),
+            transform.getShardNameTemplate(),
+            getFileExtension(transform.getFilenameSuffix()),
+            i,
+            transform.getNumShards());
+
+        String transformName = String.format("%s(Shard:%s)", 
transform.getName(), i);
+        partitionedElements.get(i).apply(transformName,
+            
transform.withNumShards(1).withShardNameTemplate("").withSuffix("").to(outputFilename));
+      }
+      return PDone.in(input.getPipeline());
+    }
+  }
+
+  /**
+   * Apply the override for AvroIO.Write.Bound if the user requested sharding 
controls
+   * greater than one.
+   */
+  private <T> PDone applyAvroIOWrite(AvroIO.Write.Bound<T> transform, 
PCollection<T> input) {
+    if (transform.getNumShards() <= 1) {
+      // By default, the DirectPipelineRunner outputs to only 1 shard. Since 
the user never
+      // requested sharding controls greater than 1, we default to outputting 
to 1 file.
+      return super.apply(transform.withNumShards(1), input);
+    }
+    return input.apply(new DirectAvroIOWrite<>(transform));
+  }
+
+  /**
    * The implementation may split the {@link KeyedCombineFn} into ADD, MERGE 
and EXTRACT phases (
    * see {@code 
com.google.cloud.dataflow.sdk.runners.worker.CombineValuesFn}). In order to 
emulate
    * this for the {@link DirectPipelineRunner} and provide an experience 
closer to the service, go

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/510a55db/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java
----------------------------------------------------------------------
diff --git 
a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java
 
b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java
index 4a0f91c..6524e14 100644
--- 
a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java
+++ 
b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java
@@ -25,8 +25,10 @@ import static org.junit.Assert.assertThat;
 
 import com.google.cloud.dataflow.sdk.Pipeline;
 import com.google.cloud.dataflow.sdk.coders.AtomicCoder;
+import com.google.cloud.dataflow.sdk.coders.AvroCoder;
 import com.google.cloud.dataflow.sdk.coders.Coder;
 import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.io.AvroIO;
 import com.google.cloud.dataflow.sdk.io.ShardNameTemplate;
 import com.google.cloud.dataflow.sdk.io.TextIO;
 import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions;
@@ -36,8 +38,10 @@ import com.google.cloud.dataflow.sdk.transforms.Create;
 import com.google.cloud.dataflow.sdk.transforms.DoFn;
 import com.google.cloud.dataflow.sdk.transforms.ParDo;
 import com.google.cloud.dataflow.sdk.util.IOChannelUtils;
+import com.google.common.collect.Iterables;
 import com.google.common.io.Files;
 
+import org.apache.avro.file.DataFileReader;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -154,4 +158,53 @@ public class DirectPipelineRunnerTest implements 
Serializable {
 
     assertThat(allContents, containsInAnyOrder(expectedElements));
   }
+
+  @Test
+  public void testAvroIOWriteWithDefaultShardingStrategy() throws Exception {
+    String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), 
"output");
+    Pipeline p = DirectPipeline.createForTest();
+    String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", 
"g", "h", "i" };
+    p.apply(Create.of(expectedElements))
+     
.apply(AvroIO.Write.withSchema(String.class).to(prefix).withSuffix(".avro"));
+    p.run();
+
+    String filename =
+        IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, 
".avro", 0, 1);
+    List<String> fileContents = new ArrayList<>();
+    Iterables.addAll(fileContents, DataFileReader.openReader(
+        new File(filename), AvroCoder.of(String.class).createDatumReader()));
+
+    // Ensure that each file got at least one record
+    assertFalse(fileContents.isEmpty());
+
+    assertThat(fileContents, containsInAnyOrder(expectedElements));
+  }
+
+  @Test
+  public void testAvroIOWriteWithLimitedNumberOfShards() throws Exception {
+    final int numShards = 3;
+    String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), 
"shardedOutput");
+    Pipeline p = DirectPipeline.createForTest();
+    String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", 
"g", "h", "i" };
+    p.apply(Create.of(expectedElements))
+     .apply(AvroIO.Write.withSchema(String.class).to(prefix)
+                        .withNumShards(numShards).withSuffix(".avro"));
+    p.run();
+
+    List<String> allContents = new ArrayList<>();
+    for (int i = 0; i < numShards; ++i) {
+      String shardFileName =
+          IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, 
".avro", i, 3);
+      List<String> shardFileContents = new ArrayList<>();
+      Iterables.addAll(shardFileContents, DataFileReader.openReader(
+          new File(shardFileName), 
AvroCoder.of(String.class).createDatumReader()));
+
+      // Ensure that each file got at least one record
+      assertFalse(shardFileContents.isEmpty());
+
+      allContents.addAll(shardFileContents);
+    }
+
+    assertThat(allContents, containsInAnyOrder(expectedElements));
+  }
 }

Reply via email to