This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new a04a5aef [FLINK-31373] Fix NPE thrown in ProxyOutput
a04a5aef is described below

commit a04a5aef0ae7fc1f0bcb1df0766842eaf03f0d68
Author: JiangXin <jiangxin.ji...@alibaba-inc.com>
AuthorDate: Thu Apr 20 15:04:52 2023 +0800

    [FLINK-31373] Fix NPE thrown in ProxyOutput
    
    This closes #233.
---
 .../allround/AbstractAllRoundWrapperOperator.java  |  4 ++
 .../MultipleInputAllRoundWrapperOperator.java      |  2 +
 .../allround/OneInputAllRoundWrapperOperator.java  |  2 +
 .../allround/TwoInputAllRoundWrapperOperator.java  |  2 +
 .../BoundedAllRoundStreamIterationITCase.java      | 64 ++++++++++++++++++++++
 5 files changed, 74 insertions(+)

diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index d3b46ea9..bd503e9b 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -181,12 +181,16 @@ public abstract class AbstractAllRoundWrapperOperator<T, 
S extends StreamOperato
 
     @Override
     public void finish() throws Exception {
+        setIterationContextRound(Integer.MAX_VALUE);
         wrappedOperator.finish();
+        clearIterationContextRound();
     }
 
     @Override
     public void close() throws Exception {
+        setIterationContextRound(Integer.MAX_VALUE);
         wrappedOperator.close();
+        clearIterationContextRound();
     }
 
     @Override
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
index 4d942989..949419f3 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
@@ -83,7 +83,9 @@ public class MultipleInputAllRoundWrapperOperator<OUT>
         super.endInput(i);
 
         if (wrappedOperator instanceof BoundedMultiInput) {
+            setIterationContextRound(Integer.MAX_VALUE);
             ((BoundedMultiInput) wrappedOperator).endInput(i);
+            clearIterationContextRound();
         }
     }
 
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
index 7c725d9f..bbfd0b30 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java
@@ -87,7 +87,9 @@ public class OneInputAllRoundWrapperOperator<IN, OUT>
     @Override
     public void endInput() throws Exception {
         if (wrappedOperator instanceof BoundedOneInput) {
+            setIterationContextRound(Integer.MAX_VALUE);
             ((BoundedOneInput) wrappedOperator).endInput();
+            clearIterationContextRound();
         }
     }
 }
diff --git 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
index bedcccfc..a91f6c43 100644
--- 
a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
+++ 
b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
@@ -116,7 +116,9 @@ public class TwoInputAllRoundWrapperOperator<IN1, IN2, OUT>
         super.endInput(i);
 
         if (wrappedOperator instanceof BoundedMultiInput) {
+            setIterationContextRound(Integer.MAX_VALUE);
             ((BoundedMultiInput) wrappedOperator).endInput(i);
+            clearIterationContextRound();
         }
     }
 }
diff --git 
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
 
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
index 5768e880..6deda501 100644
--- 
a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
+++ 
b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.test.iteration;
 
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.iteration.DataStreamList;
 import org.apache.flink.iteration.IterationBody;
@@ -32,6 +33,10 @@ import org.apache.flink.runtime.minicluster.MiniCluster;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.test.iteration.operators.CollectSink;
 import org.apache.flink.test.iteration.operators.EpochRecord;
 import org.apache.flink.test.iteration.operators.IncrementEpochMap;
@@ -44,6 +49,7 @@ import org.apache.flink.testutils.junit.SharedReference;
 import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.TestLogger;
 
+import org.apache.commons.collections.IteratorUtils;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
@@ -51,6 +57,7 @@ import org.junit.Test;
 
 import javax.annotation.Nullable;
 
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
@@ -154,6 +161,37 @@ public class BoundedAllRoundStreamIterationITCase extends 
TestLogger {
         assertEquals(OutputRecord.Event.TERMINATED, 
result.get().take().getEvent());
     }
 
+    @Test
+    public void testBoundedIterationWithEndInput() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+        env.getConfig().enableObjectReuse();
+
+        DataStream<Integer> inputStream = env.fromElements(1, 2, 3);
+
+        DataStreamList outputs =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(inputStream),
+                        ReplayableDataStreamList.replay(inputStream),
+                        IterationConfig.newBuilder().build(),
+                        (variableStreams, dataStreams) -> {
+                            DataStream<Integer> variables = 
variableStreams.get(0);
+                            DataStream<Integer> result =
+                                    dataStreams
+                                            .<Integer>get(0)
+                                            .transform(
+                                                    "sum",
+                                                    
BasicTypeInfo.INT_TYPE_INFO,
+                                                    new SumOperator());
+                            return new IterationBodyResult(
+                                    DataStreamList.of(variables),
+                                    DataStreamList.of(result),
+                                    variables.flatMap(new 
TerminateOnMaxIter<>(10)));
+                        });
+        List<Integer> result = 
IteratorUtils.toList(outputs.get(0).executeAndCollect());
+        result.forEach(r -> r.equals(60));
+    }
+
     private static JobGraph createVariableOnlyJobGraph(
             int numSources,
             int numRecordsPerSource,
@@ -277,4 +315,30 @@ public class BoundedAllRoundStreamIterationITCase extends 
TestLogger {
 
         return env.getStreamGraph().getJobGraph();
     }
+
+    private static class SumOperator extends AbstractStreamOperator<Integer>
+            implements OneInputStreamOperator<Integer, Integer>, 
BoundedOneInput {
+
+        private int sum = 0;
+
+        @Override
+        public void processElement(StreamRecord<Integer> element) {
+            sum += element.getValue();
+        }
+
+        @Override
+        public void endInput() {
+            output.collect(new StreamRecord<>(sum));
+        }
+
+        @Override
+        public void finish() {
+            output.collect(new StreamRecord<>(sum));
+        }
+
+        @Override
+        public void close() {
+            output.collect(new StreamRecord<>(sum));
+        }
+    }
 }

Reply via email to