This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 239788f  [FLINK-27170] Add Transformer and Estimator of 
OnlineLogistciRegression
239788f is described below

commit 239788f2b1f1f3a4e55ca112517980b598705a15
Author: weibo <wbz...@pku.edu.cn>
AuthorDate: Thu Jun 2 10:42:09 2022 +0800

    [FLINK-27170] Add Transformer and Estimator of OnlineLogistciRegression
    
    This closes #83.
---
 .../ml/common/datastream/DataStreamUtils.java      |  74 +++
 .../main/java/org/apache/flink/ml/linalg/BLAS.java |  46 +-
 .../ml/common/datastream/DataStreamUtilsTest.java  |  13 +
 .../java/org/apache/flink/ml/linalg/BLASTest.java  |  11 +
 .../logisticregression/LogisticRegression.java     |   2 +-
 .../LogisticRegressionModel.java                   |   2 +-
 .../LogisticRegressionModelData.java               |  66 +-
 .../OnlineLogisticRegression.java                  | 424 +++++++++++++
 .../OnlineLogisticRegressionModel.java             | 198 ++++++
 .../OnlineLogisticRegressionModelParams.java       |  50 ++
 .../OnlineLogisticRegressionParams.java            |  66 ++
 .../flink/ml/clustering/kmeans/OnlineKMeans.java   |  57 +-
 .../ml/classification/LogisticRegressionTest.java  |   2 +-
 .../OnlineLogisticRegressionTest.java              | 687 +++++++++++++++++++++
 .../apache/flink/ml/feature/MinMaxScalerTest.java  |  15 +-
 .../java/org/apache/flink/ml/util/TestUtils.java   |  14 +
 .../ml/lib/classification/logisticregression.py    | 121 +++-
 .../tests/test_logisticregression.py               |  41 +-
 18 files changed, 1803 insertions(+), 86 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 10073b8..45ad02e 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -19,21 +19,32 @@
 package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.List;
 
 /** Provides utility functions for {@link DataStream}. */
 @Internal
@@ -182,4 +193,67 @@ public class DataStreamUtils {
             }
         }
     }
+
+    /**
+     * Splits the input data into global batches of batchSize. After 
splitting, each global batch is
+     * further split into local batches for downstream operators with each 
worker has one batch.
+     */
+    public static <T> DataStream<T[]> generateBatchData(
+            DataStream<T> inputData, final int downStreamParallelism, int 
batchSize) {
+        return inputData
+                .countWindowAll(batchSize)
+                .apply(new GlobalBatchCreator<>())
+                .flatMap(new GlobalBatchSplitter<>(downStreamParallelism))
+                .partitionCustom((chunkId, numPartitions) -> chunkId, x -> 
x.f0)
+                .map(
+                        new MapFunction<Tuple2<Integer, T[]>, T[]>() {
+                            @Override
+                            public T[] map(Tuple2<Integer, T[]> integerTuple2) 
throws Exception {
+                                return integerTuple2.f1;
+                            }
+                        });
+    }
+
+    /** Splits the input data into global batches. */
+    private static class GlobalBatchCreator<T> implements AllWindowFunction<T, 
T[], GlobalWindow> {
+        @Override
+        public void apply(GlobalWindow timeWindow, Iterable<T> iterable, 
Collector<T[]> collector) {
+            List<T> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray((T[]) new Object[0]));
+        }
+    }
+
+    /**
+     * An operator that splits a global batch into evenly-sized local batches, 
and distributes them
+     * to downstream operator.
+     */
+    private static class GlobalBatchSplitter<T>
+            implements FlatMapFunction<T[], Tuple2<Integer, T[]>> {
+        private final int downStreamParallelism;
+
+        public GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> 
collector) {
+            int div = values.length / downStreamParallelism;
+            int mod = values.length % downStreamParallelism;
+
+            int offset = 0;
+            int i = 0;
+
+            int size = div + 1;
+            for (; i < mod; i++) {
+                collector.collect(Tuple2.of(i, Arrays.copyOfRange(values, 
offset, offset + size)));
+                offset += size;
+            }
+
+            size = div;
+            for (; i < downStreamParallelism; i++) {
+                collector.collect(Tuple2.of(i, Arrays.copyOfRange(values, 
offset, offset + size)));
+                offset += size;
+            }
+        }
+    }
 }
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
index 24cc3ef..c00f642 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -65,12 +65,54 @@ public class BLAS {
         }
     }
 
-    /** x \cdot y . */
-    public static double dot(DenseVector x, DenseVector y) {
+    /** Computes the dot of the two vectors (y \dot x). */
+    public static double dot(Vector x, Vector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
+        if (x instanceof SparseVector) {
+            if (y instanceof SparseVector) {
+                return dot((SparseVector) x, (SparseVector) y);
+            } else {
+                return dot((DenseVector) y, (SparseVector) x);
+            }
+        } else {
+            if (y instanceof SparseVector) {
+                return dot((DenseVector) x, (SparseVector) y);
+            } else {
+                return dot((DenseVector) x, (DenseVector) y);
+            }
+        }
+    }
+
+    private static double dot(DenseVector x, DenseVector y) {
         return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
     }
 
+    private static double dot(DenseVector x, SparseVector y) {
+        double dotValue = 0.0;
+        for (int i = 0; i < y.indices.length; ++i) {
+            dotValue += y.values[i] * x.values[y.indices[i]];
+        }
+        return dotValue;
+    }
+
+    private static double dot(SparseVector x, SparseVector y) {
+        double dotValue = 0;
+        int p0 = 0;
+        int p1 = 0;
+        while (p0 < x.values.length && p1 < y.values.length) {
+            if (x.indices[p0] == y.indices[p1]) {
+                dotValue += x.values[p0] * y.values[p1];
+                p0++;
+                p1++;
+            } else if (x.indices[p0] < y.indices[p1]) {
+                p0++;
+            } else {
+                p1++;
+            }
+        }
+        return dotValue;
+    }
+
     /** \sqrt(\sum_i x_i * x_i) . */
     public static double norm2(DenseVector x) {
         return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index 7dc88c8..a968a0e 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -37,6 +37,7 @@ import org.junit.Test;
 import java.util.List;
 
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 
 /** Tests the {@link DataStreamUtils}. */
@@ -74,6 +75,18 @@ public class DataStreamUtilsTest {
         assertArrayEquals(new long[] {190L}, 
sum.stream().mapToLong(Long::longValue).toArray());
     }
 
+    @Test
+    public void testGenerateBatchData() throws Exception {
+        DataStream<Long> dataStream =
+                env.fromParallelCollection(new NumberSequenceIterator(0L, 
19L), Types.LONG);
+        DataStream<Long[]> result = 
DataStreamUtils.generateBatchData(dataStream, 2, 4);
+        List<Long[]> batches = 
IteratorUtils.toList(result.executeAndCollect());
+        for (Long[] batch : batches) {
+            assertEquals(2, batch.length);
+        }
+        assertEquals(10, batches.size());
+    }
+
     /** A simple implementation for a {@link MapPartitionFunction}. */
     private static class TestMapPartitionFunc extends 
RichMapPartitionFunction<Long, Integer> {
 
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
index 7055c62..21d68a9 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
@@ -70,7 +70,18 @@ public class BLASTest {
     @Test
     public void testDot() {
         DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5);
+        SparseVector sparseVector1 =
+                Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {1., 1., 
4.});
+        SparseVector sparseVector2 =
+                Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {1., 2., 
1.});
+        // Tests dot(dense, dense).
         assertEquals(-3, BLAS.dot(inputDenseVec, anotherDenseVec), TOLERANCE);
+        // Tests dot(dense, sparse).
+        assertEquals(-19, BLAS.dot(inputDenseVec, sparseVector1), TOLERANCE);
+        // Tests dot(sparse, dense).
+        assertEquals(1, BLAS.dot(sparseVector2, inputDenseVec), TOLERANCE);
+        // Tests dot(sparse, sparse).
+        assertEquals(5, BLAS.dot(sparseVector1, sparseVector2), TOLERANCE);
     }
 
     @Test
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index df8c386..551b66a 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -115,7 +115,7 @@ public class LogisticRegression
                 optimizer.optimize(initModelData, trainData, 
BinaryLogisticLoss.INSTANCE);
 
         DataStream<LogisticRegressionModelData> modelData =
-                rawModelData.map(LogisticRegressionModelData::new);
+                rawModelData.map(vector -> new 
LogisticRegressionModelData(vector, 0));
         LogisticRegressionModel model =
                 new 
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
         ReadWriteUtils.updateExistingParams(model, paramMap);
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
index ac42142..675846a 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
@@ -162,7 +162,7 @@ public class LogisticRegressionModel
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Row predictOneDataPoint(DenseVector feature, DenseVector 
coefficient) {
+    protected static Row predictOneDataPoint(Vector feature, DenseVector 
coefficient) {
         double dotValue = BLAS.dot(feature, coefficient);
         double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue));
         return Row.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob));
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
index d2f451f..a9a6285 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
@@ -25,10 +26,11 @@ import 
org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
@@ -36,9 +38,10 @@ import org.apache.flink.table.api.internal.TableImpl;
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link 
OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
  * to save/load model data.
@@ -46,12 +49,49 @@ import java.io.OutputStream;
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    public LogisticRegressionModelData() {}
+
+    public LogisticRegressionModelData(DenseVector coefficient, long 
modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
-    public LogisticRegressionModelData() {}
+    /**
+     * Generates a Table containing a {@link LogisticRegressionModelData} 
instance with randomly
+     * generated coefficient.
+     *
+     * @param tEnv The environment where to create the table.
+     * @param dim The size of generated coefficient.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(StreamTableEnvironment tEnv, 
int dim, int seed) {
+        StreamExecutionEnvironment env = 
TableUtils.getExecutionEnvironment(tEnv);
+        return tEnv.fromDataStream(
+                env.fromElements(1).map(new RandomModelDataGenerator(dim, 
seed)));
+    }
+
+    private static class RandomModelDataGenerator
+            implements MapFunction<Integer, LogisticRegressionModelData> {
+        private final int dim;
+        private final int seed;
+
+        public RandomModelDataGenerator(int dim, int seed) {
+            this.dim = dim;
+            this.seed = seed;
+        }
+
+        @Override
+        public LogisticRegressionModelData map(Integer integer) throws 
Exception {
+            DenseVector vector = new DenseVector(dim);
+            Random random = new Random(seed);
+            for (int j = 0; j < dim; j++) {
+                vector.values[j] = random.nextDouble();
+            }
+            return new LogisticRegressionModelData(vector, 0L);
+        }
+    }
 
     /**
      * Converts the table model to a data stream.
@@ -63,21 +103,24 @@ public class LogisticRegressionModelData {
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
         return tEnv.toDataStream(modelData)
-                .map(x -> new LogisticRegressionModelData(((Vector) 
x.getField(0)).toDense()));
+                .map(x -> new LogisticRegressionModelData(x.getFieldAs(0), 
x.getFieldAs(1)));
     }
 
-    /** Data encoder for {@link LogisticRegressionModel}. */
+    /** Data encoder for {@link LogisticRegression} and {@link 
OnlineLogisticRegression}. */
     public static class ModelDataEncoder implements 
Encoder<LogisticRegressionModelData> {
 
         @Override
         public void encode(LogisticRegressionModelData modelData, OutputStream 
outputStream)
                 throws IOException {
+            DataOutputViewStreamWrapper dataOutputViewStreamWrapper =
+                    new DataOutputViewStreamWrapper(outputStream);
             DenseVectorSerializer.INSTANCE.serialize(
-                    modelData.coefficient, new 
DataOutputViewStreamWrapper(outputStream));
+                    modelData.coefficient, dataOutputViewStreamWrapper);
+            dataOutputViewStreamWrapper.writeLong(modelData.modelVersion);
         }
     }
 
-    /** Data decoder for {@link LogisticRegressionModel}. */
+    /** Data decoder for {@link LogisticRegression} and {@link 
OnlineLogisticRegression}. */
     public static class ModelDataDecoder extends 
SimpleStreamFormat<LogisticRegressionModelData> {
 
         @Override
@@ -88,10 +131,13 @@ public class LogisticRegressionModelData {
                 @Override
                 public LogisticRegressionModelData read() throws IOException {
                     try {
+                        DataInputViewStreamWrapper dataInputViewStreamWrapper =
+                                new DataInputViewStreamWrapper(inputStream);
                         DenseVector coefficient =
                                 DenseVectorSerializer.INSTANCE.deserialize(
-                                        new 
DataInputViewStreamWrapper(inputStream));
-                        return new LogisticRegressionModelData(coefficient);
+                                        dataInputViewStreamWrapper);
+                        long modelVersion = 
dataInputViewStreamWrapper.readLong();
+                        return new LogisticRegressionModelData(coefficient, 
modelVersion);
                     } catch (EOFException e) {
                         return null;
                     }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
new file mode 100644
index 0000000..59a9102
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
@@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the online logistic regression algorithm. The 
online optimizer of
+ * this algorithm is The FTRL-Proximal proposed by H.Brendan McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200";>H. Brendan McMahan 
et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, 
OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                
LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new FeaturesLabelExtractor(
+                                        getFeaturesCol(), getLabelCol(), 
getWeightCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), 
getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new 
OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesLabelExtractor implements MapFunction<Row, 
Row> {
+        private final String featuresCol;
+        private final String labelCol;
+        private final String weightCol;
+
+        private FeaturesLabelExtractor(String featuresCol, String labelCol, 
String weightCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+            this.weightCol = weightCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (weightCol == null) {
+                return Row.of(row.getField(featuresCol), 
row.getField(labelCol));
+            } else {
+                return Row.of(
+                        row.getField(featuresCol), row.getField(labelCol), 
row.getField(weightCol));
+            }
+        }
+    }
+
+    /**
+     * In the implementation of ftrl optimizer, gradients are calculated in 
distributed workers and
+     * reduce them to one final gradient. The reduced gradient is used to 
update model by ftrl
+     * method. When the feature vector is dense, it can get the same result as 
tensorflow's ftrl. If
+     * feature vector is sparse, we use the mean value in every feature dim 
instead of mean value of
+     * whole vector, which can get a better convergence.
+     *
+     * <p>See 
https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in 
org.apache.flink.ml.common.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double 
elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the 
number "
+                            + "of elements in each batch. Some subtasks might 
be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, 
batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> 
{
+                                                BLAS.axpy(1.0, 
gradientInfo[0], newGradientInfo[0]);
+                                                BLAS.axpy(1.0, 
gradientInfo[1], newGradientInfo[1]);
+                                                if (newGradientInfo[2] == 
null) {
+                                                    newGradientInfo[2] = 
gradientInfo[2];
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new 
CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), 
DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, 
CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws 
Exception {
+            return new LogisticRegressionModelData(denseVector, 
modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext 
functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new 
ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends 
AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) 
throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - 
Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + 
l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends 
AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, 
DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weightSum;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
DenseVector.class));
+            TypeInformation<Row[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws 
Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = 
IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                double weight = point.getArity() == 2 ? 1.0 : 
point.getFieldAs(2);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weightSum = new double[gradient.length];
+                }
+                double p = BLAS.dot(modelData, vec);
+                p = 1 / (1 + Math.exp(-p));
+                if (vec instanceof DenseVector) {
+                    DenseVector dvec = (DenseVector) vec;
+                    for (int i = 0; i < modelData.size(); ++i) {
+                        gradient[i] += (p - label) * dvec.values[i];
+                        weightSum[i] += 1.0;
+                    }
+                } else {
+                    SparseVector svec = (SparseVector) vec;
+                    for (int i = 0; i < svec.indices.length; ++i) {
+                        int idx = svec.indices[i];
+                        gradient[idx] += (p - label) * svec.values[i];
+                        weightSum[idx] += weight;
+                    }
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weightSum),
+                                    
(getRuntimeContext().getIndexOfThisSubtask() == 0)
+                                            ? modelData
+                                            : null
+                                }));
+            }
+            Arrays.fill(gradient, 0.0);
+            Arrays.fill(weightSum, 0.0);
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector> modelDataRecord) 
throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            calculateGradient();
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                
LogisticRegressionModelData.getModelDataStream(initModelDataTable),
+                path,
+                new LogisticRegressionModelData.ModelDataEncoder());
+    }
+
+    public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, 
String path)
+            throws IOException {
+        OnlineLogisticRegression onlineLogisticRegression = 
ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+        onlineLogisticRegression.setInitialModelData(modelDataTable);
+        return onlineLogisticRegression;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the 
provided model data
+     * table.
+     */
+    public OnlineLogisticRegression setInitialModelData(Table 
initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
new file mode 100644
index 0000000..eab5cf6
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint;
+
+/**
+ * A Model which classifies data using the model data computed by {@link 
OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                
OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    public static final String MODEL_DATA_VERSION_GAUGE_KEY = 
"modelDataVersion";
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                
LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, 
getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends 
AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, 
LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0L;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String 
featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new 
ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            MODEL_DATA_VERSION_GAUGE_KEY,
+                            (Gauge<String>) () -> 
Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws 
Exception {
+            processElement(streamRecord);
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LogisticRegressionModelData> 
streamRecord)
+                throws Exception {
+            LogisticRegressionModelData modelData = streamRecord.getValue();
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            for (Row dataPoint : bufferedPointsState.get()) {
+                processElement(new StreamRecord<>(dataPoint));
+            }
+            bufferedPointsState.clear();
+        }
+
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row dataPoint = streamRecord.getValue();
+            if (coefficient == null) {
+                bufferedPointsState.add(dataPoint);
+                return;
+            }
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Row predictionResult = predictOneDataPoint(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(
+                                    dataPoint,
+                                    Row.of(
+                                            predictionResult.getField(0),
+                                            predictionResult.getField(1),
+                                            modelDataVersion))));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineLogisticRegressionModel load(StreamTableEnvironment 
tEnv, String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java
new file mode 100644
index 0000000..573cd48
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link OnlineLogisticRegressionModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> 
{
+    Param<String> MODEL_VERSION_COL =
+            new StringParam(
+                    "modelVersionCol",
+                    "Model version column name.",
+                    "modelVersion",
+                    ParamValidators.notNull());
+
+    default String getModelVersionCol() {
+        return get(MODEL_VERSION_COL);
+    }
+
+    default T setModelVersionCol(String value) {
+        set(MODEL_VERSION_COL, value);
+        return (T) this;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java
new file mode 100644
index 0000000..961b7c5
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The alpha parameter of ftrl.", 0.1, 
ParamValidators.gt(0.0));
+
+    Param<Double> BETA =
+            new DoubleParam("beta", "The beta parameter of ftrl.", 0.1, 
ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    default Double getBeta() {
+        return get(BETA);
+    }
+
+    default T setBeta(Double value) {
+        return set(BETA, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
index b876b13..4112527 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.ml.clustering.kmeans;
 
-import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.ListState;
@@ -31,6 +30,7 @@ import org.apache.flink.iteration.IterationBodyResult;
 import org.apache.flink.iteration.Iterations;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.common.distance.DistanceMeasure;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
@@ -41,22 +41,18 @@ import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.collections.IteratorUtils;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -169,10 +165,7 @@ public class OnlineKMeans
                             + "of elements in each batch. Some subtasks might 
be idling forever.");
 
             DataStream<KMeansModelData> newModelData =
-                    points.countWindowAll(batchSize)
-                            .apply(new GlobalBatchCreator())
-                            .flatMap(new GlobalBatchSplitter(parallelism))
-                            .rebalance()
+                    DataStreamUtils.generateBatchData(points, parallelism, 
batchSize)
                             .connect(modelData.broadcast())
                             .transform(
                                     "ModelDataLocalUpdater",
@@ -340,52 +333,6 @@ public class OnlineKMeans
         }
     }
 
-    /**
-     * An operator that splits a global batch into evenly-sized local batches, 
and distributes them
-     * to downstream operator.
-     */
-    private static class GlobalBatchSplitter
-            implements FlatMapFunction<DenseVector[], DenseVector[]> {
-        private final int downStreamParallelism;
-
-        private GlobalBatchSplitter(int downStreamParallelism) {
-            this.downStreamParallelism = downStreamParallelism;
-        }
-
-        @Override
-        public void flatMap(DenseVector[] values, Collector<DenseVector[]> 
collector) {
-            int div = values.length / downStreamParallelism;
-            int mod = values.length % downStreamParallelism;
-
-            int offset = 0;
-            int i = 0;
-
-            int size = div + 1;
-            for (; i < mod; i++) {
-                collector.collect(Arrays.copyOfRange(values, offset, offset + 
size));
-                offset += size;
-            }
-
-            size = div;
-            for (; i < downStreamParallelism; i++) {
-                collector.collect(Arrays.copyOfRange(values, offset, offset + 
size));
-                offset += size;
-            }
-        }
-    }
-
-    private static class GlobalBatchCreator
-            implements AllWindowFunction<DenseVector, DenseVector[], 
GlobalWindow> {
-        @Override
-        public void apply(
-                GlobalWindow timeWindow,
-                Iterable<DenseVector> iterable,
-                Collector<DenseVector[]> collector) {
-            List<DenseVector> points = 
IteratorUtils.toList(iterable.iterator());
-            collector.collect(points.toArray(new DenseVector[0]));
-        }
-    }
-
     /**
      * Sets the initial model data of the online training process with the 
provided model data
      * table.
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 129fd2e..37815e0 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -250,7 +250,7 @@ public class LogisticRegressionTest {
         LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
         model = TestUtils.saveAndReload(tEnv, model, 
tempFolder.newFolder().getAbsolutePath());
         assertEquals(
-                Collections.singletonList("coefficient"),
+                Arrays.asList("coefficient", "modelVersion"),
                 model.getModelData()[0].getResolvedSchema().getColumnNames());
         Table output = model.transform(binomialDataTable)[0];
         verifyPredictionResult(
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
new file mode 100644
index 0000000..446548d
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
@@ -0,0 +1,687 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.JobSubmissionResult;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import 
org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import 
org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static 
org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;
+
+/** Tests {@link OnlineLogisticRegression} and {@link 
OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+
+    private static final Row[] TRAIN_DENSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.dense(0.1, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(0.3, 2.), 0.),
+                Row.of(Vectors.dense(0.4, 2.), 0.),
+                Row.of(Vectors.dense(0.5, 2.), 0.),
+                Row.of(Vectors.dense(11., 12.), 1.),
+                Row.of(Vectors.dense(12., 11.), 1.),
+                Row.of(Vectors.dense(13., 12.), 1.),
+                Row.of(Vectors.dense(14., 12.), 1.),
+                Row.of(Vectors.dense(15., 12.), 1.)
+            };
+
+    private static final Row[] TRAIN_DENSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.dense(0.2, 3.), 0.),
+                Row.of(Vectors.dense(0.8, 1.), 0.),
+                Row.of(Vectors.dense(0.7, 1.), 0.),
+                Row.of(Vectors.dense(0.6, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(14., 17.), 1.),
+                Row.of(Vectors.dense(15., 10.), 1.),
+                Row.of(Vectors.dense(16., 16.), 1.),
+                Row.of(Vectors.dense(17., 10.), 1.),
+                Row.of(Vectors.dense(18., 13.), 1.)
+            };
+
+    private static final Row[] PREDICT_DENSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.dense(0.8, 2.7), 0.0), 
Row.of(Vectors.dense(15.5, 11.2), 1.0)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 
1.0),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0., 
1.4),
+                Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0., 
1.3),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0., 
1.4),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 
1.6),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1., 
1.8),
+                Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1., 
1.9),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1., 
1.0),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1., 
1.1)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0., 
1.0),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0., 
1.3),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0., 
1.4),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 
1.0),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1., 
1.6),
+                Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1., 
1.8),
+                Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1., 
1.0),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1., 
1.5),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1., 
1.0)
+            };
+
+    private static final Row[] PREDICT_SPARSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 5}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.)
+            };
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private long currentModelDataVersion;
+
+    private InMemorySourceFunction<Row> trainDenseSource;
+    private InMemorySourceFunction<Row> predictDenseSource;
+    private InMemorySourceFunction<Row> trainSparseSource;
+    private InMemorySourceFunction<Row> predictSparseSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<LogisticRegressionModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after 
dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainDenseTable;
+    private Table onlineTrainDenseTable;
+    private Table onlinePredictDenseTable;
+    private Table onlineTrainSparseTable;
+    private Table onlinePredictSparseTable;
+    private Table initDenseModel;
+    private Table initSparseModel;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainDenseSource = new InMemorySourceFunction<>();
+        predictDenseSource = new InMemorySourceFunction<>();
+        trainSparseSource = new InMemorySourceFunction<>();
+        predictSparseSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        reporter = InMemoryReporter.create();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        offlineTrainDenseTable =
+                
tEnv.fromDataStream(env.fromElements(TRAIN_DENSE_ROWS_1)).as("features", 
"label");
+        onlineTrainDenseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                trainDenseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label"})));
+
+        onlinePredictDenseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                predictDenseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label"})));
+
+        onlineTrainSparseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                trainSparseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(SparseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", 
"weight"})));
+
+        onlinePredictSparseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                predictSparseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(SparseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label"})));
+
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.41233679404769874, 
-0.18088118293232122
+                                                }),
+                                        0L)));
+        initSparseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.01, 0.01, 0.01, 0.01, 
0.01, 0.01, 0.01, 0.01,
+                                                    0.01, 0.01
+                                                }),
+                                        0L)));
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds 
sinks for
+     * OnlineLogisticRegressionModel's transform output and model data.
+     */
+    private void transformAndOutputData(
+            OnlineLogisticRegressionModel onlineModel, boolean isSparse) {
+        Table outputTable =
+                onlineModel
+                        .transform(isSparse ? onlinePredictSparseTable : 
onlinePredictDenseTable)[
+                        0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        
LogisticRegressionModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup(JobID jobID) throws 
InterruptedException {
+        while (reporter.findMetrics(jobID, MODEL_DATA_VERSION_GAUGE_KEY).size()
+                < defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate(jobID);
+    }
+
+    /** Blocks the thread until the Model has received the next 
model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate(JobID jobID) throws InterruptedException {
+        do {
+            long tmpModelDataVersion =
+                    reporter.findMetrics(jobID, 
MODEL_DATA_VERSION_GAUGE_KEY).values().stream()
+                            .map(x -> Long.parseLong(((Gauge<String>) 
x).getValue()))
+                            .min(Long::compareTo)
+                            .get();
+            if (tmpModelDataVersion == currentModelDataVersion) {
+                Thread.sleep(100);
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the 
prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedRawInfo A list containing sets of expected result 
RawInfo.
+     */
+    private void predictAndAssert(List<DenseVector> expectedRawInfo, boolean 
isSparse)
+            throws Exception {
+        if (isSparse) {
+            predictSparseSource.addAll(PREDICT_SPARSE_ROWS);
+        } else {
+            predictDenseSource.addAll(PREDICT_DENSE_ROWS);
+        }
+        List<Row> rawResult =
+                outputSink.poll(isSparse ? PREDICT_SPARSE_ROWS.length : 
PREDICT_DENSE_ROWS.length);
+        List<DenseVector> resultDetail = new ArrayList<>(rawResult.size());
+        for (Row row : rawResult) {
+            resultDetail.add(row.getFieldAs(3));
+        }
+        resultDetail.sort(TestUtils::compare);
+        expectedRawInfo.sort(TestUtils::compare);
+        for (int i = 0; i < resultDetail.size(); ++i) {
+            double[] realData = resultDetail.get(i).values;
+            double[] expectedData = expectedRawInfo.get(i).values;
+            for (int j = 0; j < expectedData.length; ++j) {
+                Assert.assertEquals(realData[j], expectedData[j], 1.0e-5);
+            }
+        }
+    }
+
+    private JobID submitJob(JobGraph jobGraph)
+            throws ExecutionException, InterruptedException, TimeoutException {
+        return miniCluster
+                .submitJob(jobGraph)
+                .thenApply(JobSubmissionResult::getJobID)
+                .get(1, TimeUnit.SECONDS);
+    }
+
+    @Test
+    public void testParam() {
+        OnlineLogisticRegression onlineLogisticRegression = new 
OnlineLogisticRegression();
+        Assert.assertEquals("features", 
onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("count", 
onlineLogisticRegression.getBatchStrategy());
+        Assert.assertEquals("label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.0, onlineLogisticRegression.getReg(), 1.0e-5);
+        Assert.assertEquals(0.0, onlineLogisticRegression.getElasticNet(), 
1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(32, onlineLogisticRegression.getGlobalBatchSize());
+
+        onlineLogisticRegression
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setGlobalBatchSize(5)
+                .setReg(0.5)
+                .setElasticNet(0.25)
+                .setAlpha(0.1)
+                .setBeta(0.2);
+
+        Assert.assertEquals("test_feature", 
onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("test_label", 
onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.5, onlineLogisticRegression.getReg(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getElasticNet(), 
1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.2, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(5, onlineLogisticRegression.getGlobalBatchSize());
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel();
+        Assert.assertEquals("features", 
onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("modelVersion", 
onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("prediction", 
onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("rawPrediction", 
onlineLogisticRegressionModel.getRawPredictionCol());
+
+        onlineLogisticRegressionModel
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("pred")
+                .setModelVersionCol("version")
+                .setRawPredictionCol("raw");
+
+        Assert.assertEquals("test_feature", 
onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("version", 
onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("pred", 
onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("raw", 
onlineLogisticRegressionModel.getRawPredictionCol());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        final List<DenseVector> expectedRawInfo1 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.04481034155642882, 
0.9551896584435712}),
+                        new DenseVector(new double[] {0.5353966697318491, 
0.4646033302681509}));
+        final List<DenseVector> expectedRawInfo2 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.013104324065967066, 
0.9868956759340329}),
+                        new DenseVector(new double[] {0.5095144380001769, 
0.49048556199982307}));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol("features")
+                        .setLabelCol("label")
+                        .setPredictionCol("prediction")
+                        .setReg(0.2)
+                        .setElasticNet(0.5)
+                        .setGlobalBatchSize(10)
+                        .setInitialModelData(initDenseModel);
+        OnlineLogisticRegressionModel onlineModel =
+                onlineLogisticRegression.fit(onlineTrainDenseTable);
+        transformAndOutputData(onlineModel, false);
+
+        final JobID jobID = submitJob(env.getStreamGraph().getJobGraph());
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
+
+        waitInitModelDataSetup(jobID);
+        predictAndAssert(expectedRawInfo1, false);
+
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_2);
+        waitModelDataUpdate(jobID);
+        predictAndAssert(expectedRawInfo2, false);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        final List<DenseVector> expectedRawInfo1 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.4452309884735286, 
0.5547690115264714}),
+                        new DenseVector(new double[] {0.5105551725414953, 
0.4894448274585047}));
+        final List<DenseVector> expectedRawInfo2 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.40310431554310666, 
0.5968956844568933}),
+                        new DenseVector(new double[] {0.5249618837373886, 
0.4750381162626114}));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol("features")
+                        .setLabelCol("label")
+                        .setPredictionCol("prediction")
+                        .setReg(0.2)
+                        .setElasticNet(0.5)
+                        .setGlobalBatchSize(9)
+                        .setInitialModelData(initSparseModel);
+        OnlineLogisticRegressionModel onlineModel =
+                onlineLogisticRegression.fit(onlineTrainSparseTable);
+        transformAndOutputData(onlineModel, true);
+        final JobID jobID = submitJob(env.getStreamGraph().getJobGraph());
+
+        trainSparseSource.addAll(TRAIN_SPARSE_ROWS_1);
+        waitInitModelDataSetup(jobID);
+        predictAndAssert(expectedRawInfo1, true);
+
+        trainSparseSource.addAll(TRAIN_SPARSE_ROWS_2);
+        waitModelDataUpdate(jobID);
+        predictAndAssert(expectedRawInfo2, true);
+    }
+
+    @Test
+    public void testFitAndPredictWithWeightCol() throws Exception {
+        final List<DenseVector> expectedRawInfo1 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.452491993753382, 
0.547508006246618}),
+                        new DenseVector(new double[] {0.5069192929506545, 
0.4930807070493455}));
+        final List<DenseVector> expectedRawInfo2 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.41108882806164193, 
0.5889111719383581}),
+                        new DenseVector(new double[] {0.5247727600974581, 
0.4752272399025419}));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol("features")
+                        .setLabelCol("label")
+                        .setWeightCol("weight")
+                        .setPredictionCol("prediction")
+                        .setReg(0.2)
+                        .setElasticNet(0.5)
+                        .setGlobalBatchSize(9)
+                        .setInitialModelData(initSparseModel);
+        OnlineLogisticRegressionModel onlineModel =
+                onlineLogisticRegression.fit(onlineTrainSparseTable);
+        transformAndOutputData(onlineModel, true);
+        final JobID jobID = submitJob(env.getStreamGraph().getJobGraph());
+
+        trainSparseSource.addAll(TRAIN_SPARSE_ROWS_1);
+        waitInitModelDataSetup(jobID);
+        predictAndAssert(expectedRawInfo1, true);
+
+        trainSparseSource.addAll(TRAIN_SPARSE_ROWS_2);
+        waitModelDataUpdate(jobID);
+        predictAndAssert(expectedRawInfo2, true);
+    }
+
+    @Test
+    public void testGenerateRandomModelData() throws Exception {
+        Table modelDataTable = 
LogisticRegressionModelData.generateRandomModelData(tEnv, 2, 2022);
+        DataStream<Row> modelData = tEnv.toDataStream(modelDataTable);
+        Row modelRow = (Row) 
IteratorUtils.toList(modelData.executeAndCollect()).get(0);
+        Assert.assertEquals(2, ((DenseVector) modelRow.getField(0)).size());
+        Assert.assertEquals(0L, modelRow.getField(1));
+    }
+
+    @Test
+    public void testInitWithLogisticRegression() throws Exception {
+        final List<DenseVector> expectedRawInfo1 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.037327343811250024, 
0.96267265618875}),
+                        new DenseVector(new double[] {0.5684728224189707, 
0.4315271775810293}));
+        final List<DenseVector> expectedRawInfo2 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.007758574555505882, 
0.9922414254444941}),
+                        new DenseVector(new double[] {0.5257216567388069, 
0.4742783432611931}));
+        LogisticRegression logisticRegression =
+                new LogisticRegression()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        LogisticRegressionModel model = 
logisticRegression.fit(offlineTrainDenseTable);
+
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setReg(0.2)
+                        .setElasticNet(0.5)
+                        .setGlobalBatchSize(10)
+                        .setInitialModelData(model.getModelData()[0]);
+
+        OnlineLogisticRegressionModel onlineModel =
+                onlineLogisticRegression.fit(onlineTrainDenseTable);
+        transformAndOutputData(onlineModel, false);
+        final JobID jobID = submitJob(env.getStreamGraph().getJobGraph());
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
+        waitInitModelDataSetup(jobID);
+        predictAndAssert(expectedRawInfo1, false);
+
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_2);
+        waitModelDataUpdate(jobID);
+        predictAndAssert(expectedRawInfo2, false);
+    }
+
+    @Test
+    public void testBatchSizeLessThanParallelism() {
+        try {
+            env.setParallelism(defaultParallelism);
+            trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
+            new OnlineLogisticRegression()
+                    .setInitialModelData(initDenseModel)
+                    .setReg(0.2)
+                    .setElasticNet(0.5)
+                    .setGlobalBatchSize(2)
+                    .setLabelCol("label")
+                    .fit(onlineTrainDenseTable);
+            Assert.fail("Expected IllegalStateException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            Assert.assertEquals(IllegalStateException.class, 
exception.getClass());
+            Assert.assertEquals(
+                    "There are more subtasks in the training process than the 
number "
+                            + "of elements in each batch. Some subtasks might 
be idling forever.",
+                    exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        final List<DenseVector> expectedRawInfo1 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.04481034155642882, 
0.9551896584435712}),
+                        new DenseVector(new double[] {0.5353966697318491, 
0.4646033302681509}));
+        final List<DenseVector> expectedRawInfo2 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.013104324065967066, 
0.9868956759340329}),
+                        new DenseVector(new double[] {0.5095144380001769, 
0.49048556199982307}));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setReg(0.2)
+                        .setElasticNet(0.5)
+                        .setGlobalBatchSize(10)
+                        .setInitialModelData(initDenseModel);
+
+        String savePath = tempFolder.newFolder().getAbsolutePath();
+        onlineLogisticRegression.save(savePath);
+        miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
+        OnlineLogisticRegression loadedOnlineLogisticRegression =
+                OnlineLogisticRegression.load(tEnv, savePath);
+        OnlineLogisticRegressionModel onlineModel =
+                loadedOnlineLogisticRegression.fit(onlineTrainDenseTable);
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        onlineModel.save(modelSavePath);
+        OnlineLogisticRegressionModel loadedOnlineModel =
+                OnlineLogisticRegressionModel.load(tEnv, modelSavePath);
+        loadedOnlineModel.setModelData(onlineModel.getModelData());
+
+        transformAndOutputData(loadedOnlineModel, false);
+        final JobID jobID = submitJob(env.getStreamGraph().getJobGraph());
+
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
+        waitInitModelDataSetup(jobID);
+        predictAndAssert(expectedRawInfo1, false);
+
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_2);
+        waitModelDataUpdate(jobID);
+        predictAndAssert(expectedRawInfo2, false);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setReg(0.2)
+                        .setElasticNet(0.5)
+                        .setGlobalBatchSize(10)
+                        .setInitialModelData(initDenseModel);
+        OnlineLogisticRegressionModel onlineModel =
+                onlineLogisticRegression.fit(onlineTrainDenseTable);
+        transformAndOutputData(onlineModel, false);
+
+        submitJob(env.getStreamGraph().getJobGraph());
+        trainDenseSource.addAll(TRAIN_DENSE_ROWS_1);
+        LogisticRegressionModelData actualModelData = modelDataSink.poll();
+
+        LogisticRegressionModelData expectedModelData =
+                new LogisticRegressionModelData(
+                        new DenseVector(new double[] {0.2994527071464283, 
-0.1412541067743284}),
+                        1L);
+        Assert.assertArrayEquals(
+                expectedModelData.coefficient.values, 
actualModelData.coefficient.values, 1e-5);
+        Assert.assertEquals(expectedModelData.modelVersion, 
actualModelData.modelVersion);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        LogisticRegressionModelData modelData1 =
+                new LogisticRegressionModelData(new DenseVector(new double[] 
{0.085, -0.22}), 1L);
+
+        LogisticRegressionModelData modelData2 =
+                new LogisticRegressionModelData(new DenseVector(new double[] 
{0.075, -0.28}), 2L);
+
+        final List<DenseVector> expectedRawInfo1 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.6285496932692606, 
0.3714503067307394}),
+                        new DenseVector(new double[] {0.7588710471221473, 
0.24112895287785274}));
+        final List<DenseVector> expectedRawInfo2 =
+                Arrays.asList(
+                        new DenseVector(new double[] {0.6673003248270917, 
0.3326996751729083}),
+                        new DenseVector(new double[] {0.8779865510655934, 
0.12201344893440658}));
+
+        InMemorySourceFunction<LogisticRegressionModelData> modelDataSource =
+                new InMemorySourceFunction<>();
+        Table modelDataTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                modelDataSource,
+                                
TypeInformation.of(LogisticRegressionModelData.class)));
+
+        OnlineLogisticRegressionModel onlineModel =
+                new OnlineLogisticRegressionModel()
+                        .setModelData(modelDataTable)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        transformAndOutputData(onlineModel, false);
+        final JobID jobID = submitJob(env.getStreamGraph().getJobGraph());
+
+        modelDataSource.addAll(modelData1);
+        waitInitModelDataSetup(jobID);
+        predictAndAssert(expectedRawInfo1, false);
+
+        modelDataSource.addAll(modelData2);
+        waitModelDataUpdate(jobID);
+        predictAndAssert(expectedRawInfo2, false);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
index 1a0cfc4..ebf7ab0 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
@@ -47,6 +47,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
+import static 
org.apache.flink.test.util.TestBaseUtils.compareResultCollections;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 
@@ -79,17 +80,6 @@ public class MinMaxScalerTest {
                             Vectors.dense(0.5, 0.125),
                             Vectors.dense(0.75, 0.225)));
 
-    /** Note: this comparator imposes orderings that are inconsistent with 
equals. */
-    private static int compare(DenseVector first, DenseVector second) {
-        for (int i = 0; i < first.size(); i++) {
-            int cmp = Double.compare(first.get(i), second.get(i));
-            if (cmp != 0) {
-                return cmp;
-            }
-        }
-        return 0;
-    }
-
     @Before
     public void before() {
         Configuration config = new Configuration();
@@ -113,8 +103,7 @@ public class MinMaxScalerTest {
                                 (MapFunction<Row, DenseVector>)
                                         row -> (DenseVector) 
row.getField(outputCol));
         List<DenseVector> result = 
IteratorUtils.toList(stream.executeAndCollect());
-        result.sort(MinMaxScalerTest::compare);
-        assertEquals(expected, result);
+        compareResultCollections(expected, result, TestUtils::compare);
     }
 
     @Test
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java
index 171aa66..5250588 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.api.Stage;
 import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
@@ -33,6 +34,7 @@ import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.types.DataType;
 import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
@@ -114,4 +116,16 @@ public class TestUtils {
                 .map(DataType::getConversionClass)
                 .toArray(Class<?>[]::new);
     }
+
+    /** Note: this comparator imposes orderings that are inconsistent with 
equals. */
+    public static int compare(DenseVector first, DenseVector second) {
+        Preconditions.checkArgument(first.size() == second.size(), "Vector 
size mismatched.");
+        for (int i = 0; i < first.size(); i++) {
+            int cmp = Double.compare(first.get(i), second.get(i));
+            if (cmp != 0) {
+                return cmp;
+            }
+        }
+        return 0;
+    }
 }
diff --git 
a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py 
b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py
index 60a49f2..fb3adef 100644
--- a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py
+++ b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py
@@ -17,12 +17,14 @@
 
################################################################################
 from abc import ABC
 
+from pyflink.ml.core.param import (ParamValidators, Param, StringParam, 
FloatParam)
 from pyflink.ml.core.wrapper import JavaWithParams
 from pyflink.ml.lib.classification.common import (JavaClassificationModel,
                                                   JavaClassificationEstimator)
 from pyflink.ml.lib.param import (HasWeightCol, HasMaxIter, HasReg, 
HasLearningRate,
                                   HasGlobalBatchSize, HasTol, HasMultiClass, 
HasFeaturesCol,
-                                  HasPredictionCol, HasRawPredictionCol, 
HasLabelCol, HasElasticNet)
+                                  HasPredictionCol, HasRawPredictionCol, 
HasLabelCol,
+                                  HasBatchStrategy, HasElasticNet)
 
 
 class _LogisticRegressionModelParams(
@@ -98,3 +100,120 @@ class LogisticRegression(JavaClassificationEstimator, 
_LogisticRegressionParams)
     @classmethod
     def _java_estimator_class_name(cls) -> str:
         return "LogisticRegression"
+
+
+class _OnlineLogisticRegressionModelParams(
+    JavaWithParams,
+    HasFeaturesCol,
+    HasPredictionCol,
+    HasRawPredictionCol,
+    ABC
+):
+    """
+    Params for :class:`OnlineLogisticRegressionModel`.
+    """
+    MODEL_VERSION_COL: Param[str] = StringParam(
+        "model_version_col",
+        "Model version column name.",
+        "model_version",
+        ParamValidators.not_null())
+
+    def __init__(self, java_params):
+        super(_OnlineLogisticRegressionModelParams, self).__init__(java_params)
+
+    def set_model_version_col(self, value: str):
+        return self.set(self.MODEL_VERSION_COL, value)
+
+    def get_model_version_col(self) -> str:
+        return self.get(self.MODEL_VERSION_COL)
+
+
+class _OnlineLogisticRegressionParams(
+    _OnlineLogisticRegressionModelParams,
+    HasBatchStrategy,
+    HasLabelCol,
+    HasWeightCol,
+    HasReg,
+    HasElasticNet,
+    HasGlobalBatchSize
+):
+    """
+    Params for :class:`OnlineLogisticRegression`.
+    """
+
+    ALPHA: Param[float] = FloatParam(
+        "alpha",
+        "The alpha parameter of ftrl.",
+        0.1,
+        ParamValidators.gt(0))
+
+    BETA: Param[float] = FloatParam(
+        "beta",
+        "The beta parameter of ftrl.",
+        0.1,
+        ParamValidators.gt(0))
+
+    def __init__(self, java_params):
+        super(_OnlineLogisticRegressionParams, self).__init__(java_params)
+
+    def set_alpha(self, alpha: float):
+        return self.set(self.ALPHA, alpha)
+
+    def get_alpha(self) -> float:
+        return self.get(self.ALPHA)
+
+    @property
+    def alpha(self) -> float:
+        return self.get_alpha()
+
+    def set_beta(self, beta: float):
+        return self.set(self.BETA, beta)
+
+    def get_beta(self) -> float:
+        return self.get(self.BETA)
+
+    @property
+    def beta(self) -> float:
+        return self.get_beta()
+
+
+class OnlineLogisticRegressionModel(JavaClassificationModel,
+                                    _OnlineLogisticRegressionModelParams):
+    """
+    A Model which classifies data using the model data computed by
+    :class:`OnlineLogisticRegression`.
+    """
+
+    def __init__(self, java_model=None):
+        super(OnlineLogisticRegressionModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "logisticregression"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "OnlineLogisticRegressionModel"
+
+
+class OnlineLogisticRegression(JavaClassificationEstimator, 
_OnlineLogisticRegressionParams):
+    """
+    An Estimator which implements the online logistic regression algorithm.
+
+    See H. Brendan McMahan et al., Ad click prediction: a view from the 
trenches.
+    """
+
+    def __init__(self):
+        super(OnlineLogisticRegression, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> OnlineLogisticRegressionModel:
+        return OnlineLogisticRegressionModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "logisticregression"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "OnlineLogisticRegression"
diff --git 
a/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py
 
b/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py
index 2f2f8bb..e460155 100644
--- 
a/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py
+++ 
b/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py
@@ -22,7 +22,7 @@ from pyflink.table import Table
 
 from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, DenseVector
 from pyflink.ml.lib.classification.logisticregression import 
LogisticRegression, \
-    LogisticRegressionModel
+    LogisticRegressionModel, OnlineLogisticRegression
 from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
 
 
@@ -135,7 +135,9 @@ class LogisticRegressionTest(PyFlinkMLTestCase):
         regression.save(path)
         regression = LogisticRegression.load(self.t_env, path)  # type: 
LogisticRegression
         model = regression.fit(self.binomial_data_table)
-        
self.assertEqual(model.get_model_data()[0].get_schema().get_field_names(), 
['coefficient'])
+        self.assertEqual(
+            model.get_model_data()[0].get_schema().get_field_names(),
+            ['coefficient', 'modelVersion'])
         output = model.transform(self.binomial_data_table)[0]
         field_names = output.get_schema().get_field_names()
         self.verify_predict_result(
@@ -183,3 +185,38 @@ class LogisticRegressionTest(PyFlinkMLTestCase):
                 else:
                     self.assertAlmostEqual(1, prediction, delta=1e-7)
                     self.assertTrue(raw_prediction.get(0) < 0.5)
+
+
+class OnlineLogisticRegressionTest(PyFlinkMLTestCase):
+
+    def setUp(self):
+        super(OnlineLogisticRegressionTest, self).setUp()
+
+    def test_param(self):
+        online_logistic_regression = OnlineLogisticRegression()
+        self.assertEqual("features", online_logistic_regression.features_col)
+        self.assertEqual("count", online_logistic_regression.batch_strategy)
+        self.assertEqual("label", online_logistic_regression.label_col)
+        self.assertEqual(None, online_logistic_regression.weight_col)
+        self.assertEqual(0.0, online_logistic_regression.reg)
+        self.assertEqual(0.0, online_logistic_regression.elastic_net)
+        self.assertEqual(0.1, online_logistic_regression.alpha)
+        self.assertEqual(0.1, online_logistic_regression.beta)
+        self.assertEqual(32, online_logistic_regression.global_batch_size)
+
+        online_logistic_regression \
+            .set_features_col("test_feature") \
+            .set_label_col("test_label") \
+            .set_global_batch_size(5) \
+            .set_reg(0.5) \
+            .set_elastic_net(0.25) \
+            .set_alpha(0.1) \
+            .set_beta(0.2)
+
+        self.assertEqual("test_feature", 
online_logistic_regression.features_col)
+        self.assertEqual("test_label", online_logistic_regression.label_col)
+        self.assertEqual(0.5, online_logistic_regression.reg)
+        self.assertEqual(0.25, online_logistic_regression.elastic_net)
+        self.assertEqual(0.1, online_logistic_regression.alpha)
+        self.assertEqual(0.2, online_logistic_regression.beta)
+        self.assertEqual(5, online_logistic_regression.global_batch_size)

Reply via email to