lindong28 commented on code in PR #83: URL: https://github.com/apache/flink-ml/pull/83#discussion_r884756711
########## flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java: ########## @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan + * McMahan et al. + * + * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click + * prediction: a view from the trenches.</a> + */ +public class OnlineLogisticRegression + implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>, + OnlineLogisticRegressionParams<OnlineLogisticRegression> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineLogisticRegression() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public OnlineLogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<LogisticRegressionModelData> modelDataStream = + LogisticRegressionModelData.getModelDataStream(initModelDataTable); + + DataStream<Row> points = + tEnv.toDataStream(inputs[0]) + .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol())); + + DataStream<DenseVector> initModelData = + modelDataStream.map( + (MapFunction<LogisticRegressionModelData, DenseVector>) + value -> value.coefficient); Review Comment: I believe the model version does need to take into account the version of the input model data so that we can know the order of models across program restarts. Here is an example scenario: - Let's say the global batch size is 100 and there are 1000 records in the input datastream. The online training program should generate 10 model versions after processing these 1000 records. - The training process finished processing 300 records and generated 3 model data with versions 1, 2, and 3. After successfully making a checkpoint, the process exited due to machine failure. - The training process is restarted from the last successful checkpoint. It should continue to read input datastream starting from the 301th record. And it should read the latest model data generated before it is restarted. Ideally, we should hide the machine failure from users, meaning that the sequence of model versions should be 1, 2, 3, 4, ...10 as if the failure has never happened. Therefore we have to set the initial model version to the model version from the input model data. Does this make sense? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org