This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new a04a5aef [FLINK-31373] Fix NPE thrown in ProxyOutput a04a5aef is described below commit a04a5aef0ae7fc1f0bcb1df0766842eaf03f0d68 Author: JiangXin <jiangxin.ji...@alibaba-inc.com> AuthorDate: Thu Apr 20 15:04:52 2023 +0800 [FLINK-31373] Fix NPE thrown in ProxyOutput This closes #233. --- .../allround/AbstractAllRoundWrapperOperator.java | 4 ++ .../MultipleInputAllRoundWrapperOperator.java | 2 + .../allround/OneInputAllRoundWrapperOperator.java | 2 + .../allround/TwoInputAllRoundWrapperOperator.java | 2 + .../BoundedAllRoundStreamIterationITCase.java | 64 ++++++++++++++++++++++ 5 files changed, 74 insertions(+) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java index d3b46ea9..bd503e9b 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java @@ -181,12 +181,16 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato @Override public void finish() throws Exception { + setIterationContextRound(Integer.MAX_VALUE); wrappedOperator.finish(); + clearIterationContextRound(); } @Override public void close() throws Exception { + setIterationContextRound(Integer.MAX_VALUE); wrappedOperator.close(); + clearIterationContextRound(); } @Override diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java index 4d942989..949419f3 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java @@ -83,7 +83,9 @@ public class MultipleInputAllRoundWrapperOperator<OUT> super.endInput(i); if (wrappedOperator instanceof BoundedMultiInput) { + setIterationContextRound(Integer.MAX_VALUE); ((BoundedMultiInput) wrappedOperator).endInput(i); + clearIterationContextRound(); } } diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java index 7c725d9f..bbfd0b30 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java @@ -87,7 +87,9 @@ public class OneInputAllRoundWrapperOperator<IN, OUT> @Override public void endInput() throws Exception { if (wrappedOperator instanceof BoundedOneInput) { + setIterationContextRound(Integer.MAX_VALUE); ((BoundedOneInput) wrappedOperator).endInput(); + clearIterationContextRound(); } } } diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java index bedcccfc..a91f6c43 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java @@ -116,7 +116,9 @@ public class TwoInputAllRoundWrapperOperator<IN1, IN2, OUT> super.endInput(i); if (wrappedOperator instanceof BoundedMultiInput) { + setIterationContextRound(Integer.MAX_VALUE); ((BoundedMultiInput) wrappedOperator).endInput(i); + clearIterationContextRound(); } } } diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java index 5768e880..6deda501 100644 --- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java @@ -18,6 +18,7 @@ package org.apache.flink.test.iteration; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.iteration.DataStreamList; import org.apache.flink.iteration.IterationBody; @@ -32,6 +33,10 @@ import org.apache.flink.runtime.minicluster.MiniCluster; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.test.iteration.operators.CollectSink; import org.apache.flink.test.iteration.operators.EpochRecord; import org.apache.flink.test.iteration.operators.IncrementEpochMap; @@ -44,6 +49,7 @@ import org.apache.flink.testutils.junit.SharedReference; import org.apache.flink.util.OutputTag; import org.apache.flink.util.TestLogger; +import org.apache.commons.collections.IteratorUtils; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -51,6 +57,7 @@ import org.junit.Test; import javax.annotation.Nullable; +import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; @@ -154,6 +161,37 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger { assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent()); } + @Test + public void testBoundedIterationWithEndInput() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(1); + env.getConfig().enableObjectReuse(); + + DataStream<Integer> inputStream = env.fromElements(1, 2, 3); + + DataStreamList outputs = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of(inputStream), + ReplayableDataStreamList.replay(inputStream), + IterationConfig.newBuilder().build(), + (variableStreams, dataStreams) -> { + DataStream<Integer> variables = variableStreams.get(0); + DataStream<Integer> result = + dataStreams + .<Integer>get(0) + .transform( + "sum", + BasicTypeInfo.INT_TYPE_INFO, + new SumOperator()); + return new IterationBodyResult( + DataStreamList.of(variables), + DataStreamList.of(result), + variables.flatMap(new TerminateOnMaxIter<>(10))); + }); + List<Integer> result = IteratorUtils.toList(outputs.get(0).executeAndCollect()); + result.forEach(r -> r.equals(60)); + } + private static JobGraph createVariableOnlyJobGraph( int numSources, int numRecordsPerSource, @@ -277,4 +315,30 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger { return env.getStreamGraph().getJobGraph(); } + + private static class SumOperator extends AbstractStreamOperator<Integer> + implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput { + + private int sum = 0; + + @Override + public void processElement(StreamRecord<Integer> element) { + sum += element.getValue(); + } + + @Override + public void endInput() { + output.collect(new StreamRecord<>(sum)); + } + + @Override + public void finish() { + output.collect(new StreamRecord<>(sum)); + } + + @Override + public void close() { + output.collect(new StreamRecord<>(sum)); + } + } }