zhipeng93 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762777943
########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ########## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria - implements FlatMapFunction<EpochRecord, Integer>, IterationListener<Integer> { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * <p>When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol + implements IterationListener<Integer>, FlatMapFunction<Double, Integer> { + + private final int maxIter; - private final int maxRound; + private final double tol; - public RoundBasedTerminationCriteria(int maxRound) { - this.maxRound = maxRound; + private double loss = Double.NEGATIVE_INFINITY; + + public TerminateOnMaxIterOrTol(int maxIter, double tol) { + this.maxIter = maxIter; + this.tol = tol; + } + + public TerminateOnMaxIterOrTol(int maxIter) { + this.maxIter = maxIter; + this.tol = Double.NEGATIVE_INFINITY; + } + + public TerminateOnMaxIterOrTol(double tol) { + this.maxIter = Integer.MAX_VALUE; + this.tol = tol; } @Override - public void flatMap(EpochRecord integer, Collector<Integer> collector) throws Exception {} + public void flatMap(Double value, Collector<Integer> out) { + this.loss = value; Review comment: I think throwing exception to enforce only one input in each epoch is a viable solution. Users are not supposed to use `TerminateOnMaxIterOrTol` or `TerminateOnMaxIter` for asynchronous iterations --- Each worker is expected to decide its termination relying on its own loss.. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org