lindong28 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762673330



##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##########
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-        implements FlatMapFunction<EpochRecord, Integer>, 
IterationListener<Integer> {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * <p>When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+        implements IterationListener<Integer>, FlatMapFunction<Double, 
Integer> {
+
+    private final int maxIter;
 
-    private final int maxRound;
+    private final double tol;
 
-    public RoundBasedTerminationCriteria(int maxRound) {
-        this.maxRound = maxRound;
+    private double loss = Double.NEGATIVE_INFINITY;
+
+    public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+        this.maxIter = maxIter;
+        this.tol = tol;
+    }
+
+    public TerminateOnMaxIterOrTol(int maxIter) {

Review comment:
       nits: It seems simpler to remove this constructor and let user use 
`TerminateOnMaxIter` instead.

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##########
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-        implements FlatMapFunction<EpochRecord, Integer>, 
IterationListener<Integer> {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * <p>When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+        implements IterationListener<Integer>, FlatMapFunction<Double, 
Integer> {
+
+    private final int maxIter;
 
-    private final int maxRound;
+    private final double tol;
 
-    public RoundBasedTerminationCriteria(int maxRound) {
-        this.maxRound = maxRound;
+    private double loss = Double.NEGATIVE_INFINITY;

Review comment:
       What happens if there are not input values in the first epoch? Do we 
expect the iteration to terminate (which seems to be the case with `loss = 
Double.NEGATIVE_INFINITY`)?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+        implements Model<LogisticRegressionModel>,
+                LogisticRegressionModelParams<LogisticRegressionModel> {
+
+    private Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelData;
+
+    public LogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(modelData),
+                path,
+                new LogisticRegressionModelData.ModelDataEncoder());
+    }
+
+    public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+            throws IOException {
+        LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+        return model.setModelData(modelData);
+    }
+
+    @Override
+    public LogisticRegressionModel setModelData(Table... inputs) {
+        modelData = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelData};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<LogisticRegressionModelData> modelData =
+                LogisticRegressionModelData.getModelDataStream(this.modelData);
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                BasicTypeInfo.DOUBLE_TYPE_INFO,
+                                TypeInformation.of(DenseVector.class)),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol()));
+        DataStream<Row> predictionResult =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(inputStream),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.transform(
+                                    "doPrediction",
+                                    outputTypeInfo,
+                                    new PredictOperator(broadcastModelKey, 
getFeaturesCol()));
+                        });
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+            implements OneInputStreamOperator<Row, Row> {
+
+        private final String broadcastModelKey;
+
+        private final String featuresCol;
+
+        private DenseVector coefficient;
+
+        public PredictOperator(String broadcastModelKey, String featuresCol) {
+            super(new AbstractRichFunction() {});
+            this.broadcastModelKey = broadcastModelKey;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) {
+            if (coefficient == null) {
+                LogisticRegressionModelData modelData =
+                        (LogisticRegressionModelData)
+                                userFunction

Review comment:
       nits: can we call `getRuntimeContext()` directly without using 
`userFunction`?

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##########
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-        implements FlatMapFunction<EpochRecord, Integer>, 
IterationListener<Integer> {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * <p>When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+        implements IterationListener<Integer>, FlatMapFunction<Double, 
Integer> {
+
+    private final int maxIter;
 
-    private final int maxRound;
+    private final double tol;
 
-    public RoundBasedTerminationCriteria(int maxRound) {
-        this.maxRound = maxRound;
+    private double loss = Double.NEGATIVE_INFINITY;
+
+    public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+        this.maxIter = maxIter;
+        this.tol = tol;
+    }
+
+    public TerminateOnMaxIterOrTol(int maxIter) {
+        this.maxIter = maxIter;
+        this.tol = Double.NEGATIVE_INFINITY;
+    }
+
+    public TerminateOnMaxIterOrTol(double tol) {
+        this.maxIter = Integer.MAX_VALUE;
+        this.tol = tol;
     }
 
     @Override
-    public void flatMap(EpochRecord integer, Collector<Integer> collector) 
throws Exception {}
+    public void flatMap(Double value, Collector<Integer> out) {
+        this.loss = value;

Review comment:
       What if the input value sequence may be increasing? Should we use the 
minimum input value observed in this epoch?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java
##########
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Utility class to compute gradient and loss for logistic loss function.
+ *
+ * <p>See http://mlwiki.org/index.php/Logistic_Regression.
+ */
+public class LogisticGradient implements Serializable {
+
+    /** L2 regularization term. */
+    private final double l2;
+
+    public LogisticGradient(double l2) {
+        this.l2 = l2;
+    }
+
+    /**
+     * Computes weight sum and loss sum on a set of samples.
+     *
+     * @param dataPoints A sample set of train data.
+     * @param coefficient The model parameters.
+     * @return Weight sum and loss sum of the input data.
+     */
+    public final Tuple2<Double, Double> computeLoss(
+            List<LabeledPointWithWeight> dataPoints, DenseVector coefficient) {
+        double weightSum = 0.0;
+        double lossSum = 0.0;
+        for (LabeledPointWithWeight dataPoint : dataPoints) {
+            lossSum += dataPoint.weight * computeLoss(dataPoint, coefficient);
+            weightSum += dataPoint.weight;
+        }
+        if (Double.compare(0, l2) != 0) {
+            lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2);
+        }
+        return Tuple2.of(weightSum, lossSum);
+    }
+
+    /**
+     * Computes gradient on a set of samples.
+     *
+     * @param dataPoints A sample set of train data.
+     * @param coefficient The model parameters.
+     * @param cumGradient The accumulated gradients.
+     * @return Weight sum of the input data.
+     */
+    public final double computeGradient(

Review comment:
       nits: can we remove the `final` keyword here for consistency with most 
other method? Same for `computeLoss()`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java
##########
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Utility class to compute gradient and loss for logistic loss function.
+ *
+ * <p>See http://mlwiki.org/index.php/Logistic_Regression.
+ */
+public class LogisticGradient implements Serializable {
+
+    /** L2 regularization term. */
+    private final double l2;
+
+    public LogisticGradient(double l2) {
+        this.l2 = l2;
+    }
+
+    /**
+     * Computes weight sum and loss sum on a set of samples.
+     *
+     * @param dataPoints A sample set of train data.
+     * @param coefficient The model parameters.
+     * @return Weight sum and loss sum of the input data.
+     */
+    public final Tuple2<Double, Double> computeLoss(
+            List<LabeledPointWithWeight> dataPoints, DenseVector coefficient) {
+        double weightSum = 0.0;
+        double lossSum = 0.0;
+        for (LabeledPointWithWeight dataPoint : dataPoints) {
+            lossSum += dataPoint.weight * computeLoss(dataPoint, coefficient);
+            weightSum += dataPoint.weight;
+        }
+        if (Double.compare(0, l2) != 0) {
+            lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2);
+        }
+        return Tuple2.of(weightSum, lossSum);
+    }
+
+    /**
+     * Computes gradient on a set of samples.
+     *
+     * @param dataPoints A sample set of train data.
+     * @param coefficient The model parameters.
+     * @param cumGradient The accumulated gradients.
+     * @return Weight sum of the input data.
+     */
+    public final double computeGradient(
+            List<LabeledPointWithWeight> dataPoints,
+            DenseVector coefficient,
+            DenseVector cumGradient) {
+        double weightSum = 0.0;
+        for (LabeledPointWithWeight dataPoint : dataPoints) {
+            weightSum += dataPoint.weight;
+            computeGradient(dataPoint, coefficient, cumGradient);
+        }
+        if (Double.compare(0, l2) != 0) {
+            BLAS.axpy(this.l2 * 2, coefficient, cumGradient);
+        }
+        return weightSum;

Review comment:
       It seems that the `weightSum` returned by this method is never used. 
Could we remove this return value for simplicity?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasWeightCol.java
##########
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared weight column param. If this is not set or empty, 
we treat all instance

Review comment:
       Currently we only check whether `getWeightCol() == null`, which I 
believe it is simpler than additional checking whether it is an empty string.
   
   Should we remove `or empty` here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to