This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 5c69eeea18f5302f780512bb1e5fb3d898172989 Author: Yun Gao <gaoyunhen...@gmail.com> AuthorDate: Mon Nov 15 11:32:43 2021 +0800 [hotfix][iteration] Return more fine-grained operator class for the WrapperFactory --- .../flink/iteration/operator/OperatorWrapper.java | 3 +++ .../iteration/operator/WrapperOperatorFactory.java | 2 +- .../operator/allround/AllRoundOperatorWrapper.java | 17 +++++++++++++++++ .../operator/perround/PerRoundOperatorWrapper.java | 17 +++++++++++++++++ .../DraftExecutionEnvironmentSwitchWrapperTest.java | 12 ++++++++++++ .../ml/common/broadcast/operator/BroadcastWrapper.java | 15 +++++++++++++++ 6 files changed, 65 insertions(+), 1 deletion(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java index 1d7bbbe..e7fd379 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorWrapper.java @@ -35,6 +35,9 @@ public interface OperatorWrapper<T, R> extends Serializable { StreamOperatorParameters<R> operatorParameters, StreamOperatorFactory<T> operatorFactory); + Class<? extends StreamOperator> getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory); + <KEY> KeySelector<R, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector); StreamPartitioner<R> wrapStreamPartitioner(StreamPartitioner<T> streamPartitioner); diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java index dfb325a..c6a9a06 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/WrapperOperatorFactory.java @@ -48,7 +48,7 @@ public class WrapperOperatorFactory<OUT> @Override public Class<? extends StreamOperator> getStreamOperatorClass(ClassLoader classLoader) { - return AbstractWrapperOperator.class; + return wrapper.getStreamOperatorClass(classLoader, operatorFactory); } @VisibleForTesting diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java index c28acd1..0cf54c0 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AllRoundOperatorWrapper.java @@ -57,6 +57,23 @@ public class AllRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR } @Override + public Class<? extends StreamOperator> getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) { + Class<? extends StreamOperator> operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return OneInputAllRoundWrapperOperator.class; + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return TwoInputAllRoundWrapperOperator.class; + } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return MultipleInputAllRoundWrapperOperator.class; + } else { + throw new UnsupportedOperationException( + "Unsupported operator class for all-round wrapper: " + operatorClass); + } + } + + @Override public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector( KeySelector<T, KEY> keySelector) { return new ProxyKeySelector<>(keySelector); diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java index ffa2221..87ee6aa 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/PerRoundOperatorWrapper.java @@ -57,6 +57,23 @@ public class PerRoundOperatorWrapper<T> implements OperatorWrapper<T, IterationR } @Override + public Class<? extends StreamOperator> getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) { + Class<? extends StreamOperator> operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return OneInputPerRoundWrapperOperator.class; + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return TwoInputPerRoundWrapperOperator.class; + } else if (MultipleInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return MultipleInputPerRoundWrapperOperator.class; + } else { + throw new UnsupportedOperationException( + "Unsupported operator class for all-round wrapper: " + operatorClass); + } + } + + @Override public <KEY> KeySelector<IterationRecord<T>, KEY> wrapKeySelector( KeySelector<T, KEY> keySelector) { return new ProxyKeySelector<>(keySelector); diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java index 86df0fc..b544003 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/compile/DraftExecutionEnvironmentSwitchWrapperTest.java @@ -113,6 +113,12 @@ public class DraftExecutionEnvironmentSwitchWrapperTest extends TestLogger { } @Override + public Class<? extends StreamOperator> getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) { + return StreamMap.class; + } + + @Override public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) { return keySelector; } @@ -142,6 +148,12 @@ public class DraftExecutionEnvironmentSwitchWrapperTest extends TestLogger { } @Override + public Class<? extends StreamOperator> getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) { + return StreamFilter.class; + } + + @Override public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) { return keySelector; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java index 2e3f88d..2a18c85 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastWrapper.java @@ -75,6 +75,21 @@ public class BroadcastWrapper<T> implements OperatorWrapper<T, T> { } @Override + public Class<? extends StreamOperator> getStreamOperatorClass( + ClassLoader classLoader, StreamOperatorFactory<T> operatorFactory) { + Class<? extends StreamOperator> operatorClass = + operatorFactory.getStreamOperatorClass(getClass().getClassLoader()); + if (OneInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return OneInputBroadcastWrapperOperator.class; + } else if (TwoInputStreamOperator.class.isAssignableFrom(operatorClass)) { + return TwoInputBroadcastWrapperOperator.class; + } else { + throw new UnsupportedOperationException( + "Unsupported operator class for with-broadcast wrapper: " + operatorClass); + } + } + + @Override public <KEY> KeySelector<T, KEY> wrapKeySelector(KeySelector<T, KEY> keySelector) { return keySelector; }