This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 80fd4dfb [FLINK-30933] Fix missing max watermark when executing join 
in iteration body
80fd4dfb is described below

commit 80fd4dfb843aee1d9cfd93130cfff016a9966b7b
Author: Zhipeng Zhang <zhangzhipe...@gmail.com>
AuthorDate: Wed Mar 15 14:36:47 2023 +0800

    [FLINK-30933] Fix missing max watermark when executing join in iteration 
body
    
    This closes #206.
---
 .../flink/iteration/operator/HeadOperator.java     |   2 +-
 .../flink/iteration/operator/OutputOperator.java   |   5 +
 .../BoundedPerRoundStreamIterationITCase.java      | 115 ++++++++++++++++++++-
 3 files changed, 116 insertions(+), 6 deletions(-)

diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index bdbe657a..e5238929 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -566,7 +566,7 @@ public class HeadOperator extends 
AbstractStreamOperator<IterationRecord<?>>
 
         private MailboxExecutorWithYieldTimeout(MailboxExecutor 
mailboxExecutor) {
             this.mailboxExecutor = mailboxExecutor;
-            this.timer = new Timer();
+            this.timer = new Timer(true);
         }
 
         @Override
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
index a584c5f4..d0e69712 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OutputOperator.java
@@ -19,9 +19,11 @@
 package org.apache.flink.iteration.operator;
 
 import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.IterationRecord.Type;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 /**
@@ -48,6 +50,9 @@ public class OutputOperator<T> extends 
AbstractStreamOperator<T>
         if (streamRecord.getValue().getType() == IterationRecord.Type.RECORD) {
             reusable.replace(streamRecord.getValue().getValue(), 
streamRecord.getTimestamp());
             output.collect(reusable);
+        } else if (streamRecord.getValue().getType() == Type.EPOCH_WATERMARK
+                && streamRecord.getValue().getEpoch() == Integer.MAX_VALUE) {
+            output.emitWatermark(new Watermark(Long.MAX_VALUE));
         }
     }
 }
diff --git 
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
 
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
index 5f453b72..6b79b66c 100644
--- 
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
+++ 
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
@@ -18,18 +18,28 @@
 
 package org.apache.flink.test.iteration;
 
+import org.apache.flink.api.common.functions.JoinFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
 import org.apache.flink.iteration.IterationBodyResult;
 import org.apache.flink.iteration.IterationConfig;
 import org.apache.flink.iteration.Iterations;
 import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.minicluster.MiniCluster;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.test.iteration.operators.CollectSink;
 import org.apache.flink.test.iteration.operators.EpochRecord;
 import org.apache.flink.test.iteration.operators.OutputRecord;
@@ -61,14 +71,16 @@ public class BoundedPerRoundStreamIterationITCase extends 
TestLogger {
 
     private MiniCluster miniCluster;
 
-    private SharedReference<BlockingQueue<OutputRecord<Integer>>> result;
+    private SharedReference<BlockingQueue<OutputRecord<Integer>>> 
collectedOutputRecord;
+    private SharedReference<BlockingQueue<Long>> collectedWatermarks;
 
     @Before
     public void setup() throws Exception {
         miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2));
         miniCluster.start();
 
-        result = sharedObjects.add(new LinkedBlockingQueue<>());
+        collectedOutputRecord = sharedObjects.add(new LinkedBlockingQueue<>());
+        collectedWatermarks = sharedObjects.add(new LinkedBlockingQueue<>());
     }
 
     @After
@@ -80,15 +92,50 @@ public class BoundedPerRoundStreamIterationITCase extends 
TestLogger {
 
     @Test
     public void testPerRoundIteration() throws Exception {
-        JobGraph jobGraph = createPerRoundJobGraph(4, 1000, 5, result);
+        JobGraph jobGraph = createPerRoundJobGraph(4, 1000, 5, 
collectedOutputRecord);
         miniCluster.executeJobBlocking(jobGraph);
 
-        assertEquals(5, result.get().size());
+        assertEquals(5, collectedOutputRecord.get().size());
         Map<Integer, Tuple2<Integer, Integer>> roundsStat =
-                computeRoundStat(result.get(), OutputRecord.Event.TERMINATED, 
5);
+                computeRoundStat(collectedOutputRecord.get(), 
OutputRecord.Event.TERMINATED, 5);
         verifyResult(roundsStat, 5, 1, 4 * (0 + 999) * 1000 / 2);
     }
 
+    @Test
+    public void testPerRoundIterationWithJoin() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(4);
+
+        DataStream<Tuple2<Long, Integer>> input1 = 
env.fromElements(Tuple2.of(1L, 1));
+
+        DataStream<Tuple2<Long, Long>> input2 = env.fromElements(Tuple2.of(1L, 
2L));
+
+        DataStream<Tuple2<Long, Long>> iterationWithJoinResult =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                                DataStreamList.of(input1),
+                                ReplayableDataStreamList.replay(input2),
+                                IterationConfig.newBuilder()
+                                        .setOperatorLifeCycle(
+                                                
IterationConfig.OperatorLifeCycle.PER_ROUND)
+                                        .build(),
+                                new IterationBodyWithJoin())
+                        .get(0);
+        DataStream<Long> watermarks =
+                iterationWithJoinResult.transform(
+                        "CollectingWatermark", Types.LONG, new 
CollectingWatermark());
+
+        watermarks.addSink(new LongSink(collectedWatermarks));
+
+        JobGraph graph = env.getStreamGraph().getJobGraph();
+        miniCluster.executeJobBlocking(graph);
+
+        assertEquals(env.getParallelism(), collectedWatermarks.get().size());
+        collectedWatermarks
+                .get()
+                .iterator()
+                .forEachRemaining(x -> assertEquals(Long.MAX_VALUE, (long) x));
+    }
+
     private static JobGraph createPerRoundJobGraph(
             int numSources,
             int numRecordsPerSource,
@@ -148,4 +195,62 @@ public class BoundedPerRoundStreamIterationITCase extends 
TestLogger {
 
         return env.getStreamGraph().getJobGraph();
     }
+
+    private static class IterationBodyWithJoin implements IterationBody {
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<Long, Integer>> input1 = variableStreams.get(0);
+            DataStream<Tuple2<Long, Long>> input2 = dataStreams.get(0);
+
+            DataStream<Long> terminationCriteria = input1.flatMap(new 
TerminateOnMaxIter(1));
+
+            DataStream<Tuple2<Long, Long>> res =
+                    input1.join(input2)
+                            .where(x -> x.f0)
+                            .equalTo(x -> x.f0)
+                            .window(EndOfStreamWindows.get())
+                            .apply(
+                                    new JoinFunction<
+                                            Tuple2<Long, Integer>,
+                                            Tuple2<Long, Long>,
+                                            Tuple2<Long, Long>>() {
+                                        @Override
+                                        public Tuple2<Long, Long> join(
+                                                Tuple2<Long, Integer> 
longIntegerTuple2,
+                                                Tuple2<Long, Long> 
longLongTuple2) {
+                                            return longLongTuple2;
+                                        }
+                                    });
+
+            return new IterationBodyResult(
+                    DataStreamList.of(input1), DataStreamList.of(res), 
terminationCriteria);
+        }
+    }
+
+    private static class LongSink implements SinkFunction<Long> {
+        private final SharedReference<BlockingQueue<Long>> collectedLong;
+
+        public LongSink(SharedReference<BlockingQueue<Long>> collectedLong) {
+            this.collectedLong = collectedLong;
+        }
+
+        @Override
+        public void invoke(Long value, Context context) {
+            collectedLong.get().add(value);
+        }
+    }
+
+    private static class CollectingWatermark extends 
AbstractStreamOperator<Long>
+            implements OneInputStreamOperator<Tuple2<Long, Long>, Long> {
+
+        @Override
+        public void processElement(StreamRecord<Tuple2<Long, Long>> 
streamRecord) {}
+
+        @Override
+        public void processWatermark(Watermark mark) throws Exception {
+            super.processWatermark(mark);
+            output.collect(new StreamRecord<>(mark.getTimestamp()));
+        }
+    }
 }

Reply via email to