This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 2a95325ab860d189ce0a79eb512db04c60abdcf0 Author: Yun Gao <gaoyunhen...@gmail.com> AuthorDate: Sat Oct 30 22:12:52 2021 +0800 [hotfix][iteration] Do not wrap the broadcast partitioner --- .../operator/allround/AllRoundOperatorWrapper.java | 6 ++++ .../itcases/UnboundedStreamIterationITCase.java | 33 +++++++++++++++++----- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java index 1a6f65d..c28acd1 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java @@ -31,6 +31,7 @@ import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperatorParameters; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.util.OutputTag; @@ -64,6 +65,11 @@ public class AllRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR @Override public StreamPartitioner<IterationRecord<T>> wrapStreamPartitioner( StreamPartitioner<T> streamPartitioner) { + // Do not wrap the BroadcastPartitioner since it executes differently. + if (streamPartitioner instanceof BroadcastPartitioner) { + return new BroadcastPartitioner<>(); + } + return new ProxyStreamPartitioner<>(streamPartitioner); } diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java index c133fe1..e16bf02 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/itcases/UnboundedStreamIterationITCase.java @@ -82,7 +82,7 @@ public class UnboundedStreamIterationITCase extends TestLogger { @Test(timeout = 60000) public void testVariableOnlyUnboundedIteration() throws Exception { // Create the test job - JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, true, 0, false, 1, result); + JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, true, 0, false, 1, false, result); miniCluster.submitJob(jobGraph); // Expected records is round * parallelism * numRecordsPerSource @@ -94,7 +94,7 @@ public class UnboundedStreamIterationITCase extends TestLogger { @Test(timeout = 60000) public void testVariableOnlyBoundedIteration() throws Exception { // Create the test job - JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, false, 1, result); + JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, false, 1, false, result); miniCluster.executeJobBlocking(jobGraph); assertEquals(8001, result.get().size()); @@ -107,6 +107,22 @@ public class UnboundedStreamIterationITCase extends TestLogger { } @Test(timeout = 60000) + public void testVariableOnlyBoundedIterationWithBroadcast() throws Exception { + // Create the test job + JobGraph jobGraph = createVariableOnlyJobGraph(4, 1000, false, 0, false, 1, true, result); + miniCluster.executeJobBlocking(jobGraph); + + assertEquals(8001, result.get().size()); + + // Expected records is round * parallelism * numRecordsPerSource * parallelism of reduce + // operators + Map<Integer, Tuple2<Integer, Integer>> roundsStat = + computeRoundStat(result.get(), 2 * 4 * 1000 * 1); + verifyResult(roundsStat, 2, 4000, 4 * (0 + 999) * 1000 / 2); + assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent()); + } + + @Test(timeout = 60000) public void testVariableAndConstantsUnboundedIteration() throws Exception { // Create the test job JobGraph jobGraph = createVariableAndConstantJobGraph(4, 1000, true, 0, false, 1, result); @@ -150,6 +166,7 @@ public class UnboundedStreamIterationITCase extends TestLogger { int period, boolean sync, int maxRound, + boolean doBroadcast, SharedReference<BlockingQueue<OutputRecord<Integer>>> result) { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); @@ -161,12 +178,14 @@ public class UnboundedStreamIterationITCase extends TestLogger { DataStreamList.of(source), DataStreamList.of(), (variableStreams, dataStreams) -> { + DataStream<EpochRecord> variable = variableStreams.get(0); + if (doBroadcast) { + variable = variable.broadcast(); + } + SingleOutputStreamOperator<EpochRecord> reducer = - variableStreams - .<EpochRecord>get(0) - .process( - new ReduceAllRoundProcessFunction( - sync, maxRound)); + variable.process( + new ReduceAllRoundProcessFunction(sync, maxRound)); return new IterationBodyResult( DataStreamList.of( reducer.map(new IncrementEpochMap())