lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834235458



##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java
##########
@@ -40,6 +40,6 @@
  */
 @PublicEvolving
 public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable 
{
-    /** Saves this stage to the given path. */
+    /** Saves the metadata and bounded model data of this stage to the given 
path. */

Review comment:
       Could we replace `bounded model data` with `bounded data` so that the 
description is a bit more general?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();

Review comment:
       According to `AbstractTestBase`'s Javadoc, it could save significant 
amount of time to reuse the same mini-cluster for tests.
   
   Could we re-use the same mini-cluster for tests using a similar approach as 
`AbstractTestBase`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} 
operator which can update
+ * model data in a streaming format, using the model data provided by {@link 
OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, 
KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        
DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends 
CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model 
data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        // TODO: replace this simple implementation of model data version with 
the formal API to
+        // track model version after its design is settled.
+        private int modelDataVersion;

Review comment:
       After the operator is opened and before the operator processes any input 
value, the value of this variable would already be exposed. It will be better 
to explicitly define it.
   
   How about we set the default value of this variable to be 0, which could 
mean `the model version is unknown`? And the valid model versions should start 
from 1.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;

Review comment:
       Given that we already have `offlineTrainTable`, how about renaming this 
table as `onlineTrainTable`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records 
from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new 
ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) throws InterruptedException {
+        while (isRunning) {
+            context.collect(queue.poll(1, TimeUnit.MINUTES));

Review comment:
       The current approach means that, after a graceful shutdown is requested, 
the source operator might needs to wait up to 1 minute before this operator can 
actually be shutdown.
   
   How about we use `BlockingQueue<Optional<T>> queue`, let `cancel()` inserts 
`Optional.empty()` into the queue, and have `run()` ignore empty value?
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} 
operator which can update
+ * model data in a streaming format, using the model data provided by {@link 
OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, 
KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        
DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends 
CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model 
data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        // TODO: replace this simple implementation of model data version with 
the formal API to
+        // track model version after its design is settled.
+        private int modelDataVersion;
+
+        public PredictLabelFunction(String featuresCol, DistanceMeasure 
distanceMeasure) {
+            this.featuresCol = featuresCol;
+            this.distanceMeasure = distanceMeasure;
+        }
+
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "modelDataVersion",

Review comment:
       Could we put the metric name in a static final variable? Then 
`OnlineKMeansTest` could use the variable instead of copying the string.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), 
schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, 
DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, 
DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds 
sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void transformAndOutputData(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        
KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() throws InterruptedException {
+        while (reporter.findMetrics("modelDataVersion").size() < 
defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next 
model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() throws InterruptedException {
+        do {
+            int tmpModelDataVersion =
+                    reporter.findMetrics("modelDataVersion").values().stream()
+                            .map(x -> Integer.parseInt(((Gauge<String>) 
x).getValue()))
+                            .min(Integer::compareTo)
+                            .orElse(currentModelDataVersion);

Review comment:
       Given that `waitModelDataUpdate()` is only called after or within 
`waitInitModelDataSetup()`, the `orElse(...)` will never be needed, right?
   
   Would it be simpler to replace `orElse(...)` with `get()`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records 
from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new 
ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;

Review comment:
       Should it be `volatile`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();

Review comment:
       Could this line be removed?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, 
weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, 
DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, 
and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> 
collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + 
size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], 
GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = 
IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the 
provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with 
randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();

Review comment:
       While I believe we will have public API on Table to expose its env, I am 
not sure we will also have public API on `StreamTableEnvironmentImpl` to expose 
its StreamExecutionEnvironment. 
   
   It will be better to reduce the use of internal APIs if possible. How about 
we just pass `StreamExecutionEnvironment` as the method's parameter? This would 
also be more consistent with `load(...)`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;
+
+    public KMeansModelData(DenseVector[] centroids, DenseVector weights) {
+        this.centroids = centroids;
+        this.weights = weights;
+    }
+
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
+        this.weights = new DenseVector(centroids.length);

Review comment:
       It seems a bit weird to have weights to be a vector of 0s.
   
   Would it be simpler to remove this constructor and explicitly specify 
weights in tests? The weights specify easily by e.g. `Vectors.dense(1, 1)`.
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, 
weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, 
DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, 
and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> 
collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + 
size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], 
GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = 
IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the 
provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with 
randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       Would it be better to still have set `k` as the parameter of 
`OnlineKMeans`, so that it is as similar to `KMeans` as possible?
   
   The number of centroids is a key aspect of the kmeans algorithm (including 
its online version) and users would want to know what its value. It would be 
harder for users to get this information if we don't specify it as 
OnlineKMeans's parameter.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>

Review comment:
       Could we add Java doc explaining the algorithm used in this operator? 
Same for `ModelDataGlobalReducer`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, 
weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, 
DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {

Review comment:
       Would it be simpler to remove `throws Exception` here? Same for other 
methods.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to