This is an automated email from the ASF dual-hosted git repository. gaoyunhaii pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 4935d03 [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception 4935d03 is described below commit 4935d03898c8c89bdda61b9cfbe10936b687785d Author: Dong Lin <lindon...@gmail.com> AuthorDate: Tue Dec 21 14:47:26 2021 +0800 [hotfix][iteration] Updates onEpochWatermarkIncremented() and onIterationTerminated() to throw Exception This closes #41. --- .../apache/flink/iteration/IterationListener.java | 5 ++- .../operator/AbstractWrapperOperator.java | 4 +- .../logisticregression/LogisticRegression.java | 12 ++--- .../apache/flink/ml/clustering/kmeans/KMeans.java | 52 ++++++++++------------ 4 files changed, 33 insertions(+), 40 deletions(-) diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java index 9c451f2..73b323f 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/IterationListener.java @@ -46,7 +46,8 @@ public interface IterationListener<T> { * the invocation of this method. * @param collector The collector for returning result values. */ - void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector); + void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) + throws Exception; /** * This callback is invoked after the execution of the iteration body has terminated. @@ -56,7 +57,7 @@ public interface IterationListener<T> { * the invocation of this method. * @param collector The collector for returning result values. */ - void onIterationTerminated(Context context, Collector<T> collector); + void onIterationTerminated(Context context, Collector<T> collector) throws Exception; /** * Information available in an invocation of the callbacks defined in the diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java index 86dbc2f..d790729 100644 --- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java +++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java @@ -123,8 +123,8 @@ public abstract class AbstractWrapperOperator<T> } @SuppressWarnings({"unchecked", "rawtypes"}) - protected void notifyEpochWatermarkIncrement( - IterationListener<?> listener, int epochWatermark) { + protected void notifyEpochWatermarkIncrement(IterationListener<?> listener, int epochWatermark) + throws Exception { if (epochWatermark != Integer.MAX_VALUE) { listener.onEpochWatermarkIncremented( epochWatermark, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 9df0cdf..602b9c6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -361,8 +361,8 @@ public class LogisticRegression @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector<double[]> collector) { - // TODO: let this method throws exception. + int epochWatermark, Context context, Collector<double[]> collector) + throws Exception { if (epochWatermark == 0) { coefficient = new DenseVector(feedbackBuffer); coefficientDim = coefficient.size(); @@ -372,12 +372,8 @@ public class LogisticRegression updateModel(); } Arrays.fill(gradient.values, 0); - try { - if (trainData == null) { - trainData = IteratorUtils.toList(trainDataState.get().iterator()); - } - } catch (Exception e) { - throw new RuntimeException(e); + if (trainData == null) { + trainData = IteratorUtils.toList(trainDataState.get().iterator()); } miniBatchData = getMiniBatchData(trainData, localBatchSize); Tuple2<Double, Double> weightSumAndLossSum = diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java index 1c1a47e..f9b704b 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java @@ -276,37 +276,33 @@ public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMea @Override public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out) { - // TODO: update onEpochWatermarkIncremented to throw Exception. - try { - List<DenseVector[]> list = IteratorUtils.toList(centroids.get().iterator()); - if (list.size() != 1) { - throw new RuntimeException( - "The operator received " - + list.size() - + " list of centroids in this round"); - } - DenseVector[] centroidValues = list.get(0); - - for (DenseVector point : points.get()) { - double minDistance = Double.MAX_VALUE; - int closestCentroidId = -1; - - for (int i = 0; i < centroidValues.length; i++) { - DenseVector centroid = centroidValues[i]; - double distance = distanceMeasure.distance(centroid, point); - if (distance < minDistance) { - minDistance = distance; - closestCentroidId = i; - } + int epochWatermark, Context context, Collector<Tuple2<Integer, DenseVector>> out) + throws Exception { + List<DenseVector[]> list = IteratorUtils.toList(centroids.get().iterator()); + if (list.size() != 1) { + throw new RuntimeException( + "The operator received " + + list.size() + + " list of centroids in this round"); + } + DenseVector[] centroidValues = list.get(0); + + for (DenseVector point : points.get()) { + double minDistance = Double.MAX_VALUE; + int closestCentroidId = -1; + + for (int i = 0; i < centroidValues.length; i++) { + DenseVector centroid = centroidValues[i]; + double distance = distanceMeasure.distance(centroid, point); + if (distance < minDistance) { + minDistance = distance; + closestCentroidId = i; } - - output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point))); } - centroids.clear(); - } catch (Exception e) { - throw new RuntimeException(e); + + output.collect(new StreamRecord<>(Tuple2.of(closestCentroidId, point))); } + centroids.clear(); } @Override