weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882484808


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm 
proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200";>H. Brendan McMahan 
et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, 
OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                
LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), 
getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), 
getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new 
OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are 
calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is 
used to update model
+     * by ftrl method.
+     *
+     * <p>See 
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in 
org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double 
elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the 
number "
+                            + "of elements in each batch. Some subtasks might 
be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, 
batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> 
{
+                                                for (int i = 0;
+                                                        i < 
newGradientInfo[1].size();
+                                                        ++i) {
+                                                    
newGradientInfo[0].values[i] =
+                                                            
gradientInfo[0].values[i]
+                                                                    + 
newGradientInfo[0].values[i];
+                                                    
newGradientInfo[1].values[i] =
+                                                            
gradientInfo[1].values[i]
+                                                                    + 
newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == 
null) {
+                                                        newGradientInfo[2] = 
gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new 
CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), 
DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, 
CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws 
Exception {
+            return new LogisticRegressionModelData(denseVector, 
modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext 
functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new 
ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends 
AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) 
throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - 
Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + 
l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends 
AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, 
DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
DenseVector.class));
+            TypeInformation<Row[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws 
Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = 
IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),

Review Comment:
   The `weight`  will be renamed as `weightSum`. For as above comment I add 
weightCol param, this variable is just the sum of weight. 
   
   I will add java doc for the code reduce the gradients.
   
   In my code, the using of model data is different from SGD. In online lr, the 
model data is used in two places: calculating gradient locally and updating 
model serially. But in SGD, the model updating is different from online lr. If 
I change the model  updating as SGD, a lot of code maybe rewrite, but not get 
an obvious benefits. So I will keep the model output format.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to