This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 2a95325ab860d189ce0a79eb512db04c60abdcf0
Author: Yun Gao <gaoyunhen...@gmail.com>
AuthorDate: Sat Oct 30 22:12:52 2021 +0800

    [hotfix][iteration] Do not wrap the broadcast partitioner
---
 .../operator/allround/AllRoundOperatorWrapper.java |  6 ++++
 .../itcases/UnboundedStreamIterationITCase.java    | 33 +++++++++++++++++-----
 2 files changed, 32 insertions(+), 7 deletions(-)

diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
index 1a6f65d..c28acd1 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java
@@ -31,6 +31,7 @@ import 
org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.util.OutputTag;
 
@@ -64,6 +65,11 @@ public class AllRoundOperatorWrapper<T> implements 
OperatorWrapper<T, IterationR
     @Override
     public StreamPartitioner<IterationRecord<T>> wrapStreamPartitioner(
             StreamPartitioner<T> streamPartitioner) {
+        // Do not wrap the BroadcastPartitioner since it executes differently.
+        if (streamPartitioner instanceof BroadcastPartitioner) {
+            return new BroadcastPartitioner<>();
+        }
+
         return new ProxyStreamPartitioner<>(streamPartitioner);
     }
 
diff --git 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java
 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java
index c133fe1..e16bf02 100644
--- 
a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java
+++ 
b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java
@@ -82,7 +82,7 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
     @Test(timeout = 60000)
     public void testVariableOnlyUnboundedIteration() throws Exception {
         // Create the test job
-        JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, true, 0, 
false, 1, result);
+        JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, true, 0, 
false, 1, false, result);
         miniCluster.submitJob(jobGraph);
 
         // Expected records is round * parallelism * numRecordsPerSource
@@ -94,7 +94,7 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
     @Test(timeout = 60000)
     public void testVariableOnlyBoundedIteration() throws Exception {
         // Create the test job
-        JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, 
false, 1, result);
+        JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, 
false, 1, false, result);
         miniCluster.executeJobBlocking(jobGraph);
 
         assertEquals(8001, result.get().size());
@@ -107,6 +107,22 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
     }
 
     @Test(timeout = 60000)
+    public void testVariableOnlyBoundedIterationWithBroadcast() throws 
Exception {
+        // Create the test job
+        JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, 
false, 1, true, result);
+        miniCluster.executeJobBlocking(jobGraph);
+
+        assertEquals(8001, result.get().size());
+
+        // Expected records is round * parallelism * numRecordsPerSource * 
parallelism of reduce
+        // operators
+        Map<Integer, Tuple2<Integer, Integer>> roundsStat =
+                computeRoundStat(result.get(), 2 * 4 * 1000 * 1);
+        verifyResult(roundsStat, 2, 4000, 4 * (0 + 999) * 1000 / 2);
+        assertEquals(OutputRecord.Event.TERMINATED, 
result.get().take().getEvent());
+    }
+
+    @Test(timeout = 60000)
     public void testVariableAndConstantsUnboundedIteration() throws Exception {
         // Create the test job
         JobGraph jobGraph = createVariableAndConstantJobGraph(4, 1000, true, 
0, false, 1, result);
@@ -150,6 +166,7 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
             int period,
             boolean sync,
             int maxRound,
+            boolean doBroadcast,
             SharedReference<BlockingQueue<OutputRecord<Integer>>> result) {
         StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
         env.setParallelism(1);
@@ -161,12 +178,14 @@ public class UnboundedStreamIterationITCase extends 
TestLogger {
                         DataStreamList.of(source),
                         DataStreamList.of(),
                         (variableStreams, dataStreams) -> {
+                            DataStream<EpochRecord> variable = 
variableStreams.get(0);
+                            if (doBroadcast) {
+                                variable = variable.broadcast();
+                            }
+
                             SingleOutputStreamOperator<EpochRecord> reducer =
-                                    variableStreams
-                                            .<EpochRecord>get(0)
-                                            .process(
-                                                    new 
ReduceAllRoundProcessFunction(
-                                                            sync, maxRound));
+                                    variable.process(
+                                            new 
ReduceAllRoundProcessFunction(sync, maxRound));
                             return new IterationBodyResult(
                                     DataStreamList.of(
                                             reducer.map(new 
IncrementEpochMap())

Reply via email to