dmvk commented on a change in pull request #13353:
URL: https://github.com/apache/beam/pull/13353#discussion_r526641055



##########
File path: 
runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslatorTest.java
##########
@@ -120,4 +138,95 @@ public void testAutoBalanceShardKeyCacheMaxSize() throws 
Exception {
     assertThat(
         fn.getCache().size(), 
equalTo(FlinkAutoBalancedShardKeyShardingFunction.CACHE_MAX_SIZE));
   }
+
+  @Test
+  public void testStatefulParDoAfterCombineChaining() {
+    final JobGraph stablePartitioning = 
getStatefulParDoAfterCombineChainingJobGraph(true);
+    final JobGraph unstablePartitioning = 
getStatefulParDoAfterCombineChainingJobGraph(false);
+    // We expect an extra shuffle stage for unstable partitioning.
+    Assert.assertEquals(
+        1,
+        Iterables.size(unstablePartitioning.getVertices())
+            - Iterables.size(stablePartitioning.getVertices()));
+  }
+
+  private JobGraph getStatefulParDoAfterCombineChainingJobGraph(boolean 
stablePartitioning) {
+    final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+    final FlinkStreamingPipelineTranslator translator =
+        new FlinkStreamingPipelineTranslator(env, 
PipelineOptionsFactory.create());
+    final PipelineOptions pipelineOptions = PipelineOptionsFactory.create();
+    pipelineOptions.setRunner(FlinkRunner.class);
+    final Pipeline pipeline = Pipeline.create(pipelineOptions);
+    PCollection<KV<String, Long>> aggregate =
+        pipeline
+            .apply(Create.of("foo", "bar").withCoder(StringUtf8Coder.of()))
+            .apply(Count.perElement());
+    if (!stablePartitioning) {
+      // When we insert any element-wise "map" operation between aggregation 
and stateful ParDo, we
+      // can no longer assume that partitioning did not change, therefore we 
need an extra shuffle
+      aggregate = aggregate.apply(ParDo.of(new StatelessIdentityDoFn<>()));
+    }
+    aggregate.apply(ParDo.of(new StatefulNoopDoFn<>()));
+    translator.translate(pipeline);
+    return env.getStreamGraph().getJobGraph();
+  }
+
+  @Test
+  public void testStatefulParDoAfterGroupByKeyChaining() {
+    final JobGraph stablePartitioning = 
getStatefulParDoAfterGroupByKeyChainingJobGraph(true);
+    final JobGraph unstablePartitioning = 
getStatefulParDoAfterGroupByKeyChainingJobGraph(false);
+    // We expect an extra shuffle stage for unstable partitioning.
+    Assert.assertEquals(
+        1,
+        Iterables.size(unstablePartitioning.getVertices())
+            - Iterables.size(stablePartitioning.getVertices()));
+  }
+
+  private JobGraph getStatefulParDoAfterGroupByKeyChainingJobGraph(boolean 
stablePartitioning) {
+    final StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+    final FlinkStreamingPipelineTranslator translator =
+        new FlinkStreamingPipelineTranslator(env, 
PipelineOptionsFactory.create());
+    final PipelineOptions pipelineOptions = PipelineOptionsFactory.create();
+    pipelineOptions.setRunner(FlinkRunner.class);
+    final Pipeline pipeline = Pipeline.create(pipelineOptions);
+    PCollection<KV<String, Iterable<Long>>> aggregate =
+        pipeline
+            .apply(
+                Create.of(KV.of("foo", 1L), KV.of("bar", 1L))
+                    .withCoder(KvCoder.of(StringUtf8Coder.of(), 
VarLongCoder.of())))
+            .apply(GroupByKey.create());

Review comment:
       I thought about it, but the methods are really small so it's worth 
decreasing readability here




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to