robertwb commented on a change in pull request #13208:
URL: https://github.com/apache/beam/pull/13208#discussion_r513671483



##########
File path: 
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
##########
@@ -103,43 +156,76 @@ public void process(ProcessContext c) {
     }
 
     @Override
-    public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, 
Iterable<V>>>>
+    public PTransformReplacement<PCollection<KV<K, V>>, 
PCollection<KV<ShardedKey<K>, Iterable<V>>>>
         getReplacementTransform(
             AppliedPTransform<
-                    PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, 
GroupIntoBatches<K, V>>
+                    PCollection<KV<K, V>>,
+                    PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+                    GroupIntoBatches<K, V>.WithShardedKey>
                 transform) {
       return PTransformReplacement.of(
           PTransformReplacements.getSingletonMainInput(transform),
-          new StreamingGroupIntoBatches(runner, transform.getTransform()));
+          new StreamingGroupIntoBatchesWithShardedKey<>(runner, 
transform.getTransform()));
     }
 
     @Override
     public Map<PCollection<?>, ReplacementOutput> mapOutputs(
-        Map<TupleTag<?>, PCollection<?>> outputs, PCollection<KV<K, 
Iterable<V>>> newOutput) {
+        Map<TupleTag<?>, PCollection<?>> outputs,
+        PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
       return ReplacementOutputs.singleton(outputs, newOutput);
     }
   }
 
   /**
-   * Specialized implementation of {@link GroupIntoBatches} for unbounded 
Dataflow pipelines. The
-   * override does the same thing as the original transform but additionally 
record the input to add
-   * corresponding properties during the graph translation.
+   * Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for 
unbounded Dataflow
+   * pipelines. The override does the same thing as the original transform but 
additionally records
+   * the input of {@code GroupIntoBatchesDoFn} in order to append relevant 
step properties during
+   * the graph translation.
    */
-  static class StreamingGroupIntoBatches<K, V>
-      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, 
Iterable<V>>>> {
+  static class StreamingGroupIntoBatchesWithShardedKey<K, V>
+      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, 
Iterable<V>>>> {
 
     private final transient DataflowRunner runner;
-    private final GroupIntoBatches<K, V> original;
+    private final GroupIntoBatches<K, V>.WithShardedKey original;
 
-    public StreamingGroupIntoBatches(DataflowRunner runner, 
GroupIntoBatches<K, V> original) {
+    public StreamingGroupIntoBatchesWithShardedKey(
+        DataflowRunner runner, GroupIntoBatches<K, V>.WithShardedKey original) 
{
       this.runner = runner;
       this.original = original;
     }
 
     @Override
-    public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) 
{
-      runner.maybeRecordPCollectionWithAutoSharding(input);
-      return input.apply(original);
+    public PCollection<KV<ShardedKey<K>, Iterable<V>>> 
expand(PCollection<KV<K, V>> input) {
+      PCollection<KV<ShardedKey<K>, V>> intermediate_input = ShardKeys(input);
+
+      runner.maybeRecordPCollectionWithAutoSharding(intermediate_input);
+
+      if (original.getMaxBufferingDuration() != null) {

Review comment:
       This doesn't look like it'll scale if more options are used. Why not 
just apply original? 

##########
File path: 
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
##########
@@ -103,43 +156,76 @@ public void process(ProcessContext c) {
     }
 
     @Override
-    public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, 
Iterable<V>>>>
+    public PTransformReplacement<PCollection<KV<K, V>>, 
PCollection<KV<ShardedKey<K>, Iterable<V>>>>
         getReplacementTransform(
             AppliedPTransform<
-                    PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, 
GroupIntoBatches<K, V>>
+                    PCollection<KV<K, V>>,
+                    PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+                    GroupIntoBatches<K, V>.WithShardedKey>
                 transform) {
       return PTransformReplacement.of(
           PTransformReplacements.getSingletonMainInput(transform),
-          new StreamingGroupIntoBatches(runner, transform.getTransform()));
+          new StreamingGroupIntoBatchesWithShardedKey<>(runner, 
transform.getTransform()));
     }
 
     @Override
     public Map<PCollection<?>, ReplacementOutput> mapOutputs(
-        Map<TupleTag<?>, PCollection<?>> outputs, PCollection<KV<K, 
Iterable<V>>> newOutput) {
+        Map<TupleTag<?>, PCollection<?>> outputs,
+        PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
       return ReplacementOutputs.singleton(outputs, newOutput);
     }
   }
 
   /**
-   * Specialized implementation of {@link GroupIntoBatches} for unbounded 
Dataflow pipelines. The
-   * override does the same thing as the original transform but additionally 
record the input to add
-   * corresponding properties during the graph translation.
+   * Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for 
unbounded Dataflow
+   * pipelines. The override does the same thing as the original transform but 
additionally records
+   * the input of {@code GroupIntoBatchesDoFn} in order to append relevant 
step properties during
+   * the graph translation.
    */
-  static class StreamingGroupIntoBatches<K, V>
-      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, 
Iterable<V>>>> {
+  static class StreamingGroupIntoBatchesWithShardedKey<K, V>
+      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, 
Iterable<V>>>> {
 
     private final transient DataflowRunner runner;
-    private final GroupIntoBatches<K, V> original;
+    private final GroupIntoBatches<K, V>.WithShardedKey original;
 
-    public StreamingGroupIntoBatches(DataflowRunner runner, 
GroupIntoBatches<K, V> original) {
+    public StreamingGroupIntoBatchesWithShardedKey(
+        DataflowRunner runner, GroupIntoBatches<K, V>.WithShardedKey original) 
{
       this.runner = runner;
       this.original = original;
     }
 
     @Override
-    public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) 
{
-      runner.maybeRecordPCollectionWithAutoSharding(input);
-      return input.apply(original);
+    public PCollection<KV<ShardedKey<K>, Iterable<V>>> 
expand(PCollection<KV<K, V>> input) {
+      PCollection<KV<ShardedKey<K>, V>> intermediate_input = ShardKeys(input);
+
+      runner.maybeRecordPCollectionWithAutoSharding(intermediate_input);
+
+      if (original.getMaxBufferingDuration() != null) {
+        return intermediate_input.apply(
+            GroupIntoBatches.<ShardedKey<K>, V>ofSize(original.getBatchSize())
+                .withMaxBufferingDuration(original.getMaxBufferingDuration()));
+      } else {
+        return 
intermediate_input.apply(GroupIntoBatches.ofSize(original.getBatchSize()));
+      }
     }
   }
+
+  private static <K, V> PCollection<KV<ShardedKey<K>, V>> 
ShardKeys(PCollection<KV<K, V>> input) {

Review comment:
       Methods shouldn't be capitalized. 

##########
File path: 
sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
##########
@@ -105,23 +108,83 @@ public long getBatchSize() {
   }
 
   /**
-   * Set a time limit (in processing time) on how long an incomplete batch of 
elements is allowed to
-   * be buffered. Once a batch is flushed to output, the timer is reset.
+   * Sets a time limit (in processing time) on how long an incomplete batch of 
elements is allowed
+   * to be buffered. Once a batch is flushed to output, the timer is reset.
    */
   public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration 
duration) {
     checkArgument(
         duration.isLongerThan(Duration.ZERO), "max buffering duration should 
be a positive value");
     return new GroupIntoBatches<>(batchSize, duration);
   }
 
+  /**
+   * Outputs batched elements associated with sharded input keys. The sharding 
is determined by the
+   * runner to balance the load during the execution time. By default, apply 
no sharding so each key
+   * has one shard.
+   */
+  @Experimental
+  public WithShardedKey withShardedKey() {
+    return new WithShardedKey();
+  }
+
+  public class WithShardedKey
+      extends PTransform<
+          PCollection<KV<K, InputT>>, PCollection<KV<ShardedKey<K>, 
Iterable<InputT>>>> {
+
+    /** Returns the size of the batch. */
+    public long getBatchSize() {
+      return batchSize;
+    }
+
+    /** Returns the size of the batch. */
+    @Nullable
+    public Duration getMaxBufferingDuration() {
+      return maxBufferingDuration;
+    }
+
+    @Override
+    public PCollection<KV<ShardedKey<K>, Iterable<InputT>>> expand(
+        PCollection<KV<K, InputT>> input) {
+      Duration allowedLateness = 
input.getWindowingStrategy().getAllowedLateness();
+
+      checkArgument(
+          input.getCoder() instanceof KvCoder,
+          "coder specified in the input PCollection is not a KvCoder");
+      KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+      Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
+      Coder<InputT> valueCoder = (Coder<InputT>) 
inputCoder.getCoderArguments().get(1);
+
+      return input
+          .apply(
+              MapElements.via(
+                  new SimpleFunction<KV<K, InputT>, KV<ShardedKey<K>, 
InputT>>() {
+                    @Override
+                    public KV<ShardedKey<K>, InputT> apply(KV<K, InputT> 
input) {
+                      // By default every input key has only one shard.
+                      return KV.of(
+                          ShardedKey.of(input.getKey(), DEFAULT_SHARD_ID), 
input.getValue());

Review comment:
       A single subshard by default will make this virtually unusable for 
runners that don't implement the optimization (including batch Dataflow). 
Instead use something like the thread id here, or at the very lease initialize 
DEFAULT_SHARD_ID to be different for each worker and add a small nonce. We 
could alternatively take a hint as to the number of subshards that would be 
nice (but that has its own downsides). 

##########
File path: 
sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
##########
@@ -105,23 +108,83 @@ public long getBatchSize() {
   }
 
   /**
-   * Set a time limit (in processing time) on how long an incomplete batch of 
elements is allowed to
-   * be buffered. Once a batch is flushed to output, the timer is reset.
+   * Sets a time limit (in processing time) on how long an incomplete batch of 
elements is allowed
+   * to be buffered. Once a batch is flushed to output, the timer is reset.
    */
   public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration 
duration) {
     checkArgument(
         duration.isLongerThan(Duration.ZERO), "max buffering duration should 
be a positive value");
     return new GroupIntoBatches<>(batchSize, duration);
   }
 
+  /**
+   * Outputs batched elements associated with sharded input keys. The sharding 
is determined by the
+   * runner to balance the load during the execution time. By default, apply 
no sharding so each key
+   * has one shard.
+   */
+  @Experimental
+  public WithShardedKey withShardedKey() {
+    return new WithShardedKey();
+  }
+
+  public class WithShardedKey
+      extends PTransform<
+          PCollection<KV<K, InputT>>, PCollection<KV<ShardedKey<K>, 
Iterable<InputT>>>> {
+
+    /** Returns the size of the batch. */
+    public long getBatchSize() {
+      return batchSize;
+    }
+
+    /** Returns the size of the batch. */
+    @Nullable
+    public Duration getMaxBufferingDuration() {
+      return maxBufferingDuration;
+    }
+
+    @Override
+    public PCollection<KV<ShardedKey<K>, Iterable<InputT>>> expand(
+        PCollection<KV<K, InputT>> input) {
+      Duration allowedLateness = 
input.getWindowingStrategy().getAllowedLateness();
+
+      checkArgument(
+          input.getCoder() instanceof KvCoder,
+          "coder specified in the input PCollection is not a KvCoder");
+      KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+      Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
+      Coder<InputT> valueCoder = (Coder<InputT>) 
inputCoder.getCoderArguments().get(1);
+
+      return input
+          .apply(
+              MapElements.via(
+                  new SimpleFunction<KV<K, InputT>, KV<ShardedKey<K>, 
InputT>>() {
+                    @Override
+                    public KV<ShardedKey<K>, InputT> apply(KV<K, InputT> 
input) {
+                      // By default every input key has only one shard.
+                      return KV.of(
+                          ShardedKey.of(input.getKey(), DEFAULT_SHARD_ID), 
input.getValue());
+                    }
+                  }))
+          .setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder))
+          .apply(

Review comment:
       Alternatively one could apply the original GroupIntoBatches that this 
was derived from here.

##########
File path: 
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
##########
@@ -1281,8 +1286,12 @@ void addPCollectionRequiringIndexedFormat(PCollection<?> 
pcol) {
   }
 
   void maybeRecordPCollectionWithAutoSharding(PCollection<?> pcol) {
-    if (hasExperiment(options, "enable_streaming_auto_sharding")
-        && !hasExperiment(options, "beam_fn_api")) {
+    if (hasExperiment(options, "beam_fn_api")) {

Review comment:
       I think it makes the most sense here. 

##########
File path: 
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
##########
@@ -92,9 +96,58 @@ public void process(ProcessContext c) {
     }
   }
 
+  static class BatchGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>

Review comment:
       Do we need an override or is the default implementation good enough? 

##########
File path: 
sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
##########
@@ -105,23 +108,83 @@ public long getBatchSize() {
   }
 
   /**
-   * Set a time limit (in processing time) on how long an incomplete batch of 
elements is allowed to
-   * be buffered. Once a batch is flushed to output, the timer is reset.
+   * Sets a time limit (in processing time) on how long an incomplete batch of 
elements is allowed
+   * to be buffered. Once a batch is flushed to output, the timer is reset.
    */
   public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration 
duration) {
     checkArgument(
         duration.isLongerThan(Duration.ZERO), "max buffering duration should 
be a positive value");
     return new GroupIntoBatches<>(batchSize, duration);
   }
 
+  /**
+   * Outputs batched elements associated with sharded input keys. The sharding 
is determined by the
+   * runner to balance the load during the execution time. By default, apply 
no sharding so each key
+   * has one shard.
+   */
+  @Experimental
+  public WithShardedKey withShardedKey() {
+    return new WithShardedKey();
+  }
+
+  public class WithShardedKey

Review comment:
       Correct, we can't do this by default due to the coder change.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to