Repository: beam Updated Branches: refs/heads/master 063fbd4ed -> 932bb823f
Inject Sharding Strategy in the Direct Runner This removes the need to have WriteBundles be a very implementation-dependent override based on both the behavior of the Write transform and the behavior of the DirectRunner. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/24613fd9 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/24613fd9 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/24613fd9 Branch: refs/heads/master Commit: 24613fd9c46d491740d57646aa37d0fe7fc10579 Parents: 063fbd4 Author: Thomas Groh <tg...@google.com> Authored: Fri Feb 17 16:52:27 2017 -0800 Committer: Thomas Groh <tg...@google.com> Committed: Mon Feb 27 11:09:21 2017 -0800 ---------------------------------------------------------------------- .../core/construction/PTransformMatchers.java | 32 ++--- .../construction/PTransformMatchersTest.java | 43 ++++++ .../beam/runners/direct/DirectRunner.java | 3 +- .../direct/WriteWithShardingFactory.java | 144 ++++++++----------- .../direct/WriteWithShardingFactoryTest.java | 141 +++++------------- 5 files changed, 155 insertions(+), 208 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java index 7b05ed1..05b632b 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.core.construction; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.runners.PTransformMatcher; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.DoFn; @@ -42,25 +43,6 @@ public class PTransformMatchers { private PTransformMatchers() {} /** - * Returns a {@link PTransformMatcher} which matches a {@link PTransform} if any of the provided - * matchers match the {@link PTransform}. - */ - public static PTransformMatcher anyOf( - final PTransformMatcher matcher, final PTransformMatcher... matchers) { - return new PTransformMatcher() { - @Override - public boolean matches(AppliedPTransform<?, ?, ?> application) { - for (PTransformMatcher component : matchers) { - if (component.matches(application)) { - return true; - } - } - return matcher.matches(application); - } - }; - } - - /** * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the class of the * {@link PTransform} is equal to the {@link Class} provided ot this matcher. * @param clazz @@ -195,4 +177,16 @@ public class PTransformMatchers { } }; } + + public static PTransformMatcher writeWithRunnerDeterminedSharding() { + return new PTransformMatcher() { + @Override + public boolean matches(AppliedPTransform<?, ?, ?> application) { + if (application.getTransform() instanceof Write.Bound) { + return ((Write.Bound) application.getTransform()).getSharding() == null; + } + return false; + } + }; + } } http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java ---------------------------------------------------------------------- diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java index 439a475..cace033 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java @@ -28,6 +28,9 @@ import java.io.Serializable; import java.util.Collections; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.coders.VoidCoder; +import org.apache.beam.sdk.io.FileBasedSink; +import org.apache.beam.sdk.io.Write; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.runners.PTransformMatcher; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -36,6 +39,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.Flatten; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; @@ -51,6 +55,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionList; +import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.TaggedPValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -379,4 +384,42 @@ public class PTransformMatchersTest implements Serializable { assertThat(PTransformMatchers.emptyFlatten().matches(application), is(false)); } + + @Test + public void writeWithRunnerDeterminedSharding() { + Write.Bound<Integer> write = + Write.to( + new FileBasedSink<Integer>("foo", "bar") { + @Override + public FileBasedWriteOperation<Integer> createWriteOperation( + PipelineOptions options) { + return null; + } + }); + assertThat( + PTransformMatchers.writeWithRunnerDeterminedSharding().matches(appliedWrite(write)), + is(true)); + + Write.Bound<Integer> withStaticSharding = write.withNumShards(3); + assertThat( + PTransformMatchers.writeWithRunnerDeterminedSharding() + .matches(appliedWrite(withStaticSharding)), + is(false)); + + Write.Bound<Integer> withCustomSharding = + write.withSharding(Sum.integersGlobally().asSingletonView()); + assertThat( + PTransformMatchers.writeWithRunnerDeterminedSharding() + .matches(appliedWrite(withCustomSharding)), + is(false)); + } + + private AppliedPTransform<?, ?, ?> appliedWrite(Write.Bound<Integer> write) { + return AppliedPTransform.<PCollection<Integer>, PDone, Write.Bound<Integer>>of( + "Write", + Collections.<TaggedPValue>emptyList(), + Collections.<TaggedPValue>emptyList(), + write, + p); + } } http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 06189a2..f56d225 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -42,7 +42,6 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.io.Read; -import org.apache.beam.sdk.io.Write.Bound; import org.apache.beam.sdk.metrics.MetricResults; import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptions; @@ -82,7 +81,7 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> { private static Map<PTransformMatcher, PTransformOverrideFactory> defaultTransformOverrides = ImmutableMap.<PTransformMatcher, PTransformOverrideFactory>builder() .put( - PTransformMatchers.classEqualTo(Bound.class), + PTransformMatchers.writeWithRunnerDeterminedSharding(), new WriteWithShardingFactory()) /* Uses a view internally. */ .put( PTransformMatchers.classEqualTo(CreatePCollectionView.class), http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java index 50ec586..f206fb0 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java @@ -18,10 +18,11 @@ package org.apache.beam.runners.direct; -import static com.google.common.base.Preconditions.checkArgument; - import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; import com.google.common.collect.Iterables; +import java.io.Serializable; import java.util.Collections; import java.util.List; import java.util.Map; @@ -29,24 +30,17 @@ import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.io.Write.Bound; +import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.Flatten; -import org.apache.beam.sdk.transforms.GroupByKey; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.Values; -import org.apache.beam.sdk.transforms.windowing.DefaultTrigger; -import org.apache.beam.sdk.transforms.windowing.GlobalWindows; -import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PDone; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TaggedPValue; -import org.joda.time.Duration; /** * A {@link PTransformOverrideFactory} that overrides {@link Write} {@link PTransform PTransforms} @@ -54,108 +48,90 @@ import org.joda.time.Duration; * of shards is the log base 10 of the number of input records, with up to 2 additional shards. */ class WriteWithShardingFactory<InputT> - implements org.apache.beam.sdk.runners.PTransformOverrideFactory< - PCollection<InputT>, PDone, Write.Bound<InputT>> { + implements PTransformOverrideFactory<PCollection<InputT>, PDone, Bound<InputT>> { static final int MAX_RANDOM_EXTRA_SHARDS = 3; + @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3; @Override - public PTransform<PCollection<InputT>, PDone> getReplacementTransform( - Bound<InputT> transform) { - if (transform.getSharding() == null) { - return new DynamicallyReshardedWrite<>(transform); - } - return transform; + public PTransform<PCollection<InputT>, PDone> getReplacementTransform(Bound<InputT> transform) { + return transform.withSharding(new LogElementShardsWithDrift<InputT>()); } @Override - public PCollection<InputT> getInput( - List<TaggedPValue> inputs, Pipeline p) { + public PCollection<InputT> getInput(List<TaggedPValue> inputs, Pipeline p) { return (PCollection<InputT>) Iterables.getOnlyElement(inputs).getValue(); } @Override - public Map<PValue, ReplacementOutput> mapOutputs( - List<TaggedPValue> outputs, PDone newOutput) { + public Map<PValue, ReplacementOutput> mapOutputs(List<TaggedPValue> outputs, PDone newOutput) { return Collections.emptyMap(); } - private static class DynamicallyReshardedWrite<T> extends PTransform<PCollection<T>, PDone> { - private final transient Write.Bound<T> original; - - private DynamicallyReshardedWrite(Bound<T> original) { - this.original = original; - } + private static class LogElementShardsWithDrift<T> + extends PTransform<PCollection<T>, PCollectionView<Integer>> { @Override - public PDone expand(PCollection<T> input) { - checkArgument( - IsBounded.BOUNDED == input.isBounded(), - "%s can only be applied to a Bounded PCollection", - getClass().getSimpleName()); - PCollection<T> records = - input.apply( - "RewindowInputs", - Window.<T>into(new GlobalWindows()) - .triggering(DefaultTrigger.of()) - .withAllowedLateness(Duration.ZERO) - .discardingFiredPanes()); - final PCollectionView<Long> numRecords = - records.apply("CountRecords", Count.<T>globally().asSingletonView()); - PCollection<T> resharded = - records - .apply( - "ApplySharding", - ParDo.withSideInputs(numRecords) - .of( - new KeyBasedOnCountFn<T>( - numRecords, - ThreadLocalRandom.current().nextInt(MAX_RANDOM_EXTRA_SHARDS)))) - .apply("GroupIntoShards", GroupByKey.<Integer, T>create()) - .apply("DropShardingKeys", Values.<Iterable<T>>create()) - .apply("FlattenShardIterables", Flatten.<T>iterables()); - // This is an inverted application to apply the expansion of the original Write PTransform - // without adding a new Write Transform Node, which would be overwritten the same way, leading - // to an infinite recursion. We cannot modify the number of shards, because that is determined - // at runtime. - return resharded.apply(original); + public PCollectionView<Integer> expand(PCollection<T> records) { + return records + .apply("CountRecords", Count.<T>globally()) + .apply("GenerateShardCount", ParDo.of(new CalculateShardsFn())) + .apply(View.<Integer>asSingleton()); } } @VisibleForTesting - static class KeyBasedOnCountFn<T> extends DoFn<T, KV<Integer, T>> { - @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3; - private final PCollectionView<Long> numRecords; - private final int randomExtraShards; - private int currentShard; - private int maxShards = 0; - - KeyBasedOnCountFn(PCollectionView<Long> numRecords, int extraShards) { - this.numRecords = numRecords; - this.randomExtraShards = extraShards; + static class CalculateShardsFn extends DoFn<Long, Integer> { + private final Supplier<Integer> extraShardsSupplier; + + public CalculateShardsFn() { + this(new BoundedRandomIntSupplier(MAX_RANDOM_EXTRA_SHARDS)); + } + + /** + * Construct a {@link CalculateShardsFn} that always uses a constant number of specified extra + * shards. + */ + @VisibleForTesting + CalculateShardsFn(int constantExtraShards) { + this(Suppliers.ofInstance(constantExtraShards)); + } + + private CalculateShardsFn(Supplier<Integer> extraShardsSupplier) { + this.extraShardsSupplier = extraShardsSupplier; } @ProcessElement - public void processElement(ProcessContext c) throws Exception { - if (maxShards == 0) { - maxShards = calculateShards(c.sideInput(numRecords)); - currentShard = ThreadLocalRandom.current().nextInt(maxShards); - } - int shard = currentShard; - currentShard = (currentShard + 1) % maxShards; - c.output(KV.of(shard, c.element())); + public void process(ProcessContext ctxt) { + ctxt.output(calculateShards(ctxt.element())); } private int calculateShards(long totalRecords) { - checkArgument( - totalRecords > 0, - "KeyBasedOnCountFn cannot be invoked on an element if there are no elements"); - if (totalRecords < MIN_SHARDS_FOR_LOG + randomExtraShards) { + if (totalRecords == 0) { + // Write out at least one shard, even if there is no input. + return 1; + } + // Windows get their own number of random extra shards. This is stored in a side input, so + // writers use a consistent number of keys. + int extraShards = extraShardsSupplier.get(); + if (totalRecords < MIN_SHARDS_FOR_LOG + extraShards) { return (int) totalRecords; } // 100mil records before >7 output files int floorLogRecs = Double.valueOf(Math.log10(totalRecords)).intValue(); - int shards = Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + randomExtraShards; - return shards; + return Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + extraShards; + } + } + + private static class BoundedRandomIntSupplier implements Supplier<Integer>, Serializable { + private final int upperBound; + + private BoundedRandomIntSupplier(int upperBound) { + this.upperBound = upperBound; + } + + @Override + public Integer get() { + return ThreadLocalRandom.current().nextInt(0, upperBound); } } } http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java index 0196a2d..51f3a87 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java @@ -19,7 +19,6 @@ package org.apache.beam.runners.direct; import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -28,8 +27,6 @@ import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertThat; -import com.google.common.base.Function; -import com.google.common.collect.Iterables; import java.io.File; import java.io.FileReader; import java.io.Reader; @@ -39,8 +36,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.UUID; -import java.util.concurrent.ThreadLocalRandom; -import org.apache.beam.runners.direct.WriteWithShardingFactory.KeyBasedOnCountFn; +import org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.Sink; import org.apache.beam.sdk.io.TextIO; @@ -48,12 +44,12 @@ import org.apache.beam.sdk.io.Write; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.IOChannelUtils; import org.apache.beam.sdk.util.PCollectionViews; import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.hamcrest.Matchers; @@ -91,7 +87,7 @@ public class WriteWithShardingFactoryTest { p.run(); Collection<String> files = IOChannelUtils.getFactory(outputPath).match(targetLocation + "*"); - List<String> actuals = new ArrayList(strs.size()); + List<String> actuals = new ArrayList<>(strs.size()); for (String file : files) { CharBuffer buf = CharBuffer.allocate((int) new File(file).length()); try (Reader reader = new FileReader(file)) { @@ -120,96 +116,56 @@ public class WriteWithShardingFactoryTest { } @Test - public void withShardingSpecifiesOriginalTransform() { - Write.Bound<Object> original = Write.to(new TestSink()).withNumShards(3); - - assertThat(factory.getReplacementTransform(original), equalTo((Object) original)); - } - - @Test public void withNoShardingSpecifiedReturnsNewTransform() { Write.Bound<Object> original = Write.to(new TestSink()); assertThat(factory.getReplacementTransform(original), not(equalTo((Object) original))); } @Test - public void keyBasedOnCountFnWithOneElement() throws Exception { - PCollectionView<Long> elementCountView = - PCollectionViews.singletonView( - p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); - KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0); - DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn); + public void keyBasedOnCountFnWithNoElements() throws Exception { + CalculateShardsFn fn = new CalculateShardsFn(0); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); - fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 1L); + List<Integer> outputs = fnTester.processBundle(0L); + assertThat( + outputs, containsInAnyOrder(1)); + } - List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar", "bazbar"); + @Test + public void keyBasedOnCountFnWithOneElement() throws Exception { + CalculateShardsFn fn = new CalculateShardsFn(0); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); + + List<Integer> outputs = fnTester.processBundle(1L); assertThat( - outputs, containsInAnyOrder(KV.of(0, "foo"), KV.of(0, "bar"), KV.of(0, "bazbar"))); + outputs, containsInAnyOrder(1)); } @Test public void keyBasedOnCountFnWithTwoElements() throws Exception { - PCollectionView<Long> elementCountView = - PCollectionViews.singletonView( - p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); - KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0); - DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn); - - fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 2L); + CalculateShardsFn fn = new CalculateShardsFn(0); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); - List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar"); - assertThat( - outputs, - anyOf( - containsInAnyOrder(KV.of(0, "foo"), KV.of(1, "bar")), - containsInAnyOrder(KV.of(1, "foo"), KV.of(0, "bar")))); + List<Integer> outputs = fnTester.processBundle(2L); + assertThat(outputs, containsInAnyOrder(2)); } @Test public void keyBasedOnCountFnFewElementsThreeShards() throws Exception { - PCollectionView<Long> elementCountView = - PCollectionViews.singletonView( - p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); - KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0); - DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn); - - fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 100L); + CalculateShardsFn fn = new CalculateShardsFn(0); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); - List<KV<Integer, String>> outputs = - fnTester.processBundle("foo", "bar", "baz", "foobar", "foobaz", "barbaz"); - assertThat( - Iterables.transform( - outputs, - new Function<KV<Integer, String>, Integer>() { - @Override - public Integer apply(KV<Integer, String> input) { - return input.getKey(); - } - }), - containsInAnyOrder(0, 0, 1, 1, 2, 2)); + List<Integer> outputs = fnTester.processBundle(5L); + assertThat(outputs, containsInAnyOrder(3)); } @Test public void keyBasedOnCountFnManyElements() throws Exception { - PCollectionView<Long> elementCountView = - PCollectionViews.singletonView( - p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); - KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0); - DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn); - - double count = Math.pow(10, 10); - fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) count); + DoFn<Long, Integer> fn = new CalculateShardsFn(0); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); - List<String> strings = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong())); - } - List<KV<Integer, String>> kvs = fnTester.processBundle(strings); - long maxKey = -1L; - for (KV<Integer, String> kv : kvs) { - maxKey = Math.max(maxKey, kv.getKey()); - } - assertThat(maxKey, equalTo(9L)); + List<Integer> shard = fnTester.processBundle((long) Math.pow(10, 10)); + assertThat(shard, containsInAnyOrder(10)); } @Test @@ -217,46 +173,25 @@ public class WriteWithShardingFactoryTest { PCollectionView<Long> elementCountView = PCollectionViews.singletonView( p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); - KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 10); - DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn); + CalculateShardsFn fn = new CalculateShardsFn(3); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); - long countValue = (long) KeyBasedOnCountFn.MIN_SHARDS_FOR_LOG + 3; + long countValue = (long) WriteWithShardingFactory.MIN_SHARDS_FOR_LOG + 3; fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, countValue); - List<String> strings = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong())); - } - List<KV<Integer, String>> kvs = fnTester.processBundle(strings); - long maxKey = -1L; - for (KV<Integer, String> kv : kvs) { - maxKey = Math.max(maxKey, kv.getKey()); - } - // 0 to n-1 shard ids. - assertThat(maxKey, equalTo(countValue - 1)); + List<Integer> kvs = fnTester.processBundle(10L); + assertThat(kvs, containsInAnyOrder(6)); } @Test public void keyBasedOnCountFnManyElementsExtraShards() throws Exception { - PCollectionView<Long> elementCountView = - PCollectionViews.singletonView( - p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); - KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 3); - DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn); + CalculateShardsFn fn = new CalculateShardsFn(3); + DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn); double count = Math.pow(10, 10); - fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) count); - List<String> strings = new ArrayList<>(); - for (int i = 0; i < 100; i++) { - strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong())); - } - List<KV<Integer, String>> kvs = fnTester.processBundle(strings); - long maxKey = -1L; - for (KV<Integer, String> kv : kvs) { - maxKey = Math.max(maxKey, kv.getKey()); - } - assertThat(maxKey, equalTo(12L)); + List<Integer> shards = fnTester.processBundle((long) count); + assertThat(shards, containsInAnyOrder(13)); } @Test