This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 80fd4dfb [FLINK-30933] Fix missing max watermark when executing join in iteration body 80fd4dfb is described below commit 80fd4dfb843aee1d9cfd93130cfff016a9966b7b Author: Zhipeng Zhang <zhangzhipe...@gmail.com> AuthorDate: Wed Mar 15 14:36:47 2023 +0800 [FLINK-30933] Fix missing max watermark when executing join in iteration body This closes #206. --- .../flink/iteration/operator/HeadOperator.java | 2 +- .../flink/iteration/operator/OutputOperator.java | 5 + .../BoundedPerRoundStreamIterationITCase.java | 115 ++++++++++++++++++++- 3 files changed, 116 insertions(+), 6 deletions(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java index bdbe657a..e5238929 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java @@ -566,7 +566,7 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>> private MailboxExecutorWithYieldTimeout(MailboxExecutor mailboxExecutor) { this.mailboxExecutor = mailboxExecutor; - this.timer = new Timer(); + this.timer = new Timer(true); } @Override diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java index a584c5f4..d0e69712 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java @@ -19,9 +19,11 @@ package org.apache.flink.iteration.operator; import org.apache.flink.iteration.IterationRecord; +import org.apache.flink.iteration.IterationRecord.Type; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; /** @@ -48,6 +50,9 @@ public class OutputOperator<T> extends AbstractStreamOperator<T> if (streamRecord.getValue().getType() == IterationRecord.Type.RECORD) { reusable.replace(streamRecord.getValue().getValue(), streamRecord.getTimestamp()); output.collect(reusable); + } else if (streamRecord.getValue().getType() == Type.EPOCH_WATERMARK + && streamRecord.getValue().getEpoch() == Integer.MAX_VALUE) { + output.emitWatermark(new Watermark(Long.MAX_VALUE)); } } } diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java index 5f453b72..6b79b66c 100644 --- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java @@ -18,18 +18,28 @@ package org.apache.flink.test.iteration; +import org.apache.flink.api.common.functions.JoinFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; import org.apache.flink.iteration.IterationBodyResult; import org.apache.flink.iteration.IterationConfig; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIter; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.minicluster.MiniCluster; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.test.iteration.operators.CollectSink; import org.apache.flink.test.iteration.operators.EpochRecord; import org.apache.flink.test.iteration.operators.OutputRecord; @@ -61,14 +71,16 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { private MiniCluster miniCluster; - private SharedReference<BlockingQueue<OutputRecord<Integer>>> result; + private SharedReference<BlockingQueue<OutputRecord<Integer>>> collectedOutputRecord; + private SharedReference<BlockingQueue<Long>> collectedWatermarks; @Before public void setup() throws Exception { miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2)); miniCluster.start(); - result = sharedObjects.add(new LinkedBlockingQueue<>()); + collectedOutputRecord = sharedObjects.add(new LinkedBlockingQueue<>()); + collectedWatermarks = sharedObjects.add(new LinkedBlockingQueue<>()); } @After @@ -80,15 +92,50 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { @Test public void testPerRoundIteration() throws Exception { - JobGraph jobGraph = createPerRoundJobGraph(4, 1000, 5, result); + JobGraph jobGraph = createPerRoundJobGraph(4, 1000, 5, collectedOutputRecord); miniCluster.executeJobBlocking(jobGraph); - assertEquals(5, result.get().size()); + assertEquals(5, collectedOutputRecord.get().size()); Map<Integer, Tuple2<Integer, Integer>> roundsStat = - computeRoundStat(result.get(), OutputRecord.Event.TERMINATED, 5); + computeRoundStat(collectedOutputRecord.get(), OutputRecord.Event.TERMINATED, 5); verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2); } + @Test + public void testPerRoundIterationWithJoin() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + DataStream<Tuple2<Long, Integer>> input1 = env.fromElements(Tuple2.of(1L, 1)); + + DataStream<Tuple2<Long, Long>> input2 = env.fromElements(Tuple2.of(1L, 2L)); + + DataStream<Tuple2<Long, Long>> iterationWithJoinResult = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(input1), + ReplayableDataStreamList.replay(input2), + IterationConfig.newBuilder() + .setOperatorLifeCycle( + IterationConfig.OperatorLifeCycle.PER_ROUND) + .build(), + new IterationBodyWithJoin()) + .get(0); + DataStream<Long> watermarks = + iterationWithJoinResult.transform( + "CollectingWatermark", Types.LONG, new CollectingWatermark()); + + watermarks.addSink(new LongSink(collectedWatermarks)); + + JobGraph graph = env.getStreamGraph().getJobGraph(); + miniCluster.executeJobBlocking(graph); + + assertEquals(env.getParallelism(), collectedWatermarks.get().size()); + collectedWatermarks + .get() + .iterator() + .forEachRemaining(x -> assertEquals(Long.MAX_VALUE, (long) x)); + } + private static JobGraph createPerRoundJobGraph( int numSources, int numRecordsPerSource, @@ -148,4 +195,62 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { return env.getStreamGraph().getJobGraph(); } + + private static class IterationBodyWithJoin implements IterationBody { + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<Tuple2<Long, Integer>> input1 = variableStreams.get(0); + DataStream<Tuple2<Long, Long>> input2 = dataStreams.get(0); + + DataStream<Long> terminationCriteria = input1.flatMap(new TerminateOnMaxIter(1)); + + DataStream<Tuple2<Long, Long>> res = + input1.join(input2) + .where(x -> x.f0) + .equalTo(x -> x.f0) + .window(EndOfStreamWindows.get()) + .apply( + new JoinFunction< + Tuple2<Long, Integer>, + Tuple2<Long, Long>, + Tuple2<Long, Long>>() { + @Override + public Tuple2<Long, Long> join( + Tuple2<Long, Integer> longIntegerTuple2, + Tuple2<Long, Long> longLongTuple2) { + return longLongTuple2; + } + }); + + return new IterationBodyResult( + DataStreamList.of(input1), DataStreamList.of(res), terminationCriteria); + } + } + + private static class LongSink implements SinkFunction<Long> { + private final SharedReference<BlockingQueue<Long>> collectedLong; + + public LongSink(SharedReference<BlockingQueue<Long>> collectedLong) { + this.collectedLong = collectedLong; + } + + @Override + public void invoke(Long value, Context context) { + collectedLong.get().add(value); + } + } + + private static class CollectingWatermark extends AbstractStreamOperator<Long> + implements OneInputStreamOperator<Tuple2<Long, Long>, Long> { + + @Override + public void processElement(StreamRecord<Tuple2<Long, Long>> streamRecord) {} + + @Override + public void processWatermark(Watermark mark) throws Exception { + super.processWatermark(mark); + output.collect(new StreamRecord<>(mark.getTimestamp())); + } + } }