This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new a368ebb [FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body a368ebb is described below commit a368ebb17affae872f0ea9eb7bb9576fb56612ee Author: Yun Gao <gaoyunhen...@gmail.com> AuthorDate: Tue Nov 2 17:06:12 2021 +0800 [FLINK-24722][iteration] Fix the issues in supporting keyed stream inside the iteration body This closes #22. --- .../flink/iteration/operator/OperatorUtils.java | 23 +++++++++ .../allround/AbstractAllRoundWrapperOperator.java | 4 +- .../allround/OneInputAllRoundWrapperOperator.java | 6 ++- .../perround/AbstractPerRoundWrapperOperator.java | 10 ++-- .../flink/iteration/proxy/ProxyKeySelector.java | 4 ++ .../iteration/proxy/ProxyStreamPartitioner.java | 11 +++++ .../BoundedAllRoundStreamIterationITCase.java | 31 ++++++++++-- .../BoundedPerRoundStreamIterationITCase.java | 14 +++++- .../iteration/UnboundedStreamIterationITCase.java | 28 +++++++++-- .../operators/StatefulProcessFunction.java | 55 ++++++++++++++++++++++ 10 files changed, 171 insertions(+), 15 deletions(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java index 292a90e..25d200b 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java @@ -18,15 +18,18 @@ package org.apache.flink.iteration.operator; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.fs.Path; import org.apache.flink.iteration.IterationID; import org.apache.flink.iteration.config.IterationOptions; +import org.apache.flink.iteration.proxy.ProxyKeySelector; import org.apache.flink.iteration.utils.ReflectionUtils; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel; import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer; import org.apache.flink.statefun.flink.core.feedback.FeedbackKey; +import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.util.ExceptionUtils; @@ -39,6 +42,8 @@ import java.util.Random; import java.util.UUID; import java.util.concurrent.Executor; +import static org.apache.flink.util.Preconditions.checkState; + /** Utility class for operators. */ public class OperatorUtils { @@ -83,6 +88,24 @@ public class OperatorUtils { } } + public static StreamConfig createWrappedOperatorConfig(StreamConfig wrapperConfig) { + StreamConfig wrappedConfig = new StreamConfig(wrapperConfig.getConfiguration().clone()); + for (int i = 0; i < wrappedConfig.getNumberOfNetworkInputs(); ++i) { + KeySelector keySelector = + wrapperConfig.getStatePartitioner(i, OperatorUtils.class.getClassLoader()); + if (keySelector != null) { + checkState( + keySelector instanceof ProxyKeySelector, + "The state partitioner for the wrapper operator should always be ProxyKeySelector, but it is " + + keySelector); + wrappedConfig.setStatePartitioner( + i, ((ProxyKeySelector) keySelector).getWrappedKeySelector()); + } + } + + return wrappedConfig; + } + public static Path getDataCachePath(Configuration configuration, String[] localSpillPaths) { String pathStr = configuration.get(IterationOptions.DATA_CACHE_PATH); if (pathStr == null) { diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java index 0ea742c..180477c 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java @@ -29,6 +29,7 @@ import org.apache.flink.iteration.IterationListener; import org.apache.flink.iteration.IterationRecord; import org.apache.flink.iteration.operator.AbstractWrapperOperator; import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.metrics.groups.OperatorMetricGroup; import org.apache.flink.runtime.checkpoint.CheckpointOptions; @@ -84,7 +85,8 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato StreamOperatorFactoryUtil.<T, S>createOperator( operatorFactory, (StreamTask) parameters.getContainingTask(), - parameters.getStreamConfig(), + OperatorUtils.createWrappedOperatorConfig( + parameters.getStreamConfig()), proxyOutput, parameters.getOperatorEventDispatcher()) .f0; diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java index 6a2b9a0..7c725d9 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperator.java @@ -78,8 +78,10 @@ public class OneInputAllRoundWrapperOperator<IN, OUT> @Override public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record) throws Exception { - reusedInput.replace(record.getValue().getValue(), record.getTimestamp()); - wrappedOperator.setKeyContextElement(reusedInput); + if (record.getValue().getType() == IterationRecord.Type.RECORD) { + reusedInput.replace(record.getValue().getValue(), record.getTimestamp()); + wrappedOperator.setKeyContextElement(reusedInput); + } } @Override diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java index 3903340..cc4ac36 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java @@ -27,6 +27,7 @@ import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.iteration.IterationRecord; import org.apache.flink.iteration.operator.AbstractWrapperOperator; +import org.apache.flink.iteration.operator.OperatorUtils; import org.apache.flink.iteration.proxy.state.ProxyStateSnapshotContext; import org.apache.flink.iteration.proxy.state.ProxyStreamOperatorStateContext; import org.apache.flink.iteration.utils.ReflectionUtils; @@ -124,7 +125,8 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato StreamOperatorFactoryUtil.<T, S>createOperator( clonedOperatorFactory, (StreamTask) parameters.getContainingTask(), - parameters.getStreamConfig(), + OperatorUtils.createWrappedOperatorConfig( + parameters.getStreamConfig()), proxyOutput, parameters.getOperatorEventDispatcher()) .f0; @@ -294,7 +296,9 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector) throws Exception { - if (selector != null) { + if (selector != null + && ((IterationRecord<?>) record.getValue()).getType() + == IterationRecord.Type.RECORD) { Object key = selector.getKey(record.getValue()); setCurrentKey(key); } @@ -335,7 +339,7 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato return null; } - return stateHandler.getKeyedStateStore().orElse(null); + return stateHandler.getCurrentKey(); } protected void reportOrForwardLatencyMarker(LatencyMarker marker) { diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java index 1ac64ec..f3615f7 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyKeySelector.java @@ -30,6 +30,10 @@ public class ProxyKeySelector<T, KEY> implements KeySelector<IterationRecord<T>, this.wrappedKeySelector = wrappedKeySelector; } + public KeySelector<T, KEY> getWrappedKeySelector() { + return wrappedKeySelector; + } + @Override public KEY getKey(IterationRecord<T> record) throws Exception { return wrappedKeySelector.getKey(record.getValue()); diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java index 4accb32..525f12a 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/ProxyStreamPartitioner.java @@ -44,6 +44,12 @@ public class ProxyStreamPartitioner<T> extends StreamPartitioner<IterationRecord } @Override + public void setup(int numberOfChannels) { + super.setup(numberOfChannels); + wrappedStreamPartitioner.setup(numberOfChannels); + } + + @Override public StreamPartitioner<IterationRecord<T>> copy() { return new ProxyStreamPartitioner<>(wrappedStreamPartitioner.copy()); } @@ -87,4 +93,9 @@ public class ProxyStreamPartitioner<T> extends StreamPartitioner<IterationRecord return selectChannel(record); } } + + @Override + public String toString() { + return wrappedStreamPartitioner.toString(); + } } diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java index 5084c78..1b28374 100644 --- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundStreamIterationITCase.java @@ -37,6 +37,7 @@ import org.apache.flink.test.iteration.operators.IncrementEpochMap; import org.apache.flink.test.iteration.operators.OutputRecord; import org.apache.flink.test.iteration.operators.RoundBasedTerminationCriteria; import org.apache.flink.test.iteration.operators.SequenceSource; +import org.apache.flink.test.iteration.operators.StatefulProcessFunction; import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction; import org.apache.flink.testutils.junit.SharedObjects; import org.apache.flink.testutils.junit.SharedReference; @@ -130,6 +131,8 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger { // If termination criteria is created only with the constants streams, it would not have // records after the round 1 if the input is not replayed. int numOfRound = terminationCriteriaFollowsConstantsStreams ? 1 : 5; + assertEquals(numOfRound + 1, result.get().size()); + Map<Integer, Tuple2<Integer, Integer>> roundsStat = computeRoundStat( result.get(), OutputRecord.Event.EPOCH_WATERMARK_INCREMENTED, numOfRound); @@ -184,9 +187,19 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger { .process( new TwoInputReduceAllRoundProcessFunction( sync, maxRound)); + return new IterationBodyResult( DataStreamList.of( - reducer.map(new IncrementEpochMap()) + reducer.partitionCustom( + (k, numPartitions) -> k % numPartitions, + EpochRecord::getValue) + .map(x -> x) + .keyBy(EpochRecord::getValue) + .process( + new StatefulProcessFunction< + EpochRecord>() {}) + .setParallelism(4) + .map(new IncrementEpochMap()) .setParallelism(numSources)), DataStreamList.of( reducer.getSideOutput( @@ -237,10 +250,20 @@ public class BoundedAllRoundStreamIterationITCase extends TestLogger { .process( new TwoInputReduceAllRoundProcessFunction( sync, maxRound)); + + SingleOutputStreamOperator<EpochRecord> feedbackStream = + reducer.partitionCustom( + (k, numPartitions) -> k % numPartitions, + EpochRecord::getValue) + .map(x -> x) + .keyBy(EpochRecord::getValue) + .process(new StatefulProcessFunction<EpochRecord>() {}) + .setParallelism(4) + .map(new IncrementEpochMap()) + .setParallelism(numSources); + return new IterationBodyResult( - DataStreamList.of( - reducer.map(new IncrementEpochMap()) - .setParallelism(numSources)), + DataStreamList.of(feedbackStream), DataStreamList.of( reducer.getSideOutput( new OutputTag<OutputRecord<Integer>>( diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java index cb36d0c..8bc6f1f 100644 --- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java @@ -34,6 +34,7 @@ import org.apache.flink.test.iteration.operators.CollectSink; import org.apache.flink.test.iteration.operators.EpochRecord; import org.apache.flink.test.iteration.operators.OutputRecord; import org.apache.flink.test.iteration.operators.SequenceSource; +import org.apache.flink.test.iteration.operators.StatefulProcessFunction; import org.apache.flink.test.iteration.operators.TwoInputReducePerRoundOperator; import org.apache.flink.testutils.junit.SharedObjects; import org.apache.flink.testutils.junit.SharedReference; @@ -123,7 +124,18 @@ public class BoundedPerRoundStreamIterationITCase extends TestLogger { .setParallelism(1); return new IterationBodyResult( - DataStreamList.of(reducer.filter(x -> x < maxRound)), + DataStreamList.of( + reducer.partitionCustom( + (k, numPartitions) -> k % numPartitions, + x -> x) + .map(x -> x) + .keyBy(x -> x) + .process( + new StatefulProcessFunction< + Integer>() {}) + .setParallelism(4) + .filter(x -> x < maxRound) + .setParallelism(1)), DataStreamList.of( reducer.getSideOutput( TwoInputReducePerRoundOperator.OUTPUT_TAG)), diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java index 6d80f23..f3f2272 100644 --- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java @@ -38,6 +38,7 @@ import org.apache.flink.test.iteration.operators.IncrementEpochMap; import org.apache.flink.test.iteration.operators.OutputRecord; import org.apache.flink.test.iteration.operators.ReduceAllRoundProcessFunction; import org.apache.flink.test.iteration.operators.SequenceSource; +import org.apache.flink.test.iteration.operators.StatefulProcessFunction; import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction; import org.apache.flink.testutils.junit.SharedObjects; import org.apache.flink.testutils.junit.SharedReference; @@ -192,7 +193,16 @@ public class UnboundedStreamIterationITCase extends TestLogger { new ReduceAllRoundProcessFunction(sync, maxRound)); return new IterationBodyResult( DataStreamList.of( - reducer.map(new IncrementEpochMap()) + reducer.partitionCustom( + (k, numPartitions) -> k % numPartitions, + EpochRecord::getValue) + .map(x -> x) + .keyBy(EpochRecord::getValue) + .process( + new StatefulProcessFunction< + EpochRecord>() {}) + .setParallelism(4) + .map(new IncrementEpochMap()) .setParallelism(numSources)), DataStreamList.of( reducer.getSideOutput( @@ -234,10 +244,20 @@ public class UnboundedStreamIterationITCase extends TestLogger { .process( new TwoInputReduceAllRoundProcessFunction( sync, maxRound)); + + SingleOutputStreamOperator<EpochRecord> feedbackStream = + reducer.partitionCustom( + (k, numPartitions) -> k % numPartitions, + EpochRecord::getValue) + .map(x -> x) + .keyBy(EpochRecord::getValue) + .process(new StatefulProcessFunction<EpochRecord>() {}) + .setParallelism(4) + .map(new IncrementEpochMap()) + .setParallelism(numSources); + return new IterationBodyResult( - DataStreamList.of( - reducer.map(new IncrementEpochMap()) - .setParallelism(numSources)), + DataStreamList.of(feedbackStream), DataStreamList.of( reducer.getSideOutput( new OutputTag<OutputRecord<Integer>>( diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java new file mode 100644 index 0000000..47415f5 --- /dev/null +++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/StatefulProcessFunction.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.iteration.operators; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.KeyedProcessFunction; +import org.apache.flink.util.Collector; + +/** + * This is a function that uses keyed state so that we could verify the correctness of using keyed + * stream inside the iteration. + */ +public class StatefulProcessFunction<T> extends KeyedProcessFunction<Integer, T, T> { + + private ValueState<Integer> state; + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + this.state = + getRuntimeContext().getState(new ValueStateDescriptor<>("state", Integer.class)); + } + + @Override + public void processElement(T value, Context ctx, Collector<T> out) throws Exception { + if (state.value() == null) { + state.update(0); + + // Trying registers a timer + ctx.timerService().registerEventTimeTimer(1000L); + } else { + state.update(state.value() + 1); + } + + out.collect(value); + } +}