lindong28 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1194570809


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##########
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * <p>Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+    /** Initialize the model data. */
+    void open(long startFeatureIndex, long endFeatureIndex);
+
+    /** Applies the push to update the model data, e.g., using gradient to 
update model. */
+    void handlePush(long[] keys, double[] values);
+
+    /** Applies the pull and return the retrieved model data. */
+    double[] handlePull(long[] keys);
+
+    /** Returns model pieces with the format of (startFeatureIdx, 
endFeatureIdx, modelValues). */
+    Iterator<Tuple3<Long, Long, double[]>> getModelPieces();
+
+    /** Recover the model data from state. */

Review Comment:
   It would be useful to make the comment style consistent. E.g. Recover -> 
Recovers.
   
   Same for other comments.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##########
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * <p>Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+    /** Initialize the model data. */
+    void open(long startFeatureIndex, long endFeatureIndex);
+
+    /** Applies the push to update the model data, e.g., using gradient to 
update model. */
+    void handlePush(long[] keys, double[] values);
+
+    /** Applies the pull and return the retrieved model data. */
+    double[] handlePull(long[] keys);
+
+    /** Returns model pieces with the format of (startFeatureIdx, 
endFeatureIdx, modelValues). */
+    Iterator<Tuple3<Long, Long, double[]>> getModelPieces();

Review Comment:
   It would be useful to know what is the expected output of this API w.r.t. 
the invocation of other APIs (e.g. handlePush).



##########
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##########
@@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) {
     public LogisticRegressionModelServable setModelData(InputStream... 
modelDataInputs)
             throws IOException {
         Preconditions.checkArgument(modelDataInputs.length == 1);
+        List<LogisticRegressionModelData> modelPieces = new ArrayList<>();
+        while (true) {
+            try {
+                LogisticRegressionModelData piece =
+                        LogisticRegressionModelData.decode(modelDataInputs[0]);
+                modelPieces.add(piece);
+            } catch (IOException e) {
+                // Reached the end of model stream.
+                break;
+            }
+        }
 
-        modelData = LogisticRegressionModelData.decode(modelDataInputs[0]);
+        modelData = mergePieces(modelPieces);
         return this;
     }
 
+    @VisibleForTesting
+    public static LogisticRegressionModelData mergePieces(

Review Comment:
   Would it be more intuitive to put this method in 
`LogisticRegressionModelData`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/FTRL.java:
##########
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+/** The FTRL model updater. */

Review Comment:
   Would it be useful to provide doc or reference link to explain what is FTRL?
   
   Maybe something like 
https://github.com/Angel-ML/angel/blob/master/docs/algo/ftrl_lr_spark.md.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java:
##########
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+/** Message Type between workers and servers. */
+public enum MessageType {
+    ZEROS_TO_PUSH((char) 0),

Review Comment:
   How about using the following names:
   
   PUSH_ZERO, PUSH_KV, PULL_INDICE, PULL_VALUE
   
   I am not sure what is the meaning of `zero` in `PUSH_ZERO`. Should we rename 
it something like `INITIALIZE_MODEL`?
   
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java:
##########
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Merges the message from different servers for one pull request.
+ *
+ * <p>Note that for each single-thread worker, there are at exactly 
#numServers pieces for each pull
+ * request in the feedback edge.
+ */
+public class MirrorWorkerOperator extends AbstractStreamOperator<byte[]>

Review Comment:
   It is not clear what is the meaning of `mirror` here. Maybe we can discuss 
offline.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##########
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * <p>Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {

Review Comment:
   Given that `ModelUpdater` is used only by classes in the package 
`org.apache.flink.ml.common.ps`, would it be better to move it to that package?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##########
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * <p>Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+    /** Initialize the model data. */
+    void open(long startFeatureIndex, long endFeatureIndex);
+
+    /** Applies the push to update the model data, e.g., using gradient to 
update model. */
+    void handlePush(long[] keys, double[] values);
+
+    /** Applies the pull and return the retrieved model data. */
+    double[] handlePull(long[] keys);

Review Comment:
   What is the relationship between this method and `handlePush`? For example, 
does this only handle `keys` that has been updated with `handlePush()`?
   
   If it works like a map, maybe re-use the API of map so that it is more 
intuitive.



##########
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##########
@@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) {
     public LogisticRegressionModelServable setModelData(InputStream... 
modelDataInputs)
             throws IOException {
         Preconditions.checkArgument(modelDataInputs.length == 1);
+        List<LogisticRegressionModelData> modelPieces = new ArrayList<>();
+        while (true) {
+            try {
+                LogisticRegressionModelData piece =
+                        LogisticRegressionModelData.decode(modelDataInputs[0]);

Review Comment:
   Other `XXXModelData#decode` methods will finish reading the given input 
stream and return a self-contained model data instance. We will break this 
convention by having `LogisticRegressionModelData.decode` return a segment of 
the full model data.
   
   Would it be simpler to have `LogisticRegressionModelData` maintain a list of 
`LogisticRegressionModelDataSegment` internally?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl

Review Comment:
   Since we keep both `LogisticRegressionWithFtrl` and `LogisticRegression` and 
both classes implement the same algorithm, I suppose these two algorithms have 
different pros/cons that address different use-cases.
   
   Can you provide information to help users decide which algorithm to use?



##########
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##########
@@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) {
     public LogisticRegressionModelServable setModelData(InputStream... 
modelDataInputs)
             throws IOException {
         Preconditions.checkArgument(modelDataInputs.length == 1);
+        List<LogisticRegressionModelData> modelPieces = new ArrayList<>();
+        while (true) {
+            try {
+                LogisticRegressionModelData piece =
+                        LogisticRegressionModelData.decode(modelDataInputs[0]);
+                modelPieces.add(piece);

Review Comment:
   It is probably more common and intuitive to use `segment` instead of `piece`.
   
   We can find a lot of class in Flink with `segment` in the class name.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/ServerOperator.java:
##########
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.ps.message.IndicesToPullM;
+import org.apache.flink.ml.common.ps.message.KVsToPushM;
+import org.apache.flink.ml.common.ps.message.MessageType;
+import org.apache.flink.ml.common.ps.message.MessageUtils;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.ml.common.ps.message.ZerosToPushM;
+import org.apache.flink.ml.common.updater.ModelUpdater;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.SerializableObject;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+/**
+ * The server operator maintains the shared parameters. It receives push/pull 
requests from {@link
+ * WorkerOperator} and sends the answer request to {@link 
MirrorWorkerOperator}. It works closely
+ * with {@link ModelUpdater} in the following way:
+ *
+ * <ul>
+ *   <li>The server operator deals with the message from workers and decide 
when to process the
+ *       received message. (i.e., synchronous vs. asynchronous).
+ *   <li>The server operator calls {@link ModelUpdater#handlePush(long[], 
double[])} and {@link
+ *       ModelUpdater#handlePull(long[])} to process the messages in detail.
+ *   <li>The server operator ensures that {@link ModelUpdater} is robust to 
failures.

Review Comment:
   Instead of using `robust to failures`, it might be simpler and more explicit 
to say something like this:
   
   The server operator triggers checkpoint for {@link ModelUpdater}.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl
+        implements Estimator<LogisticRegressionWithFtrl, 
LogisticRegressionModel>,
+                LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrl> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LogisticRegressionWithFtrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public LogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String classificationType = getMultiClass();
+        Preconditions.checkArgument(
+                "auto".equals(classificationType) || 
"binomial".equals(classificationType),
+                "Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<LabeledLargePointWithWeight> trainData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, LabeledLargePointWithWeight>)
+                                        dataPoint -> {
+                                            double weight =
+                                                    getWeightCol() == null
+                                                            ? 1.0
+                                                            : ((Number)
+                                                                            
dataPoint.getField(
+                                                                               
     getWeightCol()))
+                                                                    
.doubleValue();
+                                            double label =
+                                                    ((Number) 
dataPoint.getField(getLabelCol()))
+                                                            .doubleValue();
+                                            boolean isBinomial =
+                                                    Double.compare(0., label) 
== 0
+                                                            || 
Double.compare(1., label) == 0;
+                                            if (!isBinomial) {
+                                                throw new RuntimeException(
+                                                        "Multinomial 
classification is not supported yet. Supported options: [auto, binomial].");
+                                            }
+                                            Tuple2<long[], double[]> features =
+                                                    
dataPoint.getFieldAs(getFeaturesCol());
+                                            return new 
LabeledLargePointWithWeight(
+                                                    features, label, weight);
+                                        });
+
+        DataStream<Long> modelDim;
+        if (getModelDim() > 0) {
+            modelDim = 
trainData.getExecutionEnvironment().fromElements(getModelDim());
+        } else {
+            modelDim =
+                    DataStreamUtils.reduce(
+                                    trainData.map(x -> 
x.features.f0[x.features.f0.length - 1]),
+                                    (ReduceFunction<Long>) Math::max)
+                            .map((MapFunction<Long, Long>) value -> value + 1);
+        }
+
+        LogisticRegressionWithFtrlTrainingContext trainingContext =
+                new LogisticRegressionWithFtrlTrainingContext(getParamMap());
+
+        IterationStageList<LogisticRegressionWithFtrlTrainingContext> 
iterationStages =
+                new IterationStageList<>(trainingContext);
+        iterationStages
+                .addTrainingStage(new ComputeIndices())
+                .addTrainingStage(
+                        new PullStage(
+                                (SerializableSupplier<long[]>) () -> 
trainingContext.pullIndices,
+                                (SerializableConsumer<double[]>)
+                                        x -> trainingContext.pulledValues = x))
+                .addTrainingStage(new 
ComputeGradients(BinaryLogisticLoss.INSTANCE))
+                .addTrainingStage(
+                        new PushStage(
+                                (SerializableSupplier<long[]>) () -> 
trainingContext.pushIndices,
+                                (SerializableSupplier<double[]>) () -> 
trainingContext.pushValues))
+                .setTerminationCriteria(
+                        
(SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>)
+                                o -> o.iterationId >= getMaxIter());
+        FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<Tuple3<Long, Long, double[]>> rawModelData =
+                TrainingUtils.train(
+                        modelDim,
+                        trainData,
+                        ftrl,
+                        iterationStages,
+                        getNumServers(),
+                        getNumServerCores());
+
+        final long modelVersion = 0L;
+
+        DataStream<LogisticRegressionModelData> modelData =
+                rawModelData.map(
+                        tuple3 ->
+                                new LogisticRegressionModelData(
+                                        Vectors.dense(tuple3.f2),
+                                        tuple3.f0,
+                                        tuple3.f1,
+                                        modelVersion));
+
+        LogisticRegressionModel model =
+                new 
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
+        ParamUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, 
String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
+
+/**
+ * An iteration stage that samples a batch of training data and computes the 
indices needed to
+ * compute gradients.
+ */
+class ComputeIndices extends 
ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
+
+    @Override
+    public void process(LogisticRegressionWithFtrlTrainingContext context) 
throws Exception {
+        context.readInNextBatchData();
+        context.pullIndices = computeIndices(context.batchData);
+    }
+
+    public static long[] computeIndices(List<LabeledLargePointWithWeight> 
dataPoints) {
+        LongOpenHashSet indices = new LongOpenHashSet();
+        for (LabeledLargePointWithWeight dataPoint : dataPoints) {
+            long[] notZeros = dataPoint.features.f0;
+            for (long index : notZeros) {
+                indices.add(index);
+            }
+        }
+
+        long[] sortedIndices = new long[indices.size()];
+        Iterator<Long> iterator = indices.iterator();
+        int i = 0;
+        while (iterator.hasNext()) {
+            sortedIndices[i++] = iterator.next();
+        }
+        Arrays.sort(sortedIndices);
+        return sortedIndices;
+    }
+}
+
+/**
+ * An iteration stage that uses the pulled model values and sampled batch data 
to compute the
+ * gradients.
+ */
+class ComputeGradients extends 
ProcessStage<LogisticRegressionWithFtrlTrainingContext> {

Review Comment:
   Since APIs of this class may be invoked directly outside 
`LogisticRegressionWithFtrl`, it seems more conventional and readable to move 
this class outside `LogisticRegressionWithFtrl`.
   



##########
flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java:
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.feature;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/** A data point to represent values that use long as index and double as 
values. */
+public class LabeledLargePointWithWeight {
+    public Tuple2<long[], double[]> features;

Review Comment:
   Can you explain why we can't re-use `LabeledPointWithWeight`?
   
   If the features presented here encodes a sparse vector, then we should be 
able to re-use `LabeledPointWithWeight` because 
`LabeledPointWithWeight#features` can be a SparseVector.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.training;
+
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+
+import java.io.Serializable;
+
+/**
+ * Stores the context information that is alive during the training process. 
Note that the context
+ * information will be updated by each {@link IterationStage}.
+ *
+ * <p>Note that subclasses should take care of the snapshot of object stored 
in {@link
+ * TrainingContext} if the object satisfies that: the write-process is 
followed by an {@link
+ * PullStage}, which is later again read by other stages.
+ */
+public interface TrainingContext extends Serializable {

Review Comment:
   `context` is typically used for APIs that get states rather than writing 
states.
   
   Would it be more intuitive to name it `IterationStageListener`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/IterationStageList.java:
##########
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.training;
+
+import org.apache.flink.util.function.SerializableFunction;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Function;
+
+/**
+ * A list of iteration stages to express the logic of an iterative machine 
learning training
+ * process.
+ */
+public class IterationStageList<T extends TrainingContext> implements 
Serializable {
+    public final T context;
+    public Function<T, Boolean> shouldTerminate;
+    public List<IterationStage> stageList;
+
+    public IterationStageList(T context) {
+        this.stageList = new ArrayList<>();
+        this.context = context;
+    }
+
+    /** Sets the criteria of termination. */
+    public void setTerminationCriteria(SerializableFunction<T, Boolean> 
shouldTerminate) {
+        this.shouldTerminate = shouldTerminate;
+    }
+
+    /** Adds an iteration stage into the stage list. */
+    public IterationStageList<T> addTrainingStage(IterationStage stage) {

Review Comment:
   Given that the class name is `IterationStageList`, would it be simpler to 
name the method `add(...)` or `addStage(...)`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##########
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl
+        implements Estimator<LogisticRegressionWithFtrl, 
LogisticRegressionModel>,
+                LogisticRegressionWithFtrlParams<LogisticRegressionWithFtrl> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LogisticRegressionWithFtrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public LogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String classificationType = getMultiClass();
+        Preconditions.checkArgument(
+                "auto".equals(classificationType) || 
"binomial".equals(classificationType),
+                "Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<LabeledLargePointWithWeight> trainData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, LabeledLargePointWithWeight>)
+                                        dataPoint -> {
+                                            double weight =
+                                                    getWeightCol() == null
+                                                            ? 1.0
+                                                            : ((Number)
+                                                                            
dataPoint.getField(
+                                                                               
     getWeightCol()))
+                                                                    
.doubleValue();
+                                            double label =
+                                                    ((Number) 
dataPoint.getField(getLabelCol()))
+                                                            .doubleValue();
+                                            boolean isBinomial =
+                                                    Double.compare(0., label) 
== 0
+                                                            || 
Double.compare(1., label) == 0;
+                                            if (!isBinomial) {
+                                                throw new RuntimeException(
+                                                        "Multinomial 
classification is not supported yet. Supported options: [auto, binomial].");
+                                            }
+                                            Tuple2<long[], double[]> features =
+                                                    
dataPoint.getFieldAs(getFeaturesCol());
+                                            return new 
LabeledLargePointWithWeight(
+                                                    features, label, weight);
+                                        });
+
+        DataStream<Long> modelDim;
+        if (getModelDim() > 0) {
+            modelDim = 
trainData.getExecutionEnvironment().fromElements(getModelDim());
+        } else {
+            modelDim =
+                    DataStreamUtils.reduce(
+                                    trainData.map(x -> 
x.features.f0[x.features.f0.length - 1]),
+                                    (ReduceFunction<Long>) Math::max)
+                            .map((MapFunction<Long, Long>) value -> value + 1);
+        }
+
+        LogisticRegressionWithFtrlTrainingContext trainingContext =
+                new LogisticRegressionWithFtrlTrainingContext(getParamMap());
+
+        IterationStageList<LogisticRegressionWithFtrlTrainingContext> 
iterationStages =
+                new IterationStageList<>(trainingContext);
+        iterationStages
+                .addTrainingStage(new ComputeIndices())
+                .addTrainingStage(
+                        new PullStage(
+                                (SerializableSupplier<long[]>) () -> 
trainingContext.pullIndices,
+                                (SerializableConsumer<double[]>)
+                                        x -> trainingContext.pulledValues = x))
+                .addTrainingStage(new 
ComputeGradients(BinaryLogisticLoss.INSTANCE))
+                .addTrainingStage(
+                        new PushStage(
+                                (SerializableSupplier<long[]>) () -> 
trainingContext.pushIndices,
+                                (SerializableSupplier<double[]>) () -> 
trainingContext.pushValues))
+                .setTerminationCriteria(
+                        
(SerializableFunction<LogisticRegressionWithFtrlTrainingContext, Boolean>)
+                                o -> o.iterationId >= getMaxIter());
+        FTRL ftrl = new FTRL(getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<Tuple3<Long, Long, double[]>> rawModelData =
+                TrainingUtils.train(
+                        modelDim,
+                        trainData,
+                        ftrl,
+                        iterationStages,
+                        getNumServers(),
+                        getNumServerCores());
+
+        final long modelVersion = 0L;
+
+        DataStream<LogisticRegressionModelData> modelData =
+                rawModelData.map(
+                        tuple3 ->
+                                new LogisticRegressionModelData(
+                                        Vectors.dense(tuple3.f2),
+                                        tuple3.f0,
+                                        tuple3.f1,
+                                        modelVersion));
+
+        LogisticRegressionModel model =
+                new 
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
+        ParamUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static LogisticRegressionWithFtrl load(StreamTableEnvironment tEnv, 
String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
+
+/**
+ * An iteration stage that samples a batch of training data and computes the 
indices needed to
+ * compute gradients.
+ */
+class ComputeIndices extends 
ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
+
+    @Override
+    public void process(LogisticRegressionWithFtrlTrainingContext context) 
throws Exception {
+        context.readInNextBatchData();
+        context.pullIndices = computeIndices(context.batchData);
+    }
+
+    public static long[] computeIndices(List<LabeledLargePointWithWeight> 
dataPoints) {
+        LongOpenHashSet indices = new LongOpenHashSet();
+        for (LabeledLargePointWithWeight dataPoint : dataPoints) {
+            long[] notZeros = dataPoint.features.f0;
+            for (long index : notZeros) {
+                indices.add(index);
+            }
+        }
+
+        long[] sortedIndices = new long[indices.size()];
+        Iterator<Long> iterator = indices.iterator();
+        int i = 0;
+        while (iterator.hasNext()) {
+            sortedIndices[i++] = iterator.next();
+        }
+        Arrays.sort(sortedIndices);
+        return sortedIndices;
+    }
+}
+
+/**
+ * An iteration stage that uses the pulled model values and sampled batch data 
to compute the
+ * gradients.
+ */
+class ComputeGradients extends 
ProcessStage<LogisticRegressionWithFtrlTrainingContext> {
+    private final LossFunc lossFunc;
+
+    public ComputeGradients(LossFunc lossFunc) {
+        this.lossFunc = lossFunc;
+    }
+
+    @Override
+    public void process(LogisticRegressionWithFtrlTrainingContext context) 
throws IOException {
+        long[] indices = ComputeIndices.computeIndices(context.batchData);
+        double[] pulledModelValues = context.pulledValues;
+        double[] gradients = computeGradient(context.batchData, indices, 
pulledModelValues);
+
+        context.pushIndices = indices;
+        context.pushValues = gradients;
+    }
+
+    private double[] computeGradient(
+            List<LabeledLargePointWithWeight> batchData,
+            long[] sortedBatchIndices,
+            double[] pulledModelValues) {
+        Long2DoubleOpenHashMap coefficient = new 
Long2DoubleOpenHashMap(sortedBatchIndices.length);
+        for (int i = 0; i < sortedBatchIndices.length; i++) {
+            coefficient.put(sortedBatchIndices[i], pulledModelValues[i]);
+        }
+        Long2DoubleOpenHashMap cumGradients = new 
Long2DoubleOpenHashMap(sortedBatchIndices.length);
+
+        for (LabeledLargePointWithWeight dataPoint : batchData) {
+            double dot = dot(dataPoint.features, coefficient);
+            double multiplier = lossFunc.computeGradient(dataPoint.label, dot) 
* dataPoint.weight;
+
+            long[] featureIndices = dataPoint.features.f0;
+            double[] featureValues = dataPoint.features.f1;
+            double z;
+            for (int i = 0; i < featureIndices.length; i++) {
+                long currentIndex = featureIndices[i];
+                z = featureValues[i] * multiplier + 
cumGradients.getOrDefault(currentIndex, 0.);
+                cumGradients.put(currentIndex, z);
+            }
+        }
+        double[] cumGradientValues = new double[sortedBatchIndices.length];
+        for (int i = 0; i < sortedBatchIndices.length; i++) {
+            cumGradientValues[i] = cumGradients.get(sortedBatchIndices[i]);
+        }
+        return cumGradientValues;
+    }
+
+    private static double dot(
+            Tuple2<long[], double[]> features, Long2DoubleOpenHashMap 
coefficient) {
+        double dot = 0;
+        for (int i = 0; i < features.f0.length; i++) {
+            dot += features.f1[i] * coefficient.get(features.f0[i]);
+        }
+        return dot;
+    }
+}
+
+/** The context information of local computing process. */
+class LogisticRegressionWithFtrlTrainingContext

Review Comment:
   Would it be more intuitive to name it something like 
`FtrlIterationStageState`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java:
##########
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+import org.apache.flink.ml.util.Bits;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Message sent by worker to server that initializes the model as a dense 
array with defined range.
+ */
+public class ZerosToPushM implements Message {
+    public final int workerId;
+    public final int serverId;
+    public final long startIndex;
+    public final long endIndex;
+
+    public static final MessageType MESSAGE_TYPE = MessageType.ZEROS_TO_PUSH;

Review Comment:
   Can we make this field `private` or even remove this field for simplicity?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java:
##########
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.util.Bits;
+
+/** Utility functions for processing messages. */
+public class MessageUtils {
+
+    /** Retrieves the message type from the byte array. */
+    public static MessageType getMessageType(byte[] bytesData) {
+        char type = Bits.getChar(bytesData, 0);
+        return MessageType.valueOf(type);
+    }
+
+    /** Reads a long array from the byte array starting from the given offset. 
*/
+    public static long[] readLongArray(byte[] bytesData, int offset) {
+        int size = Bits.getInt(bytesData, offset);
+        offset += Integer.BYTES;
+        long[] result = new long[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = Bits.getLong(bytesData, offset);
+            offset += Long.BYTES;
+        }
+        return result;
+    }
+
+    /**
+     * Writes a long array to the byte array starting from the given offset.
+     *
+     * @return the next position to write on.
+     */
+    public static int writeLongArray(long[] array, byte[] bytesData, int 
offset) {
+        Bits.putInt(bytesData, offset, array.length);
+        offset += Integer.BYTES;
+        for (int i = 0; i < array.length; i++) {
+            Bits.putLong(bytesData, offset, array[i]);
+            offset += Long.BYTES;
+        }
+        return offset;
+    }
+
+    /** Returns the size of a long array in bytes. */
+    public static int getLongArraySizeInBytes(long[] array) {
+        return Integer.BYTES + array.length * Long.BYTES;
+    }
+
+    /** Reads a double array from the byte array starting from the given 
offset. */
+    public static double[] readDoubleArray(byte[] bytesData, int offset) {
+        int size = Bits.getInt(bytesData, offset);
+        offset += Integer.BYTES;
+        double[] result = new double[size];
+        for (int i = 0; i < size; i++) {
+            result[i] = Bits.getDouble(bytesData, offset);
+            offset += Long.BYTES;
+        }
+        return result;
+    }
+
+    /**
+     * Writes a double array to the byte array starting from the given offset.
+     *
+     * @return the next position to write on.
+     */
+    public static int writeDoubleArray(double[] array, byte[] bytesData, int 
offset) {

Review Comment:
   Would it be simpler to move these methods to `Bits.java` and make the method 
and parameter names consistent with the existing methods in `Bits`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/IndicesToPullM.java:
##########
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+import org.apache.flink.ml.util.Bits;
+import org.apache.flink.util.Preconditions;
+
+/** The indices one worker needs to pull from servers. */
+public class IndicesToPullM implements Message {
+    public final int serverId;
+    public final int workerId;
+    public final long[] indicesToPull;
+
+    public static final MessageType MESSAGE_TYPE = MessageType.INDICES_TO_PULL;
+
+    public IndicesToPullM(int serverId, int workerId, long[] indicesToPull) {
+        this.serverId = serverId;
+        this.workerId = workerId;
+        this.indicesToPull = indicesToPull;
+    }
+
+    public static IndicesToPullM fromBytes(byte[] bytesData) {

Review Comment:
   It seems simpler to rename `bytesData` as `bytes`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to