Inject number of Shards as a Side Input to Write This permits users to pass a PTransform from the input elements to the number of shards instead of requiring a constant amount. This enables sharding to be determined based on the input data rather than a constant value.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/0b737496 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/0b737496 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/0b737496 Branch: refs/heads/master Commit: 0b737496b9870ee1b3155e7596cf90be37f6cf71 Parents: 9335738 Author: Thomas Groh <tg...@google.com> Authored: Tue Feb 7 15:16:19 2017 -0800 Committer: Thomas Groh <tg...@google.com> Committed: Fri Feb 17 16:27:23 2017 -0800 ---------------------------------------------------------------------- .../direct/WriteWithShardingFactory.java | 2 +- .../main/java/org/apache/beam/sdk/io/Write.java | 195 +++++++++++++++---- .../java/org/apache/beam/sdk/io/WriteTest.java | 138 +++++++++++-- 3 files changed, 281 insertions(+), 54 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/0b737496/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index 83c82a5..50ec586 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -61,7 +61,7 @@ class WriteWithShardingFactory<InputT> @Override public PTransform<PCollection<InputT>, PDone> getReplacementTransform( Bound<InputT> transform) { - if (transform.getNumShards() == 0) { + if (transform.getSharding() == null) { return new DynamicallyReshardedWrite<>(transform); } return transform; http://git-wip-us.apache.org/repos/asf/beam/blob/0b737496/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Write.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Write.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Write.java index bc651d8..acbbb97 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Write.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Write.java @@ -20,10 +20,13 @@ package org.apache.beam.sdk.io; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; import java.util.UUID; import java.util.concurrent.ThreadLocalRandom; +import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.coders.Coder; @@ -31,14 +34,14 @@ import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.Sink.WriteOperation; import org.apache.beam.sdk.io.Sink.Writer; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.ValueProvider; +import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.View; -import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; @@ -66,7 +69,7 @@ import org.slf4j.LoggerFactory; * <p>{@code Write} re-windows the data into the global window, so it is typically not well suited * to use in streaming pipelines. * - * <p>Example usage with runner-controlled sharding: + * <p>Example usage with runner-determined sharding: * * <pre>{@code p.apply(Write.to(new MySink(...)));}</pre> * @@ -84,7 +87,7 @@ public class Write { */ public static <T> Bound<T> to(Sink<T> sink) { checkNotNull(sink, "sink"); - return new Bound<>(sink, 0 /* runner-controlled sharding */); + return new Bound<>(sink, null /* runner-determined sharding */); } /** @@ -96,16 +99,20 @@ public class Write { */ public static class Bound<T> extends PTransform<PCollection<T>, PDone> { private final Sink<T> sink; - private int numShards; + @Nullable + private final PTransform<PCollection<T>, PCollectionView<Integer>> computeNumShards; - private Bound(Sink<T> sink, int numShards) { + private Bound( + Sink<T> sink, + @Nullable PTransform<PCollection<T>, PCollectionView<Integer>> computeNumShards) { this.sink = sink; - this.numShards = numShards; + this.computeNumShards = computeNumShards; } @Override public PDone expand(PCollection<T> input) { - checkArgument(IsBounded.BOUNDED == input.isBounded(), + checkArgument( + IsBounded.BOUNDED == input.isBounded(), "%s can only be applied to a Bounded PCollection", Write.class.getSimpleName()); PipelineOptions options = input.getPipeline().getOptions(); @@ -118,26 +125,28 @@ public class Write { super.populateDisplayData(builder); builder .add(DisplayData.item("sink", sink.getClass()).withLabel("Write Sink")) - .include("sink", sink) - .addIfNotDefault( - DisplayData.item("numShards", getNumShards()).withLabel("Fixed Number of Shards"), - 0); + .include("sink", sink); + if (getSharding() != null) { + builder.include("sharding", getSharding()); + } } /** - * Returns the number of shards that will be produced in the output. - * - * @see Write for more information + * Returns the {@link Sink} associated with this PTransform. */ - public int getNumShards() { - return numShards; + public Sink<T> getSink() { + return sink; } /** - * Returns the {@link Sink} associated with this PTransform. + * Gets the {@link PTransform} that will be used to determine sharding. This can be either a + * static number of shards (as following a call to {@link #withNumShards(int)}), dynamic (by + * {@link #withSharding(PTransform)}), or runner-determined (by {@link + * #withRunnerDeterminedSharding()}. */ - public Sink<T> getSink() { - return sink; + @Nullable + public PTransform<PCollection<T>, PCollectionView<Integer>> getSharding() { + return computeNumShards; } /** @@ -148,10 +157,45 @@ public class Write { * more information. * * <p>A value less than or equal to 0 will be equivalent to the default behavior of - * runner-controlled sharding. + * runner-determined sharding. */ public Bound<T> withNumShards(int numShards) { - return new Bound<>(sink, Math.max(numShards, 0)); + if (numShards > 0) { + return withNumShards(StaticValueProvider.of(numShards)); + } + return withRunnerDeterminedSharding(); + } + + /** + * Returns a new {@link Write.Bound} that will write to the current {@link Sink} using the + * {@link ValueProvider} specified number of shards. + * + * <p>This option should be used sparingly as it can hurt performance. See {@link Write} for + * more information. + */ + public Bound<T> withNumShards(ValueProvider<Integer> numShards) { + return new Bound<>(sink, new ConstantShards<T>(numShards)); + } + + /** + * Returns a new {@link Write.Bound} that will write to the current {@link Sink} using the + * specified {@link PTransform} to compute the number of shards. + * + * <p>This option should be used sparingly as it can hurt performance. See {@link Write} for + * more information. + */ + public Bound<T> withSharding(PTransform<PCollection<T>, PCollectionView<Integer>> sharding) { + checkNotNull( + sharding, "Cannot provide null sharding. Use withRunnerDeterminedSharding() instead"); + return new Bound<>(sink, sharding); + } + + /** + * Returns a new {@link Write.Bound} that will write to the current {@link Sink} with + * runner-determined sharding. + */ + public Bound<T> withRunnerDeterminedSharding() { + return new Bound<>(sink, null); } /** @@ -265,25 +309,31 @@ public class Write { } } - private static class ApplyShardingKey<T> implements SerializableFunction<T, Integer> { - private final int numShards; + private static class ApplyShardingKey<T> extends DoFn<T, KV<Integer, T>> { + private final PCollectionView<Integer> numShards; private int shardNumber; - ApplyShardingKey(int numShards) { + ApplyShardingKey(PCollectionView<Integer> numShards) { this.numShards = numShards; shardNumber = -1; } - @Override - public Integer apply(T input) { + @ProcessElement + public void processElement(ProcessContext context) { + Integer shardCount = context.sideInput(numShards); + checkArgument( + shardCount > 0, + "Must have a positive number of shards specified for non-runner-determined sharding." + + " Got %s", + shardCount); if (shardNumber == -1) { // We want to desynchronize the first record sharding key for each instance of // ApplyShardingKey, so records in a small PCollection will be statistically balanced. - shardNumber = ThreadLocalRandom.current().nextInt(numShards); + shardNumber = ThreadLocalRandom.current().nextInt(shardCount); } else { - shardNumber = (shardNumber + 1) % numShards; + shardNumber = (shardNumber + 1) % shardCount; } - return shardNumber; + context.output(KV.of(shardNumber, context.element())); } } @@ -366,18 +416,26 @@ public class Write { // There is a dependency between this ParDo and the first (the WriteOperation PCollection // as a side input), so this will happen after the initial ParDo. PCollection<WriteT> results; - if (getNumShards() <= 0) { - results = inputInGlobalWindow - .apply("WriteBundles", + final PCollectionView<Integer> numShards; + if (computeNumShards == null) { + numShards = null; + results = + inputInGlobalWindow.apply( + "WriteBundles", ParDo.of(new WriteBundles<>(writeOperationView)) .withSideInputs(writeOperationView)); } else { - results = inputInGlobalWindow - .apply("ApplyShardLabel", WithKeys.of(new ApplyShardingKey<T>(getNumShards()))) - .apply("GroupIntoShards", GroupByKey.<Integer, T>create()) - .apply("WriteShardedBundles", - ParDo.of(new WriteShardedBundles<>(writeOperationView)) - .withSideInputs(writeOperationView)); + numShards = inputInGlobalWindow.apply(computeNumShards); + results = + inputInGlobalWindow + .apply( + "ApplyShardLabel", + ParDo.of(new ApplyShardingKey<T>(numShards)).withSideInputs(numShards)) + .apply("GroupIntoShards", GroupByKey.<Integer, T>create()) + .apply( + "WriteShardedBundles", + ParDo.of(new WriteShardedBundles<>(writeOperationView)) + .withSideInputs(writeOperationView)); } results.setCoder(writeOperation.getWriterResultCoder()); @@ -389,6 +447,11 @@ public class Write { // The WriteOperation's state is the same as after its initialization in the first do-once // ParDo. There is a dependency between this ParDo and the parallel write (the writer results // collection as a side input), so it will happen after the parallel write. + ImmutableList.Builder<PCollectionView<?>> sideInputs = + ImmutableList.<PCollectionView<?>>builder().add(resultsView); + if (numShards != null) { + sideInputs.add(numShards); + } operationCollection .apply("Finalize", ParDo.of(new DoFn<WriteOperation<T, WriteT>, Integer>() { @ProcessElement @@ -399,7 +462,17 @@ public class Write { LOG.debug("Side input initialized to finalize write operation {}.", writeOperation); // We must always output at least 1 shard, and honor user-specified numShards if set. - int minShardsNeeded = Math.max(1, getNumShards()); + int minShardsNeeded; + if (numShards == null) { + minShardsNeeded = 1; + } else { + minShardsNeeded = c.sideInput(numShards); + checkArgument( + minShardsNeeded > 0, + "Must have a positive number of shards for non-runner-determined sharding." + + " Got %s", + minShardsNeeded); + } int extraShardsNeeded = minShardsNeeded - results.size(); if (extraShardsNeeded > 0) { LOG.info( @@ -417,8 +490,48 @@ public class Write { writeOperation.finalize(results, c.getPipelineOptions()); LOG.debug("Done finalizing write operation {}", writeOperation); } - }).withSideInputs(resultsView)); + }).withSideInputs(sideInputs.build())); return PDone.in(input.getPipeline()); } } + + @VisibleForTesting + static class ConstantShards<T> + extends PTransform<PCollection<T>, PCollectionView<Integer>> { + private final ValueProvider<Integer> numShards; + + private ConstantShards(ValueProvider<Integer> numShards) { + this.numShards = numShards; + } + + @Override + public PCollectionView<Integer> expand(PCollection<T> input) { + return input + .getPipeline() + .apply(Create.of(0)) + .apply( + "FixedNumShards", + ParDo.of( + new DoFn<Integer, Integer>() { + @ProcessElement + public void outputNumShards(ProcessContext ctxt) { + checkArgument( + numShards.isAccessible(), + "NumShards must be accessible at runtime to use constant sharding"); + ctxt.output(numShards.get()); + } + })) + .apply(View.<Integer>asSingleton()); + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add( + DisplayData.item("Fixed Number of Shards", numShards).withLabel("ConstantShards")); + } + + public ValueProvider<Integer> getNumShards() { + return numShards; + } + } } http://git-wip-us.apache.org/repos/asf/beam/blob/0b737496/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java index 846d445..fd349e2 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java @@ -23,7 +23,9 @@ import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -47,6 +49,7 @@ import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.Sink.WriteOperation; import org.apache.beam.sdk.io.Sink.Writer; +import org.apache.beam.sdk.io.Write.ConstantShards; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactoryTest.TestPipelineOptions; @@ -54,18 +57,22 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SimpleFunction; import org.apache.beam.sdk.transforms.ToString; +import org.apache.beam.sdk.transforms.Top; +import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Sessions; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.hamcrest.Matchers; import org.joda.time.Duration; import org.junit.Rule; @@ -93,12 +100,22 @@ public class WriteTest { @SuppressWarnings("unchecked") // covariant cast private static final PTransform<PCollection<String>, PCollection<String>> IDENTITY_MAP = (PTransform) - MapElements.via(new SimpleFunction<String, String>() { - @Override - public String apply(String input) { - return input; - } - }); + MapElements.via( + new SimpleFunction<String, String>() { + @Override + public String apply(String input) { + return input; + } + }); + + private static final PTransform<PCollection<String>, PCollectionView<Integer>> + SHARDING_TRANSFORM = + new PTransform<PCollection<String>, PCollectionView<Integer>>() { + @Override + public PCollectionView<Integer> expand(PCollection<String> input) { + return null; + } + }; private static class WindowAndReshuffle<T> extends PTransform<PCollection<T>, PCollection<T>> { private final Window.Bound<T> window; @@ -169,6 +186,43 @@ public class WriteTest { Optional.of(1)); } + @Test + @Category(NeedsRunner.class) + public void testCustomShardedWrite() { + // Flag to validate that the pipeline options are passed to the Sink + WriteOptions options = TestPipeline.testingPipelineOptions().as(WriteOptions.class); + options.setTestFlag("test_value"); + Pipeline p = TestPipeline.create(options); + + // Clear the sink's contents. + sinkContents.clear(); + // Reset the number of shards produced. + numShards.set(0); + // Reset the number of records in each shard. + recordsPerShard.clear(); + + List<String> inputs = new ArrayList<>(); + // Prepare timestamps for the elements. + List<Long> timestamps = new ArrayList<>(); + for (long i = 0; i < 1000; i++) { + inputs.add(Integer.toString(3)); + timestamps.add(i + 1); + } + + TestSink sink = new TestSink(); + Write.Bound<String> write = Write.to(sink).withSharding(new LargestInt()); + p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of())) + .apply(IDENTITY_MAP) + .apply(write); + + p.run(); + assertThat(sinkContents, containsInAnyOrder(inputs.toArray())); + assertTrue(sink.hasCorrectState()); + // The PCollection has values all equal to three, which should be fed as the sharding strategy + assertEquals(3, numShards.intValue()); + assertEquals(3, recordsPerShard.size()); + } + /** * Test that Write with a configured number of shards produces the desired number of shards even * when there are too few elements. @@ -254,14 +308,21 @@ public class WriteTest { public void testBuildWrite() { Sink<String> sink = new TestSink() {}; Write.Bound<String> write = Write.to(sink).withNumShards(3); - assertEquals(3, write.getNumShards()); assertThat(write.getSink(), is(sink)); + PTransform<PCollection<String>, PCollectionView<Integer>> originalSharding = + write.getSharding(); + assertThat(write.getSharding(), instanceOf(ConstantShards.class)); + assertThat(((ConstantShards<String>) write.getSharding()).getNumShards().get(), equalTo(3)); + assertThat(write.getSharding(), equalTo(originalSharding)); - Write.Bound<String> write2 = write.withNumShards(7); - assertEquals(7, write2.getNumShards()); + Write.Bound<String> write2 = write.withSharding(SHARDING_TRANSFORM); assertThat(write2.getSink(), is(sink)); + assertThat(write2.getSharding(), equalTo(SHARDING_TRANSFORM)); // original unchanged - assertEquals(3, write.getNumShards()); + + Write.Bound<String> writeUnsharded = write2.withRunnerDeterminedSharding(); + assertThat(writeUnsharded.getSharding(), nullValue()); + assertThat(write.getSharding(), equalTo(originalSharding)); } @Test @@ -291,7 +352,35 @@ public class WriteTest { DisplayData displayData = DisplayData.from(write); assertThat(displayData, hasDisplayItem("sink", sink.getClass())); assertThat(displayData, includesDisplayDataFor("sink", sink)); - assertThat(displayData, hasDisplayItem("numShards", 1)); + assertThat(displayData, hasDisplayItem("Fixed Number of Shards", 1)); + } + + @Test + public void testCustomShardStrategyDisplayData() { + TestSink sink = new TestSink() { + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("foo", "bar")); + } + }; + Write.Bound<String> write = + Write.to(sink) + .withSharding( + new PTransform<PCollection<String>, PCollectionView<Integer>>() { + @Override + public PCollectionView<Integer> expand(PCollection<String> input) { + return null; + } + + @Override + public void populateDisplayData(DisplayData.Builder builder) { + builder.add(DisplayData.item("spam", "ham")); + } + }); + DisplayData displayData = DisplayData.from(write); + assertThat(displayData, hasDisplayItem("sink", sink.getClass())); + assertThat(displayData, includesDisplayDataFor("sink", sink)); + assertThat(displayData, hasDisplayItem("spam", "ham")); } @Test @@ -322,7 +411,8 @@ public class WriteTest { * verifies that the output number of shards is correct. */ private static void runShardedWrite( - List<String> inputs, PTransform<PCollection<String>, PCollection<String>> transform, + List<String> inputs, + PTransform<PCollection<String>, PCollection<String>> transform, Optional<Integer> numConfiguredShards) { // Flag to validate that the pipeline options are passed to the Sink WriteOptions options = TestPipeline.testingPipelineOptions().as(WriteOptions.class); @@ -573,4 +663,28 @@ public class WriteTest { String getTestFlag(); void setTestFlag(String value); } + + /** + * Outputs the largest integer in a {@link PCollection} into a {@link PCollectionView}. The input + * {@link PCollection} must be convertible to integers via {@link Integer#valueOf(String)} + */ + private static class LargestInt + extends PTransform<PCollection<String>, PCollectionView<Integer>> { + @Override + public PCollectionView<Integer> expand(PCollection<String> input) { + return input + .apply( + ParDo.of( + new DoFn<String, Integer>() { + @ProcessElement + public void toInteger(ProcessContext ctxt) { + ctxt.output(Integer.valueOf(ctxt.element())); + } + })) + .apply(Top.<Integer>largest(1)) + .apply(Flatten.<Integer>iterables()) + .apply(View.<Integer>asSingleton()); + } + } + }