[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209245957


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+/** Initialize the model data. */
+void open(long startFeatureIndex, long endFeatureIndex);
+
+/** Applies the push to update the model data, e.g., using gradient to 
update model. */
+void handlePush(long[] keys, double[] values);
+
+/** Applies the pull and return the retrieved model data. */
+double[] handlePull(long[] keys);

Review Comment:
   In this PR, we propose to use two type of roles to describe the iterative 
machine learning training process following the idea of parameter servers. 
   - WorkerOp stores the training data and only involves local computation 
logic. When it needs to access model parameters and involves distributed 
communication, it communicates with ServerOp via `push/pull` primitive. The 
`push/pull` could be sparse key-value pairs or dense values. Currently only 
sparse key-value are supported.
   - ServerOp stores the model parameters and provide access to WorkerOps.
   - Subtasks of WorkerOp cannot talk to each other. Subtasks of ServerOp 
cannot talk to each other.
   
   `handlePush` and `handlePull` are two operations that the server answers the 
request from workers.
   The naming following the name of `push/pull`. It is possible that 
`handlePush` handle keys that have been updated with `handlePush`, but not 
necessary.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209245957


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+/** Initialize the model data. */
+void open(long startFeatureIndex, long endFeatureIndex);
+
+/** Applies the push to update the model data, e.g., using gradient to 
update model. */
+void handlePush(long[] keys, double[] values);
+
+/** Applies the pull and return the retrieved model data. */
+double[] handlePull(long[] keys);

Review Comment:
   In this PR, we propose to use two type of roles to describe the iterative 
machine learning training process following the idea of parameter servers. 
   - WorkerOp stores the training data and only involves local computation 
logic. When it needs to access model parameters and involves distributed 
communication, it communicates with ServerOp via `push/pull` primitive. The 
`push/pull` could be sparse key-value pairs or dense values. Currently only 
sparse key-value are supported.
   - ServerOp stores the model parameters and provide access to WorkerOps.
   - Subtasks of WorkerOp cannot talk to each other. Subtasks of ServerOp 
cannot talk to each other.
   
   `handlePush` and `handlePull` are two operations that the server answers the 
request from workers.
   The naming follows the name of `push/pull`. It is possible that `handlePush` 
handle keys that have been updated with `handlePush`, but not necessary.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209250004


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+/** Initialize the model data. */
+void open(long startFeatureIndex, long endFeatureIndex);
+
+/** Applies the push to update the model data, e.g., using gradient to 
update model. */
+void handlePush(long[] keys, double[] values);
+
+/** Applies the pull and return the retrieved model data. */
+double[] handlePull(long[] keys);
+
+/** Returns model pieces with the format of (startFeatureIdx, 
endFeatureIdx, modelValues). */
+Iterator> getModelPieces();

Review Comment:
   The model segments are continuously updated/retrieved by push/pull (i.e., 
`handlePush` and `handlePull`).
   
   I have renamed `pieces` as segments in the PR and also added the above 
description in the java doc.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209256528


##
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl

Review Comment:
   Good point. The current implementation of `LogisticRegressionWithFtrl` and 
`LogisticRegression` employed different optimizers and I would like to update 
the implementation of `LogisticRegression` in a later PR.
   
   I would like to abstract an optimizer for existing implementation of 
LogisticRegression/LinearSVC, etc. There are two possible options:
   - For each optimizer for each model, we construct a new Estimator. For 
example `LogisticRegressionWithSGD` and `LogisticRegressionWithFTRL`, 
`LogisticRegressionWithLBFGS`, etc.
   - We abstract optimizer as one parameter of LogisticRegression. We have only 
one Estimator pair for each model and let users set different optimizers.
   
   I think option-2 is more intuitive, but we can talk offlline for this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209256528


##
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl

Review Comment:
   Good point. The current implementation of `LogisticRegressionWithFtrl` and 
`LogisticRegression` employed different optimizers and and should be unified 
(probably in a later PR).
   
   I would like to abstract an optimizer for existing implementation of 
different models (LogisticRegression, LinearSVC, etc). There are two possible 
options:
   - For each optimizer and each model, we construct a new Estimator. For 
example `LogisticRegressionWithSGD` and `LogisticRegressionWithFTRL`, 
`LogisticRegressionWithLBFGS` for logistic regression model.
   - We abstract optimizer as one parameter of LogisticRegression. We have only 
one Estimator for each model and let users set different optimizers.
   
   I think option-2 is more intuitive, but we can talk offlline for this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209259740


##
flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java:
##
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.feature;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/** A data point to represent values that use long as index and double as 
values. */
+public class LabeledLargePointWithWeight {
+public Tuple2 features;

Review Comment:
   `SparseVector` currently only supports `int` index. However, the range of 
`int` cannot meet the requirements of high dimensional data.
   
   It is a bit tricky to extend `SparseVector` to support `long` as index.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209259740


##
flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java:
##
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.feature;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/** A data point to represent values that use long as index and double as 
values. */
+public class LabeledLargePointWithWeight {
+public Tuple2 features;

Review Comment:
   `SparseVector` currently only supports `int` index. However, the index of 
real-world data could exceeds the range of `int` and we use `long` to describe 
the index.
   
   It is a bit tricky to extend `SparseVector` to support `long` as index.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209262254


##
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java:
##
@@ -81,11 +82,49 @@ public DataFrame transform(DataFrame input) {
 public LogisticRegressionModelServable setModelData(InputStream... 
modelDataInputs)
 throws IOException {
 Preconditions.checkArgument(modelDataInputs.length == 1);
+List modelPieces = new ArrayList<>();
+while (true) {
+try {
+LogisticRegressionModelData piece =
+LogisticRegressionModelData.decode(modelDataInputs[0]);

Review Comment:
   Storing all segments in a list would probably leads to OOM here. When 
dealing with large models, we probably need to partition them into segments and 
store it one by one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209268250


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/ZerosToPushM.java:
##
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+import org.apache.flink.ml.util.Bits;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Message sent by worker to server that initializes the model as a dense 
array with defined range.
+ */
+public class ZerosToPushM implements Message {
+public final int workerId;
+public final int serverId;
+public final long startIndex;
+public final long endIndex;
+
+public static final MessageType MESSAGE_TYPE = MessageType.ZEROS_TO_PUSH;

Review Comment:
   Thanks for pointing this out. It has been removed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209272423


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java:
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Merges the message from different servers for one pull request.
+ *
+ * Note that for each single-thread worker, there are at exactly 
#numServers pieces for each pull
+ * request in the feedback edge.
+ */
+public class MirrorWorkerOperator extends AbstractStreamOperator

Review Comment:
   I name it as `mirror` here for the following two reasons:
   - This operator merges the answer from servers and feed it to workers. It is 
doing the concatenation for workers.
   - It is colocated with Worker operator.
   
   I am open to other names~



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209272423


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java:
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Merges the message from different servers for one pull request.
+ *
+ * Note that for each single-thread worker, there are at exactly 
#numServers pieces for each pull
+ * request in the feedback edge.
+ */
+public class MirrorWorkerOperator extends AbstractStreamOperator

Review Comment:
   I name it as `mirror` here for the following two reasons:
   - This operator merges the answer from servers and feed it to workers. It is 
doing the concatenation for workers.
   - It is colocated with Worker operator.
   
   I am also open to other names.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209272423


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java:
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Merges the message from different servers for one pull request.
+ *
+ * Note that for each single-thread worker, there are at exactly 
#numServers pieces for each pull
+ * request in the feedback edge.
+ */
+public class MirrorWorkerOperator extends AbstractStreamOperator

Review Comment:
   I name it as `mirror` here for the following two reasons:
   - This operator merges/concates the answer from servers and feeds it to 
workers. In the traditional parameter server architecture, the concatenation 
happens on workers.
   - It is colocated with Worker operator.
   
   I am also open to other names.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209272423


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java:
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Merges the message from different servers for one pull request.
+ *
+ * Note that for each single-thread worker, there are at exactly 
#numServers pieces for each pull
+ * request in the feedback edge.
+ */
+public class MirrorWorkerOperator extends AbstractStreamOperator

Review Comment:
   I name it as `mirror` here for the following two reasons:
   - This operator merges/concates the answer from SeverOperaotr and feeds it 
to WorkerOperator. In the traditional parameter server architecture, the 
concatenation happens on workers.
   - It is colocated with Worker operator.
   
   I am also open to other names.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209650931


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java:
##
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+/** Message Type between workers and servers. */
+public enum MessageType {
+ZEROS_TO_PUSH((char) 0),

Review Comment:
   Thanks for pointing this out. I have renamed them as 
`INITIALIZE_MODEL_AS_ZERO, PULL_INDEX, PULLED_VALUE and PUSH_KV` and also added 
java doc for these enums.
   
   Note that `PULLED_VALUE` adds the `~ed` suffix because it is sent from 
servers to workers, different from other message types. What do you think?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-29 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209661397


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java:
##
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.util.Bits;
+
+/** Utility functions for processing messages. */
+public class MessageUtils {
+
+/** Retrieves the message type from the byte array. */
+public static MessageType getMessageType(byte[] bytesData) {
+char type = Bits.getChar(bytesData, 0);
+return MessageType.valueOf(type);
+}
+
+/** Reads a long array from the byte array starting from the given offset. 
*/
+public static long[] readLongArray(byte[] bytesData, int offset) {
+int size = Bits.getInt(bytesData, offset);
+offset += Integer.BYTES;
+long[] result = new long[size];
+for (int i = 0; i < size; i++) {
+result[i] = Bits.getLong(bytesData, offset);
+offset += Long.BYTES;
+}
+return result;
+}
+
+/**
+ * Writes a long array to the byte array starting from the given offset.
+ *
+ * @return the next position to write on.
+ */
+public static int writeLongArray(long[] array, byte[] bytesData, int 
offset) {
+Bits.putInt(bytesData, offset, array.length);
+offset += Integer.BYTES;
+for (int i = 0; i < array.length; i++) {
+Bits.putLong(bytesData, offset, array[i]);
+offset += Long.BYTES;
+}
+return offset;
+}
+
+/** Returns the size of a long array in bytes. */
+public static int getLongArraySizeInBytes(long[] array) {
+return Integer.BYTES + array.length * Long.BYTES;
+}
+
+/** Reads a double array from the byte array starting from the given 
offset. */
+public static double[] readDoubleArray(byte[] bytesData, int offset) {
+int size = Bits.getInt(bytesData, offset);
+offset += Integer.BYTES;
+double[] result = new double[size];
+for (int i = 0; i < size; i++) {
+result[i] = Bits.getDouble(bytesData, offset);
+offset += Long.BYTES;
+}
+return result;
+}
+
+/**
+ * Writes a double array to the byte array starting from the given offset.
+ *
+ * @return the next position to write on.
+ */
+public static int writeDoubleArray(double[] array, byte[] bytesData, int 
offset) {

Review Comment:
   `Bits.java` is copy-paste from `java.io.Bits` without any modifications. 
Moreover, `writeDoubleArray` is only used is the ps-infra, so shall we keep it 
here for now?
   
   I have updated the name following the convention of `Bits.java` and rename 
the functions as `putXX` and `getXX`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-30 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209803567


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/training/TrainingContext.java:
##
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.training;
+
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+
+import java.io.Serializable;
+
+/**
+ * Stores the context information that is alive during the training process. 
Note that the context
+ * information will be updated by each {@link IterationStage}.
+ *
+ * Note that subclasses should take care of the snapshot of object stored 
in {@link
+ * TrainingContext} if the object satisfies that: the write-process is 
followed by an {@link
+ * PullStage}, which is later again read by other stages.
+ */
+public interface TrainingContext extends Serializable {

Review Comment:
   It is indeed not a `listener`. How about we rename it as `MLSession`?
   
   We can talk about this offline.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-30 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209843509


##
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl
+implements Estimator,
+LogisticRegressionWithFtrlParams {
+
+private final Map, Object> paramMap = new HashMap<>();
+
+public LogisticRegressionWithFtrl() {
+ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+}
+
+@Override
+public LogisticRegressionModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+String classificationType = getMultiClass();
+Preconditions.checkArgument(
+"auto".equals(classificationType) || 
"binomial".equals(classificationType),
+"Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+DataStream trainData =
+tEnv.toDataStream(inputs[0])
+.map(
+(MapFunction)
+dataPoint -> {
+double weight =
+getWeightCol() == null
+? 1.0
+

[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-05-30 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209952491


##
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl
+implements Estimator,
+LogisticRegressionWithFtrlParams {
+
+private final Map, Object> paramMap = new HashMap<>();
+
+public LogisticRegressionWithFtrl() {
+ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+}
+
+@Override
+public LogisticRegressionModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+String classificationType = getMultiClass();
+Preconditions.checkArgument(
+"auto".equals(classificationType) || 
"binomial".equals(classificationType),
+"Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+DataStream trainData =
+tEnv.toDataStream(inputs[0])
+.map(
+(MapFunction)
+dataPoint -> {
+double weight =
+getWeightCol() == null
+? 1.0
+

[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1222525107


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/updater/ModelUpdater.java:
##
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.updater;
+
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A model updater that could be used to handle push/pull request from workers.
+ *
+ * Note that model updater should also ensure that model data is robust to 
failures.
+ */
+public interface ModelUpdater extends Serializable {
+
+/** Initialize the model data. */
+void open(long startFeatureIndex, long endFeatureIndex);
+
+/** Applies the push to update the model data, e.g., using gradient to 
update model. */
+void handlePush(long[] keys, double[] values);
+
+/** Applies the pull and return the retrieved model data. */
+double[] handlePull(long[] keys);

Review Comment:
   After some offline discussion, we aggree to rename the functions as `update` 
and `get`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1222525949


##
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl

Review Comment:
   We have aggreed to go with option-2 since it is more intuitive and simpler 
for users.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1209650931


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageType.java:
##
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+/** Message Type between workers and servers. */
+public enum MessageType {
+ZEROS_TO_PUSH((char) 0),

Review Comment:
   Thanks for pointing this out. I have refactored the message and rename them 
as `INITIALIZE, PUSH, PULL, ALL_REDUCE`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1222527288


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/message/MessageUtils.java:
##
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps.message;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.util.Bits;
+
+/** Utility functions for processing messages. */
+public class MessageUtils {
+
+/** Retrieves the message type from the byte array. */
+public static MessageType getMessageType(byte[] bytesData) {
+char type = Bits.getChar(bytesData, 0);
+return MessageType.valueOf(type);
+}
+
+/** Reads a long array from the byte array starting from the given offset. 
*/
+public static long[] readLongArray(byte[] bytesData, int offset) {
+int size = Bits.getInt(bytesData, offset);
+offset += Integer.BYTES;
+long[] result = new long[size];
+for (int i = 0; i < size; i++) {
+result[i] = Bits.getLong(bytesData, offset);
+offset += Long.BYTES;
+}
+return result;
+}
+
+/**
+ * Writes a long array to the byte array starting from the given offset.
+ *
+ * @return the next position to write on.
+ */
+public static int writeLongArray(long[] array, byte[] bytesData, int 
offset) {
+Bits.putInt(bytesData, offset, array.length);
+offset += Integer.BYTES;
+for (int i = 0; i < array.length; i++) {
+Bits.putLong(bytesData, offset, array[i]);
+offset += Long.BYTES;
+}
+return offset;
+}
+
+/** Returns the size of a long array in bytes. */
+public static int getLongArraySizeInBytes(long[] array) {
+return Integer.BYTES + array.length * Long.BYTES;
+}
+
+/** Reads a double array from the byte array starting from the given 
offset. */
+public static double[] readDoubleArray(byte[] bytesData, int offset) {
+int size = Bits.getInt(bytesData, offset);
+offset += Integer.BYTES;
+double[] result = new double[size];
+for (int i = 0; i < size; i++) {
+result[i] = Bits.getDouble(bytesData, offset);
+offset += Long.BYTES;
+}
+return result;
+}
+
+/**
+ * Writes a double array to the byte array starting from the given offset.
+ *
+ * @return the next position to write on.
+ */
+public static int writeDoubleArray(double[] array, byte[] bytesData, int 
offset) {

Review Comment:
   After some offline discussion, we aggree to move these methods to our own 
version of `Bits.java`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1222527899


##
flink-ml-lib/src/main/java/org/apache/flink/ml/common/ps/MirrorWorkerOperator.java:
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.ps;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.common.ps.message.ValuesPulledM;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Preconditions;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Merges the message from different servers for one pull request.
+ *
+ * Note that for each single-thread worker, there are at exactly 
#numServers pieces for each pull
+ * request in the feedback edge.
+ */
+public class MirrorWorkerOperator extends AbstractStreamOperator

Review Comment:
   After some offline discussion, we agree to rename the operator as 
`ResponseAssemblerOperator`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1222529588


##
flink-ml-servable-core/src/main/java/org/apache/flink/ml/common/feature/LabeledLargePointWithWeight.java:
##
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.feature;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+
+/** A data point to represent values that use long as index and double as 
values. */
+public class LabeledLargePointWithWeight {
+public Tuple2 features;

Review Comment:
   After some offline discussion, we aggree that we should extend the current 
implementation of `Vector` and let it support `SparseLongDoubleVector`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #237: [Flink-27826] Support training very high dimensional logistic regression

2023-06-07 Thread via GitHub


zhipeng93 commented on code in PR #237:
URL: https://github.com/apache/flink-ml/pull/237#discussion_r1222532049


##
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionWithFtrl.java:
##
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledLargePointWithWeight;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.common.ps.training.IterationStageList;
+import org.apache.flink.ml.common.ps.training.ProcessStage;
+import org.apache.flink.ml.common.ps.training.PullStage;
+import org.apache.flink.ml.common.ps.training.PushStage;
+import org.apache.flink.ml.common.ps.training.SerializableConsumer;
+import org.apache.flink.ml.common.ps.training.TrainingContext;
+import org.apache.flink.ml.common.ps.training.TrainingUtils;
+import org.apache.flink.ml.common.updater.FTRL;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.runtime.util.ResettableIterator;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SerializableFunction;
+import org.apache.flink.util.function.SerializableSupplier;
+
+import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the large scale logistic regression algorithm 
using FTRL optimizer.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegressionWithFtrl
+implements Estimator,
+LogisticRegressionWithFtrlParams {
+
+private final Map, Object> paramMap = new HashMap<>();
+
+public LogisticRegressionWithFtrl() {
+ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+}
+
+@Override
+public LogisticRegressionModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+String classificationType = getMultiClass();
+Preconditions.checkArgument(
+"auto".equals(classificationType) || 
"binomial".equals(classificationType),
+"Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+DataStream trainData =
+tEnv.toDataStream(inputs[0])
+.map(
+(MapFunction)
+dataPoint -> {
+double weight =
+getWeightCol() == null
+? 1.0
+