weibozhao commented on code in PR #83: URL: https://github.com/apache/flink-ml/pull/83#discussion_r882484808
########## flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java: ########## @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan + * McMahan et al. + * + * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click + * prediction: a view from the trenches.</a> + */ +public class OnlineLogisticRegression + implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>, + OnlineLogisticRegressionParams<OnlineLogisticRegression> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineLogisticRegression() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineLogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<LogisticRegressionModelData> modelDataStream = + LogisticRegressionModelData.getModelDataStream(initModelDataTable); + + DataStream<Row> points = + tEnv.toDataStream(inputs[0]) + .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol())); + + DataStream<DenseVector> initModelData = + modelDataStream.map( + (MapFunction<LogisticRegressionModelData, DenseVector>) + value -> value.coefficient); + + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new FtrlIterationBody( + getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet()); + + DataStream<LogisticRegressionModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineLogisticRegressionModel model = + new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + private static class FeaturesExtractor implements MapFunction<Row, Row> { + private final String featuresCol; + private final String labelCol; + + private FeaturesExtractor(String featuresCol, String labelCol) { + this.featuresCol = featuresCol; + this.labelCol = labelCol; + } + + @Override + public Row map(Row row) throws Exception { + return Row.of(row.getField(featuresCol), row.getField(labelCol)); + } + } + + /** + * Implementation of ftrl optimizer. In this implementation, gradients are calculated in + * distributed workers and reduce to one gradient. The reduced gradient is used to update model + * by ftrl method. + * + * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl + * + * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in + * future. + */ + private static class FtrlIterationBody implements IterationBody { + private final int batchSize; + private final double alpha; + private final double beta; + private final double l1; + private final double l2; + + public FtrlIterationBody( + int batchSize, double alpha, double beta, double reg, double elasticNet) { + this.batchSize = batchSize; + this.alpha = alpha; + this.beta = beta; + this.l1 = elasticNet * reg; + this.l2 = (1 - elasticNet) * reg; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<DenseVector> modelData = variableStreams.get(0); + + DataStream<Row> points = dataStreams.get(0); + int parallelism = points.getParallelism(); + Preconditions.checkState( + parallelism <= batchSize, + "There are more subtasks in the training process than the number " + + "of elements in each batch. Some subtasks might be idling forever."); + + DataStream<DenseVector[]> newGradient = + DataStreamUtils.generateBatchData(points, parallelism, batchSize) + .connect(modelData.broadcast()) + .transform( + "LocalGradientCalculator", + TypeInformation.of(DenseVector[].class), + new CalculateLocalGradient()) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce( + (ReduceFunction<DenseVector[]>) + (gradientInfo, newGradientInfo) -> { + for (int i = 0; + i < newGradientInfo[1].size(); + ++i) { + newGradientInfo[0].values[i] = + gradientInfo[0].values[i] + + newGradientInfo[0].values[i]; + newGradientInfo[1].values[i] = + gradientInfo[1].values[i] + + newGradientInfo[1].values[i]; + if (newGradientInfo[2] == null) { + newGradientInfo[2] = gradientInfo[2]; + } + } + return newGradientInfo; + }); + DataStream<DenseVector> feedbackModelData = + newGradient + .transform( + "ModelDataUpdater", + TypeInformation.of(DenseVector.class), + new UpdateModel(alpha, beta, l1, l2)) + .setParallelism(1); + + DataStream<LogisticRegressionModelData> outputModelData = + feedbackModelData.map(new CreateLrModelData()).setParallelism(1); + return new IterationBodyResult( + DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData)); + } + } + + private static class CreateLrModelData + implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction { + private Long modelVersion = 1L; + private transient ListState<Long> modelVersionState; + + @Override + public LogisticRegressionModelData map(DenseVector denseVector) throws Exception { + return new LogisticRegressionModelData(denseVector, modelVersion++); + } + + @Override + public void snapshotState(FunctionSnapshotContext functionSnapshotContext) + throws Exception { + modelVersionState.update(Collections.singletonList(modelVersion)); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + modelVersionState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelVersionState", Long.class)); + } + } + + /** Updates model. */ + private static class UpdateModel extends AbstractStreamOperator<DenseVector> + implements OneInputStreamOperator<DenseVector[], DenseVector> { + private ListState<double[]> nParamState; + private ListState<double[]> zParamState; + private final double alpha; + private final double beta; + private final double l1; + private final double l2; + private double[] nParam; + private double[] zParam; + + public UpdateModel(double alpha, double beta, double l1, double l2) { + this.alpha = alpha; + this.beta = beta; + this.l1 = l1; + this.l2 = l2; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + nParamState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("nParamState", double[].class)); + zParamState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("zParamState", double[].class)); + } + + @Override + public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception { + DenseVector[] gradientInfo = streamRecord.getValue(); + double[] coefficient = gradientInfo[2].values; + double[] g = gradientInfo[0].values; + for (int i = 0; i < g.length; ++i) { + if (gradientInfo[1].values[i] != 0.0) { + g[i] = g[i] / gradientInfo[1].values[i]; + } + } + if (zParam == null) { + zParam = new double[g.length]; + nParam = new double[g.length]; + nParamState.add(nParam); + zParamState.add(zParam); + } + + for (int i = 0; i < zParam.length; ++i) { + double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha; + zParam[i] += g[i] - sigma * coefficient[i]; + nParam[i] += g[i] * g[i]; + + if (Math.abs(zParam[i]) <= l1) { + coefficient[i] = 0.0; + } else { + coefficient[i] = + ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i]) + / ((beta + Math.sqrt(nParam[i])) / alpha + l2); + } + } + output.collect(new StreamRecord<>(new DenseVector(coefficient))); + } + } + + private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]> + implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> { + private ListState<DenseVector> modelDataState; + private ListState<Row[]> localBatchDataState; + private double[] gradient; + private double[] weight; + private int[] denseVectorIndices; + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelData", DenseVector.class)); + TypeInformation<Row[]> type = + ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class)); + localBatchDataState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("localBatch", type)); + } + + @Override + public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception { + localBatchDataState.add(pointsRecord.getValue()); + calculateGradient(); + } + + private void calculateGradient() throws Exception { + if (!modelDataState.get().iterator().hasNext() + || !localBatchDataState.get().iterator().hasNext()) { + return; + } + DenseVector modelData = + OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get(); + modelDataState.clear(); + + List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator()); + Row[] points = pointsList.remove(0); + localBatchDataState.update(pointsList); + + for (Row point : points) { + Vector vec = point.getFieldAs(0); + double label = point.getFieldAs(1); + if (gradient == null) { + gradient = new double[vec.size()]; + weight = new double[gradient.length]; + if (vec instanceof DenseVector) { + denseVectorIndices = new int[vec.size()]; + for (int i = 0; i < denseVectorIndices.length; ++i) { + denseVectorIndices[i] = i; + } + } + } + + int[] indices; + double[] values; + if (vec instanceof DenseVector) { + DenseVector denseVector = (DenseVector) vec; + indices = denseVectorIndices; + values = denseVector.values; + } else { + SparseVector sparseVector = (SparseVector) vec; + indices = sparseVector.indices; + values = sparseVector.values; + } + double p = 0.0; + for (int i = 0; i < indices.length; ++i) { + int idx = indices[i]; + p += modelData.values[idx] * values[i]; + } + p = 1 / (1 + Math.exp(-p)); + for (int i = 0; i < indices.length; ++i) { + int idx = indices[i]; + gradient[idx] += (p - label) * values[i]; + weight[idx] += 1.0; + } + } + + if (points.length > 0) { + output.collect( + new StreamRecord<>( + new DenseVector[] { + new DenseVector(gradient), + new DenseVector(weight), Review Comment: The `weight` will be renamed as `weightSum`. For as above comment I add weightCol param, this variable is just the sum of weight. I will add java doc for the code reduce the gradients. In my code, the using of model data is different from SGD. In online lr, the model data is used in two places: calculating gradient locally and updating model serially. But in SGD, the model updating is different from online lr. If I change the model updating as SGD, a lot of code maybe rewrite, but not get an obvious benefits. So I will keep the model output format. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org