Honor user requested shard limits for AvroIO.Write on DirectPipelineRunner During the migration to custom sink within AvroIO, shard controls were removed for DirectPipelineRunner. This change adds them back.
----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=115515647 Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/510a55db Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/510a55db Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/510a55db Branch: refs/heads/master Commit: 510a55dbbf9b6d1a94817f7e8e78e8211dd559a4 Parents: 8b5257f Author: lcwik <lc...@google.com> Authored: Wed Feb 24 18:01:53 2016 -0800 Committer: Davor Bonaci <davorbon...@users.noreply.github.com> Committed: Thu Feb 25 23:58:28 2016 -0800 ---------------------------------------------------------------------- .../sdk/runners/DirectPipelineRunner.java | 59 ++++++++++++++++++++ .../sdk/runners/DirectPipelineRunnerTest.java | 53 ++++++++++++++++++ 2 files changed, 112 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/510a55db/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java ---------------------------------------------------------------------- diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java index 4543b5a..872cfef 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunner.java @@ -25,6 +25,7 @@ import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.ListCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; import com.google.cloud.dataflow.sdk.io.FileBasedSink; import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; @@ -240,6 +241,8 @@ public class DirectPipelineRunner return (OutputT) applyTestCombine((Combine.GroupedValues) transform, (PCollection) input); } else if (transform instanceof TextIO.Write.Bound) { return (OutputT) applyTextIOWrite((TextIO.Write.Bound) transform, (PCollection<?>) input); + } else if (transform instanceof AvroIO.Write.Bound) { + return (OutputT) applyAvroIOWrite((AvroIO.Write.Bound) transform, (PCollection<?>) input); } else { return super.apply(transform, input); } @@ -343,6 +346,62 @@ public class DirectPipelineRunner } /** + * Applies AvroIO.Write honoring user requested sharding controls (i.e. withNumShards) + * by applying a partition function based upon the number of shards the user requested. + */ + private static class DirectAvroIOWrite<T> extends PTransform<PCollection<T>, PDone> { + private final AvroIO.Write.Bound<T> transform; + + private DirectAvroIOWrite(AvroIO.Write.Bound<T> transform) { + this.transform = transform; + } + + @Override + public PDone apply(PCollection<T> input) { + checkState(transform.getNumShards() > 1, + "DirectAvroIOWrite is expected to only be used when sharding controls are required."); + + // Evenly distribute all the elements across the partitions. + PCollectionList<T> partitionedElements = + input.apply(Partition.of(transform.getNumShards(), + new ElementProcessingOrderPartitionFn<T>())); + + // For each input PCollection partition, create a write transform that represents + // one of the specific shards. + for (int i = 0; i < transform.getNumShards(); ++i) { + /* + * This logic mirrors the file naming strategy within + * {@link FileBasedSink#generateDestinationFilenames()} + */ + String outputFilename = IOChannelUtils.constructName( + transform.getFilenamePrefix(), + transform.getShardNameTemplate(), + getFileExtension(transform.getFilenameSuffix()), + i, + transform.getNumShards()); + + String transformName = String.format("%s(Shard:%s)", transform.getName(), i); + partitionedElements.get(i).apply(transformName, + transform.withNumShards(1).withShardNameTemplate("").withSuffix("").to(outputFilename)); + } + return PDone.in(input.getPipeline()); + } + } + + /** + * Apply the override for AvroIO.Write.Bound if the user requested sharding controls + * greater than one. + */ + private <T> PDone applyAvroIOWrite(AvroIO.Write.Bound<T> transform, PCollection<T> input) { + if (transform.getNumShards() <= 1) { + // By default, the DirectPipelineRunner outputs to only 1 shard. Since the user never + // requested sharding controls greater than 1, we default to outputting to 1 file. + return super.apply(transform.withNumShards(1), input); + } + return input.apply(new DirectAvroIOWrite<>(transform)); + } + + /** * The implementation may split the {@link KeyedCombineFn} into ADD, MERGE and EXTRACT phases ( * see {@code com.google.cloud.dataflow.sdk.runners.worker.CombineValuesFn}). In order to emulate * this for the {@link DirectPipelineRunner} and provide an experience closer to the service, go http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/510a55db/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java ---------------------------------------------------------------------- diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java index 4a0f91c..6524e14 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/DirectPipelineRunnerTest.java @@ -25,8 +25,10 @@ import static org.junit.Assert.assertThat; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.AvroCoder; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.coders.CoderException; +import com.google.cloud.dataflow.sdk.io.AvroIO; import com.google.cloud.dataflow.sdk.io.ShardNameTemplate; import com.google.cloud.dataflow.sdk.io.TextIO; import com.google.cloud.dataflow.sdk.options.DirectPipelineOptions; @@ -36,8 +38,10 @@ import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.DoFn; import com.google.cloud.dataflow.sdk.transforms.ParDo; import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.common.collect.Iterables; import com.google.common.io.Files; +import org.apache.avro.file.DataFileReader; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -154,4 +158,53 @@ public class DirectPipelineRunnerTest implements Serializable { assertThat(allContents, containsInAnyOrder(expectedElements)); } + + @Test + public void testAvroIOWriteWithDefaultShardingStrategy() throws Exception { + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "output"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(AvroIO.Write.withSchema(String.class).to(prefix).withSuffix(".avro")); + p.run(); + + String filename = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".avro", 0, 1); + List<String> fileContents = new ArrayList<>(); + Iterables.addAll(fileContents, DataFileReader.openReader( + new File(filename), AvroCoder.of(String.class).createDatumReader())); + + // Ensure that each file got at least one record + assertFalse(fileContents.isEmpty()); + + assertThat(fileContents, containsInAnyOrder(expectedElements)); + } + + @Test + public void testAvroIOWriteWithLimitedNumberOfShards() throws Exception { + final int numShards = 3; + String prefix = IOChannelUtils.resolve(Files.createTempDir().toString(), "shardedOutput"); + Pipeline p = DirectPipeline.createForTest(); + String[] expectedElements = new String[]{ "a", "b", "c", "d", "e", "f", "g", "h", "i" }; + p.apply(Create.of(expectedElements)) + .apply(AvroIO.Write.withSchema(String.class).to(prefix) + .withNumShards(numShards).withSuffix(".avro")); + p.run(); + + List<String> allContents = new ArrayList<>(); + for (int i = 0; i < numShards; ++i) { + String shardFileName = + IOChannelUtils.constructName(prefix, ShardNameTemplate.INDEX_OF_MAX, ".avro", i, 3); + List<String> shardFileContents = new ArrayList<>(); + Iterables.addAll(shardFileContents, DataFileReader.openReader( + new File(shardFileName), AvroCoder.of(String.class).createDatumReader())); + + // Ensure that each file got at least one record + assertFalse(shardFileContents.isEmpty()); + + allContents.addAll(shardFileContents); + } + + assertThat(allContents, containsInAnyOrder(expectedElements)); + } }