This is an automated email from the ASF dual-hosted git repository. dianfu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 3412970 [FLINK-23929][python] Operator with multiple outputs should not be chained with one of the outputs 3412970 is described below commit 3412970839eebcbeb42d8d1851d5e2fcc95e73a5 Author: Dian Fu <dia...@apache.org> AuthorDate: Tue Aug 24 10:41:40 2021 +0800 [FLINK-23929][python] Operator with multiple outputs should not be chained with one of the outputs This closes #16948. --- .../chain/PythonOperatorChainingOptimizer.java | 10 ++++-- .../chain/PythonOperatorChainingOptimizerTest.java | 42 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java index 4283c7c..e87650d 100644 --- a/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java +++ b/flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java @@ -235,7 +235,7 @@ public class PythonOperatorChainingOptimizer { } } - if (isChainable(input, transform)) { + if (isChainable(input, transform, outputMap)) { Transformation<?> chainedTransformation = createChainedTransformation(input, transform); Set<Transformation<?>> outputTransformations = outputMap.get(transform); @@ -243,6 +243,7 @@ public class PythonOperatorChainingOptimizer { for (Transformation<?> output : outputTransformations) { replaceInput(output, transform, chainedTransformation); } + outputMap.put(chainedTransformation, outputTransformations); } chainInfo = ChainInfo.of(chainedTransformation, Arrays.asList(input, transform)); } @@ -368,11 +369,14 @@ public class PythonOperatorChainingOptimizer { } private static boolean isChainable( - Transformation<?> upTransform, Transformation<?> downTransform) { + Transformation<?> upTransform, + Transformation<?> downTransform, + Map<Transformation<?>, Set<Transformation<?>>> outputMap) { return upTransform.getParallelism() == downTransform.getParallelism() && upTransform.getMaxParallelism() == downTransform.getMaxParallelism() && upTransform.getSlotSharingGroup().equals(downTransform.getSlotSharingGroup()) - && areOperatorsChainable(upTransform, downTransform); + && areOperatorsChainable(upTransform, downTransform) + && outputMap.get(upTransform).size() == 1; } private static boolean areOperatorsChainable( diff --git a/flink-python/src/test/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizerTest.java b/flink-python/src/test/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizerTest.java index 2e24bc3..0895303 100644 --- a/flink-python/src/test/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizerTest.java +++ b/flink-python/src/test/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizerTest.java @@ -588,6 +588,48 @@ public class PythonOperatorChainingOptimizerTest { "f1"); } + @Test + public void testTransformationWithMultipleOutputs() { + PythonProcessOperator<?, ?> processOperator1 = + createProcessOperator("f1", Types.STRING(), Types.LONG()); + PythonProcessOperator<?, ?> processOperator2 = + createProcessOperator("f2", Types.STRING(), Types.LONG()); + PythonProcessOperator<?, ?> processOperator3 = + createProcessOperator("f3", Types.LONG(), Types.INT()); + + Transformation<?> sourceTransformation = mock(SourceTransformation.class); + Transformation<?> processTransformation1 = + new OneInputTransformation( + sourceTransformation, + "process", + processOperator1, + processOperator1.getProducedType(), + 2); + Transformation<?> processTransformation2 = + new OneInputTransformation( + processTransformation1, + "process", + processOperator2, + processOperator2.getProducedType(), + 2); + Transformation<?> processTransformation3 = + new OneInputTransformation( + processTransformation1, + "process", + processOperator3, + processOperator3.getProducedType(), + 2); + + List<Transformation<?>> transformations = new ArrayList<>(); + transformations.add(processTransformation2); + transformations.add(processTransformation3); + + List<Transformation<?>> optimized = + PythonOperatorChainingOptimizer.optimize(transformations); + // no chaining optimization occurred + assertEquals(4, optimized.size()); + } + // ----------------------- Utility Methods ----------------------- private void validateChainedPythonFunctions(