Repository: beam
Updated Branches:
  refs/heads/master 063fbd4ed -> 932bb823f


Inject Sharding Strategy in the Direct Runner

This removes the need to have WriteBundles be a very
implementation-dependent override based on both the behavior of the
Write transform and the behavior of the DirectRunner.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/24613fd9
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/24613fd9
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/24613fd9

Branch: refs/heads/master
Commit: 24613fd9c46d491740d57646aa37d0fe7fc10579
Parents: 063fbd4
Author: Thomas Groh <tg...@google.com>
Authored: Fri Feb 17 16:52:27 2017 -0800
Committer: Thomas Groh <tg...@google.com>
Committed: Mon Feb 27 11:09:21 2017 -0800

----------------------------------------------------------------------
 .../core/construction/PTransformMatchers.java   |  32 ++---
 .../construction/PTransformMatchersTest.java    |  43 ++++++
 .../beam/runners/direct/DirectRunner.java       |   3 +-
 .../direct/WriteWithShardingFactory.java        | 144 ++++++++-----------
 .../direct/WriteWithShardingFactoryTest.java    | 141 +++++-------------
 5 files changed, 155 insertions(+), 208 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
index 7b05ed1..05b632b 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
@@ -19,6 +19,7 @@ package org.apache.beam.runners.core.construction;
 
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.runners.PTransformMatcher;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -42,25 +43,6 @@ public class PTransformMatchers {
   private PTransformMatchers() {}
 
   /**
-   * Returns a {@link PTransformMatcher} which matches a {@link PTransform} if 
any of the provided
-   * matchers match the {@link PTransform}.
-   */
-  public static PTransformMatcher anyOf(
-      final PTransformMatcher matcher, final PTransformMatcher... matchers) {
-    return new PTransformMatcher() {
-      @Override
-      public boolean matches(AppliedPTransform<?, ?, ?> application) {
-        for (PTransformMatcher component : matchers) {
-          if (component.matches(application)) {
-            return true;
-          }
-        }
-        return matcher.matches(application);
-      }
-    };
-  }
-
-  /**
    * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if 
the class of the
    * {@link PTransform} is equal to the {@link Class} provided ot this matcher.
    * @param clazz
@@ -195,4 +177,16 @@ public class PTransformMatchers {
       }
     };
   }
+
+  public static PTransformMatcher writeWithRunnerDeterminedSharding() {
+    return new PTransformMatcher() {
+      @Override
+      public boolean matches(AppliedPTransform<?, ?, ?> application) {
+        if (application.getTransform() instanceof Write.Bound) {
+          return ((Write.Bound) application.getTransform()).getSharding() == 
null;
+        }
+        return false;
+      }
+    };
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
----------------------------------------------------------------------
diff --git 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
index 439a475..cace033 100644
--- 
a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
+++ 
b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java
@@ -28,6 +28,9 @@ import java.io.Serializable;
 import java.util.Collections;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.io.FileBasedSink;
+import org.apache.beam.sdk.io.Write;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.PTransformMatcher;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
@@ -36,6 +39,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.Window;
@@ -51,6 +55,7 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PCollectionList;
+import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.TaggedPValue;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.TupleTagList;
@@ -379,4 +384,42 @@ public class PTransformMatchersTest implements 
Serializable {
 
     assertThat(PTransformMatchers.emptyFlatten().matches(application), 
is(false));
   }
+
+  @Test
+  public void writeWithRunnerDeterminedSharding() {
+    Write.Bound<Integer> write =
+        Write.to(
+            new FileBasedSink<Integer>("foo", "bar") {
+              @Override
+              public FileBasedWriteOperation<Integer> createWriteOperation(
+                  PipelineOptions options) {
+                return null;
+              }
+            });
+    assertThat(
+        
PTransformMatchers.writeWithRunnerDeterminedSharding().matches(appliedWrite(write)),
+        is(true));
+
+    Write.Bound<Integer> withStaticSharding = write.withNumShards(3);
+    assertThat(
+        PTransformMatchers.writeWithRunnerDeterminedSharding()
+            .matches(appliedWrite(withStaticSharding)),
+        is(false));
+
+    Write.Bound<Integer> withCustomSharding =
+        write.withSharding(Sum.integersGlobally().asSingletonView());
+    assertThat(
+        PTransformMatchers.writeWithRunnerDeterminedSharding()
+            .matches(appliedWrite(withCustomSharding)),
+        is(false));
+  }
+
+  private AppliedPTransform<?, ?, ?> appliedWrite(Write.Bound<Integer> write) {
+    return AppliedPTransform.<PCollection<Integer>, PDone, 
Write.Bound<Integer>>of(
+        "Write",
+        Collections.<TaggedPValue>emptyList(),
+        Collections.<TaggedPValue>emptyList(),
+        write,
+        p);
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 06189a2..f56d225 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -42,7 +42,6 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
 import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.io.Read;
-import org.apache.beam.sdk.io.Write.Bound;
 import org.apache.beam.sdk.metrics.MetricResults;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -82,7 +81,7 @@ public class DirectRunner extends 
PipelineRunner<DirectPipelineResult> {
   private static Map<PTransformMatcher, PTransformOverrideFactory> 
defaultTransformOverrides =
       ImmutableMap.<PTransformMatcher, PTransformOverrideFactory>builder()
           .put(
-              PTransformMatchers.classEqualTo(Bound.class),
+              PTransformMatchers.writeWithRunnerDeterminedSharding(),
               new WriteWithShardingFactory()) /* Uses a view internally. */
           .put(
               PTransformMatchers.classEqualTo(CreatePCollectionView.class),

http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
index 50ec586..f206fb0 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -18,10 +18,11 @@
 
 package org.apache.beam.runners.direct;
 
-import static com.google.common.base.Preconditions.checkArgument;
-
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Supplier;
+import com.google.common.base.Suppliers;
 import com.google.common.collect.Iterables;
+import java.io.Serializable;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -29,24 +30,17 @@ import java.util.concurrent.ThreadLocalRandom;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.io.Write.Bound;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.Flatten;
-import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.Values;
-import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
-import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TaggedPValue;
-import org.joda.time.Duration;
 
 /**
  * A {@link PTransformOverrideFactory} that overrides {@link Write} {@link 
PTransform PTransforms}
@@ -54,108 +48,90 @@ import org.joda.time.Duration;
  * of shards is the log base 10 of the number of input records, with up to 2 
additional shards.
  */
 class WriteWithShardingFactory<InputT>
-    implements org.apache.beam.sdk.runners.PTransformOverrideFactory<
-        PCollection<InputT>, PDone, Write.Bound<InputT>> {
+    implements PTransformOverrideFactory<PCollection<InputT>, PDone, 
Bound<InputT>> {
   static final int MAX_RANDOM_EXTRA_SHARDS = 3;
+  @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3;
 
   @Override
-  public PTransform<PCollection<InputT>, PDone> getReplacementTransform(
-      Bound<InputT> transform) {
-    if (transform.getSharding() == null) {
-      return new DynamicallyReshardedWrite<>(transform);
-    }
-    return transform;
+  public PTransform<PCollection<InputT>, PDone> 
getReplacementTransform(Bound<InputT> transform) {
+    return transform.withSharding(new LogElementShardsWithDrift<InputT>());
   }
 
   @Override
-  public PCollection<InputT> getInput(
-      List<TaggedPValue> inputs, Pipeline p) {
+  public PCollection<InputT> getInput(List<TaggedPValue> inputs, Pipeline p) {
     return (PCollection<InputT>) Iterables.getOnlyElement(inputs).getValue();
   }
 
   @Override
-  public Map<PValue, ReplacementOutput> mapOutputs(
-      List<TaggedPValue> outputs, PDone newOutput) {
+  public Map<PValue, ReplacementOutput> mapOutputs(List<TaggedPValue> outputs, 
PDone newOutput) {
     return Collections.emptyMap();
   }
 
-  private static class DynamicallyReshardedWrite<T> extends 
PTransform<PCollection<T>, PDone> {
-    private final transient Write.Bound<T> original;
-
-    private DynamicallyReshardedWrite(Bound<T> original) {
-      this.original = original;
-    }
+  private static class LogElementShardsWithDrift<T>
+      extends PTransform<PCollection<T>, PCollectionView<Integer>> {
 
     @Override
-    public PDone expand(PCollection<T> input) {
-      checkArgument(
-          IsBounded.BOUNDED == input.isBounded(),
-          "%s can only be applied to a Bounded PCollection",
-          getClass().getSimpleName());
-      PCollection<T> records =
-          input.apply(
-              "RewindowInputs",
-              Window.<T>into(new GlobalWindows())
-                  .triggering(DefaultTrigger.of())
-                  .withAllowedLateness(Duration.ZERO)
-                  .discardingFiredPanes());
-      final PCollectionView<Long> numRecords =
-          records.apply("CountRecords", Count.<T>globally().asSingletonView());
-      PCollection<T> resharded =
-          records
-              .apply(
-                  "ApplySharding",
-                  ParDo.withSideInputs(numRecords)
-                      .of(
-                          new KeyBasedOnCountFn<T>(
-                              numRecords,
-                              
ThreadLocalRandom.current().nextInt(MAX_RANDOM_EXTRA_SHARDS))))
-              .apply("GroupIntoShards", GroupByKey.<Integer, T>create())
-              .apply("DropShardingKeys", Values.<Iterable<T>>create())
-              .apply("FlattenShardIterables", Flatten.<T>iterables());
-      // This is an inverted application to apply the expansion of the 
original Write PTransform
-      // without adding a new Write Transform Node, which would be overwritten 
the same way, leading
-      // to an infinite recursion. We cannot modify the number of shards, 
because that is determined
-      // at runtime.
-      return resharded.apply(original);
+    public PCollectionView<Integer> expand(PCollection<T> records) {
+      return records
+          .apply("CountRecords", Count.<T>globally())
+          .apply("GenerateShardCount", ParDo.of(new CalculateShardsFn()))
+          .apply(View.<Integer>asSingleton());
     }
   }
 
   @VisibleForTesting
-  static class KeyBasedOnCountFn<T> extends DoFn<T, KV<Integer, T>> {
-    @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3;
-    private final PCollectionView<Long> numRecords;
-    private final int randomExtraShards;
-    private int currentShard;
-    private int maxShards = 0;
-
-    KeyBasedOnCountFn(PCollectionView<Long> numRecords, int extraShards) {
-      this.numRecords = numRecords;
-      this.randomExtraShards = extraShards;
+  static class CalculateShardsFn extends DoFn<Long, Integer> {
+    private final Supplier<Integer> extraShardsSupplier;
+
+    public CalculateShardsFn() {
+      this(new BoundedRandomIntSupplier(MAX_RANDOM_EXTRA_SHARDS));
+    }
+
+    /**
+     * Construct a {@link CalculateShardsFn} that always uses a constant 
number of specified extra
+     * shards.
+     */
+    @VisibleForTesting
+    CalculateShardsFn(int constantExtraShards) {
+      this(Suppliers.ofInstance(constantExtraShards));
+    }
+
+    private CalculateShardsFn(Supplier<Integer> extraShardsSupplier) {
+      this.extraShardsSupplier = extraShardsSupplier;
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c) throws Exception {
-      if (maxShards == 0) {
-        maxShards = calculateShards(c.sideInput(numRecords));
-        currentShard = ThreadLocalRandom.current().nextInt(maxShards);
-      }
-      int shard = currentShard;
-      currentShard = (currentShard + 1) % maxShards;
-      c.output(KV.of(shard, c.element()));
+    public void process(ProcessContext ctxt) {
+      ctxt.output(calculateShards(ctxt.element()));
     }
 
     private int calculateShards(long totalRecords) {
-      checkArgument(
-          totalRecords > 0,
-          "KeyBasedOnCountFn cannot be invoked on an element if there are no 
elements");
-      if (totalRecords < MIN_SHARDS_FOR_LOG + randomExtraShards) {
+      if (totalRecords == 0) {
+        // Write out at least one shard, even if there is no input.
+        return 1;
+      }
+      // Windows get their own number of random extra shards. This is stored 
in a side input, so
+      // writers use a consistent number of keys.
+      int extraShards = extraShardsSupplier.get();
+      if (totalRecords < MIN_SHARDS_FOR_LOG + extraShards) {
         return (int) totalRecords;
       }
       // 100mil records before >7 output files
       int floorLogRecs = Double.valueOf(Math.log10(totalRecords)).intValue();
-      int shards = Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + 
randomExtraShards;
-      return shards;
+      return Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + extraShards;
+    }
+  }
+
+  private static class BoundedRandomIntSupplier implements Supplier<Integer>, 
Serializable {
+    private final int upperBound;
+
+    private BoundedRandomIntSupplier(int upperBound) {
+      this.upperBound = upperBound;
+    }
+
+    @Override
+    public Integer get() {
+      return ThreadLocalRandom.current().nextInt(0, upperBound);
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/24613fd9/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
index 0196a2d..51f3a87 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -19,7 +19,6 @@
 package org.apache.beam.runners.direct;
 
 import static org.hamcrest.Matchers.allOf;
-import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
@@ -28,8 +27,6 @@ import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertThat;
 
-import com.google.common.base.Function;
-import com.google.common.collect.Iterables;
 import java.io.File;
 import java.io.FileReader;
 import java.io.Reader;
@@ -39,8 +36,7 @@ import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.UUID;
-import java.util.concurrent.ThreadLocalRandom;
-import 
org.apache.beam.runners.direct.WriteWithShardingFactory.KeyBasedOnCountFn;
+import 
org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn;
 import org.apache.beam.sdk.coders.VarLongCoder;
 import org.apache.beam.sdk.io.Sink;
 import org.apache.beam.sdk.io.TextIO;
@@ -48,12 +44,12 @@ import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnTester;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.IOChannelUtils;
 import org.apache.beam.sdk.util.PCollectionViews;
 import org.apache.beam.sdk.util.WindowingStrategy;
-import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.hamcrest.Matchers;
@@ -91,7 +87,7 @@ public class WriteWithShardingFactoryTest {
     p.run();
 
     Collection<String> files = 
IOChannelUtils.getFactory(outputPath).match(targetLocation + "*");
-    List<String> actuals = new ArrayList(strs.size());
+    List<String> actuals = new ArrayList<>(strs.size());
     for (String file : files) {
       CharBuffer buf = CharBuffer.allocate((int) new File(file).length());
       try (Reader reader = new FileReader(file)) {
@@ -120,96 +116,56 @@ public class WriteWithShardingFactoryTest {
   }
 
   @Test
-  public void withShardingSpecifiesOriginalTransform() {
-    Write.Bound<Object> original = Write.to(new TestSink()).withNumShards(3);
-
-    assertThat(factory.getReplacementTransform(original), equalTo((Object) 
original));
-  }
-
-  @Test
   public void withNoShardingSpecifiedReturnsNewTransform() {
     Write.Bound<Object> original = Write.to(new TestSink());
     assertThat(factory.getReplacementTransform(original), not(equalTo((Object) 
original)));
   }
 
   @Test
-  public void keyBasedOnCountFnWithOneElement() throws Exception {
-    PCollectionView<Long> elementCountView =
-        PCollectionViews.singletonView(
-            p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
-    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 
0);
-    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+  public void keyBasedOnCountFnWithNoElements() throws Exception {
+    CalculateShardsFn fn = new CalculateShardsFn(0);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
 
-    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 1L);
+    List<Integer> outputs = fnTester.processBundle(0L);
+    assertThat(
+        outputs, containsInAnyOrder(1));
+  }
 
-    List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar", 
"bazbar");
+  @Test
+  public void keyBasedOnCountFnWithOneElement() throws Exception {
+    CalculateShardsFn fn = new CalculateShardsFn(0);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
+
+    List<Integer> outputs = fnTester.processBundle(1L);
     assertThat(
-        outputs, containsInAnyOrder(KV.of(0, "foo"), KV.of(0, "bar"), KV.of(0, 
"bazbar")));
+        outputs, containsInAnyOrder(1));
   }
 
   @Test
   public void keyBasedOnCountFnWithTwoElements() throws Exception {
-    PCollectionView<Long> elementCountView =
-        PCollectionViews.singletonView(
-            p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
-    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 
0);
-    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
-
-    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 2L);
+    CalculateShardsFn fn = new CalculateShardsFn(0);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
 
-    List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar");
-    assertThat(
-        outputs,
-        anyOf(
-            containsInAnyOrder(KV.of(0, "foo"), KV.of(1, "bar")),
-            containsInAnyOrder(KV.of(1, "foo"), KV.of(0, "bar"))));
+    List<Integer> outputs = fnTester.processBundle(2L);
+    assertThat(outputs, containsInAnyOrder(2));
   }
 
   @Test
   public void keyBasedOnCountFnFewElementsThreeShards() throws Exception {
-    PCollectionView<Long> elementCountView =
-        PCollectionViews.singletonView(
-            p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
-    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 
0);
-    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
-
-    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 100L);
+    CalculateShardsFn fn = new CalculateShardsFn(0);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
 
-    List<KV<Integer, String>> outputs =
-        fnTester.processBundle("foo", "bar", "baz", "foobar", "foobaz", 
"barbaz");
-    assertThat(
-        Iterables.transform(
-            outputs,
-            new Function<KV<Integer, String>, Integer>() {
-              @Override
-              public Integer apply(KV<Integer, String> input) {
-                return input.getKey();
-              }
-            }),
-        containsInAnyOrder(0, 0, 1, 1, 2, 2));
+    List<Integer> outputs = fnTester.processBundle(5L);
+    assertThat(outputs, containsInAnyOrder(3));
   }
 
   @Test
   public void keyBasedOnCountFnManyElements() throws Exception {
-    PCollectionView<Long> elementCountView =
-        PCollectionViews.singletonView(
-            p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
-    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 
0);
-    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
-
-    double count = Math.pow(10, 10);
-    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) 
count);
+    DoFn<Long, Integer> fn = new CalculateShardsFn(0);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
 
-    List<String> strings = new ArrayList<>();
-    for (int i = 0; i < 100; i++) {
-      strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
-    }
-    List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
-    long maxKey = -1L;
-    for (KV<Integer, String> kv : kvs) {
-      maxKey = Math.max(maxKey, kv.getKey());
-    }
-    assertThat(maxKey, equalTo(9L));
+    List<Integer> shard = fnTester.processBundle((long) Math.pow(10, 10));
+    assertThat(shard, containsInAnyOrder(10));
   }
 
   @Test
@@ -217,46 +173,25 @@ public class WriteWithShardingFactoryTest {
     PCollectionView<Long> elementCountView =
         PCollectionViews.singletonView(
             p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
-    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 
10);
-    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+    CalculateShardsFn fn = new CalculateShardsFn(3);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
 
-    long countValue = (long) KeyBasedOnCountFn.MIN_SHARDS_FOR_LOG + 3;
+    long countValue = (long) WriteWithShardingFactory.MIN_SHARDS_FOR_LOG + 3;
     fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, countValue);
 
-    List<String> strings = new ArrayList<>();
-    for (int i = 0; i < 100; i++) {
-      strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
-    }
-    List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
-    long maxKey = -1L;
-    for (KV<Integer, String> kv : kvs) {
-      maxKey = Math.max(maxKey, kv.getKey());
-    }
-    // 0 to n-1 shard ids.
-    assertThat(maxKey, equalTo(countValue - 1));
+    List<Integer> kvs = fnTester.processBundle(10L);
+    assertThat(kvs, containsInAnyOrder(6));
   }
 
   @Test
   public void keyBasedOnCountFnManyElementsExtraShards() throws Exception {
-    PCollectionView<Long> elementCountView =
-        PCollectionViews.singletonView(
-            p, WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
-    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 
3);
-    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+    CalculateShardsFn fn = new CalculateShardsFn(3);
+    DoFnTester<Long, Integer> fnTester = DoFnTester.of(fn);
 
     double count = Math.pow(10, 10);
-    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) 
count);
 
-    List<String> strings = new ArrayList<>();
-    for (int i = 0; i < 100; i++) {
-      strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
-    }
-    List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
-    long maxKey = -1L;
-    for (KV<Integer, String> kv : kvs) {
-      maxKey = Math.max(maxKey, kv.getKey());
-    }
-    assertThat(maxKey, equalTo(12L));
+    List<Integer> shards = fnTester.processBundle((long) count);
+    assertThat(shards, containsInAnyOrder(13));
   }
 
   @Test

Reply via email to