This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 2ef6b8702295f246739a02d7335c1a7a1a010a94 Author: Yun Gao <gaoyunhen...@gmail.com> AuthorDate: Thu Nov 11 16:04:23 2021 +0800 [FLINK-24842][iteration] Make outputs depends on tails for the iteration body This closes #31. --- .../org/apache/flink/iteration/Iterations.java | 63 +++++++++++++++++----- .../flink/iteration/IterationConstructionTest.java | 37 +++++++------ 2 files changed, 73 insertions(+), 27 deletions(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java index 2a3fb39..514f31a 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java @@ -20,6 +20,7 @@ package org.apache.flink.iteration; import org.apache.flink.annotation.Experimental; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; import org.apache.flink.iteration.compile.DraftExecutionEnvironment; import org.apache.flink.iteration.operator.HeadOperator; import org.apache.flink.iteration.operator.HeadOperatorFactory; @@ -259,22 +260,29 @@ public class Iterations { tails.get(i).getTransformation().setCoLocationGroupKey(coLocationGroupKey); } + List<DataStream<?>> tailsAndCriteriaTails = new ArrayList<>(tails.getDataStreams()); checkState( mayHaveCriteria || iterationBodyResult.getTerminationCriteria() == null, "The current iteration type does not support the termination criteria."); if (iterationBodyResult.getTerminationCriteria() != null) { - addCriteriaStream( - iterationBodyResult.getTerminationCriteria(), - iterationId, - env, - draftEnv, - initVariableStreams, - headStreams, - totalInitVariableParallelism); + DataStreamList criteriaTails = + addCriteriaStream( + iterationBodyResult.getTerminationCriteria(), + iterationId, + env, + draftEnv, + initVariableStreams, + headStreams, + totalInitVariableParallelism); + tailsAndCriteriaTails.addAll(criteriaTails.getDataStreams()); } - return addOutputs(getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv)); + DataStream<Object> tailsUnion = + unionAllTails(env, new DataStreamList(tailsAndCriteriaTails)); + + return addOutputs( + getActualDataStreams(iterationBodyResult.getOutputStreams(), draftEnv), tailsUnion); } private static DataStreamList addReplayer( @@ -315,7 +323,7 @@ public class Iterations { return new DataStreamList(result); } - private static void addCriteriaStream( + private static DataStreamList addCriteriaStream( DataStream<?> draftCriteriaStream, IterationID iterationId, StreamExecutionEnvironment env, @@ -364,9 +372,11 @@ public class Iterations { criteriaHeaders.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey); criteriaTails.get(0).getTransformation().setCoLocationGroupKey(coLocationGroupKey); - // Now we notify all the head operators to count the criteria stream. + // Now we notify all the head operators to count the criteria streams. setCriteriaParallelism(headStreams, terminationCriteria.getParallelism()); setCriteriaParallelism(criteriaHeaders, terminationCriteria.getParallelism()); + + return criteriaTails; } @SuppressWarnings({"unchecked", "rawtypes"}) @@ -394,6 +404,24 @@ public class Iterations { return criteriaDraftEnv.getActualStream(draftMergedStream.getId()); } + @SuppressWarnings({"unchecked", "rawtypes"}) + private static DataStream<Object> unionAllTails( + StreamExecutionEnvironment env, DataStreamList tailsAndCriteriaTails) { + return Iterations.<DataStream>map( + tailsAndCriteriaTails, + tail -> + tail.filter(r -> false) + .name("filter-tail") + .returns(new GenericTypeInfo(Object.class)) + .setParallelism( + tail.getParallelism() > 0 + ? tail.getParallelism() + : env.getConfig().getParallelism())) + .stream() + .reduce(DataStream::union) + .get(); + } + private static List<TypeInformation<?>> getTypeInfos(DataStreamList dataStreams) { return map(dataStreams, DataStream::getType); } @@ -453,7 +481,8 @@ public class Iterations { .setParallelism(dataStream.getParallelism()))); } - private static DataStreamList addOutputs(DataStreamList dataStreams) { + @SuppressWarnings({"unchecked", "rawtypes"}) + private static DataStreamList addOutputs(DataStreamList dataStreams, DataStream tailsUnion) { return new DataStreamList( map( dataStreams, @@ -461,6 +490,16 @@ public class Iterations { IterationRecordTypeInfo<?> inputType = (IterationRecordTypeInfo<?>) dataStream.getType(); return dataStream + .union( + tailsUnion + .map(x -> x) + .name( + "tail-map-" + + dataStream + .getTransformation() + .getName()) + .returns(inputType) + .setParallelism(1)) .transform( "output-" + dataStream.getTransformation().getName(), inputType.getInnerTypeInfo(), diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java index f2ec465..5844b5a 100644 --- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java +++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java @@ -63,7 +63,7 @@ public class IterationConstructionTest extends TestLogger { Arrays.asList( /* 0 */ "Source: Variable -> input-Variable", /* 1 */ "head-Variable", - /* 2 */ "tail-head-Variable"); + /* 2 */ "tail-head-Variable -> filter-tail"); List<Integer> expectedParallelisms = Arrays.asList(4, 4, 4); List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources(); @@ -102,7 +102,7 @@ public class IterationConstructionTest extends TestLogger { /* 0 */ "Source: Variable", /* 1 */ "map -> input-map", /* 2 */ "head-map", - /* 3 */ "tail-head-map"); + /* 3 */ "tail-head-map -> filter-tail"); List<Integer> expectedParallelisms = Arrays.asList(4, 2, 2, 2); List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources(); @@ -191,12 +191,14 @@ public class IterationConstructionTest extends TestLogger { /* 2 */ "Source: Constant -> input-Constant", /* 3 */ "head-Variable0", /* 4 */ "head-Variable1", - /* 5 */ "Processor -> output-SideOutput -> Sink: Sink", + /* 5 */ "Processor", /* 6 */ "Feedback0", - /* 7 */ "tail-Feedback0", + /* 7 */ "tail-Feedback0 -> filter-tail", /* 8 */ "Feedback1", - /* 9 */ "tail-Feedback1"); - List<Integer> expectedParallelisms = Arrays.asList(2, 3, 3, 2, 3, 4, 2, 2, 3, 3); + /* 9 */ "tail-Feedback1 -> filter-tail", + /* 10 */ "tail-map-SideOutput", + /* 11 */ "output-SideOutput -> Sink: Sink"); + List<Integer> expectedParallelisms = Arrays.asList(2, 3, 3, 2, 3, 4, 2, 2, 3, 3, 1, 4); JobGraph jobGraph = env.getStreamGraph().getJobGraph(); List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources(); @@ -286,17 +288,19 @@ public class IterationConstructionTest extends TestLogger { /* 3 */ "Source: Termination -> input-Termination", /* 4 */ "head-Variable0", /* 5 */ "head-Variable1", - /* 6 */ "Processor -> output-SideOutput -> Sink: Sink", + /* 6 */ "Processor", /* 7 */ "Feedback0", - /* 8 */ "tail-Feedback0", + /* 8 */ "tail-Feedback0 -> filter-tail", /* 9 */ "Feedback1", - /* 10 */ "tail-Feedback1", + /* 10 */ "tail-Feedback1 -> filter-tail", /* 11 */ "Termination", /* 12 */ "head-Termination", /* 13 */ "criteria-merge", - /* 14 */ "tail-criteria-merge"); + /* 14 */ "tail-criteria-merge -> filter-tail", + /* 15 */ "tail-map-SideOutput", + /* 16 */ "output-SideOutput -> Sink: Sink"); List<Integer> expectedParallelisms = - Arrays.asList(2, 3, 3, 5, 2, 3, 4, 2, 2, 3, 3, 5, 5, 5, 5); + Arrays.asList(2, 3, 3, 5, 2, 3, 4, 2, 2, 3, 3, 5, 5, 5, 5, 1, 4); JobGraph jobGraph = env.getStreamGraph().getJobGraph(); List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources(); @@ -380,14 +384,17 @@ public class IterationConstructionTest extends TestLogger { /* 2 */ "Source: Termination -> input-Termination", /* 3 */ "head-Variable", /* 4 */ "Replayer-Constant", - /* 5 */ "Processor -> output-SideOutput -> Sink: Sink", + /* 5 */ "Processor", /* 6 */ "Feedback", - /* 7 */ "tail-Feedback", + /* 7 */ "tail-Feedback -> filter-tail", /* 8 */ "Termination", /* 9 */ "head-Termination", /* 10 */ "criteria-merge", - /* 11 */ "tail-criteria-merge"); - List<Integer> expectedParallelisms = Arrays.asList(2, 3, 5, 2, 3, 4, 2, 2, 5, 5, 5, 5); + /* 11 */ "tail-criteria-merge -> filter-tail", + /* 12 */ "tail-map-SideOutput", + /* 13 */ "output-SideOutput -> Sink: Sink"); + List<Integer> expectedParallelisms = + Arrays.asList(2, 3, 5, 2, 3, 4, 2, 2, 5, 5, 5, 5, 1, 4); JobGraph jobGraph = env.getStreamGraph().getJobGraph(); List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();