lindong28 commented on a change in pull request #70: URL: https://github.com/apache/flink-ml/pull/70#discussion_r834235458
########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java ########## @@ -40,6 +40,6 @@ */ @PublicEvolving public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable { - /** Saves this stage to the given path. */ + /** Saves the metadata and bounded model data of this stage to the given path. */ Review comment: Could we replace `bounded model data` with `bounded data` so that the description is a bit more general? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,464 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private int currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = 0; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); Review comment: According to `AbstractTestBase`'s Javadoc, it could save significant amount of time to reuse the same mini-cluster for tests. Could we re-use the same mini-cluster for tests using a similar approach as `AbstractTestBase`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java ########## @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update + * model data in a streaming format, using the model data provided by {@link OnlineKMeans}. + */ +public class OnlineKMeansModel + implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OnlineKMeansModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + + DataStream<Row> predictionResult = + KMeansModelData.getModelDataStream(modelDataTable) + .broadcast() + .connect(tEnv.toDataStream(inputs[0])) + .process( + new PredictLabelFunction( + getFeaturesCol(), + DistanceMeasure.getInstance(getDistanceMeasure())), + outputTypeInfo); + + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> { + private final String featuresCol; + + private final DistanceMeasure distanceMeasure; + + private DenseVector[] centroids; + + // TODO: replace this with a complete solution of reading first model data from unbounded + // model data stream before processing the first predict data. + private final List<Row> bufferedPoints = new ArrayList<>(); + + // TODO: replace this simple implementation of model data version with the formal API to + // track model version after its design is settled. + private int modelDataVersion; Review comment: After the operator is opened and before the operator processes any input value, the value of this variable would already be exposed. It will be better to explicitly define it. How about we set the default value of this variable to be 0, which could mean `the model version is unknown`? And the valid model versions should start from 1. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,464 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private int currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; Review comment: Given that we already have `offlineTrainTable`, how about renaming this table as `onlineTrainTable`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java ########## @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import java.util.Arrays; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link SourceFunction} implementation that can directly receive records from tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySourceFunction<T> extends RichSourceFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<T> queue; + private boolean isRunning = true; + + public InMemorySourceFunction() { + id = UUID.randomUUID(); + queue = new LinkedBlockingQueue(); + queueMap.put(id, queue); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + queue = queueMap.get(id); + } + + @Override + public void close() throws Exception { + super.close(); + queueMap.remove(id); + } + + @Override + public void run(SourceContext<T> context) throws InterruptedException { + while (isRunning) { + context.collect(queue.poll(1, TimeUnit.MINUTES)); Review comment: The current approach means that, after a graceful shutdown is requested, the source operator might needs to wait up to 1 minute before this operator can actually be shutdown. How about we use `BlockingQueue<Optional<T>> queue`, let `cancel()` inserts `Optional.empty()` into the queue, and have `run()` ignore empty value? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java ########## @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update + * model data in a streaming format, using the model data provided by {@link OnlineKMeans}. + */ +public class OnlineKMeansModel + implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OnlineKMeansModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + + DataStream<Row> predictionResult = + KMeansModelData.getModelDataStream(modelDataTable) + .broadcast() + .connect(tEnv.toDataStream(inputs[0])) + .process( + new PredictLabelFunction( + getFeaturesCol(), + DistanceMeasure.getInstance(getDistanceMeasure())), + outputTypeInfo); + + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> { + private final String featuresCol; + + private final DistanceMeasure distanceMeasure; + + private DenseVector[] centroids; + + // TODO: replace this with a complete solution of reading first model data from unbounded + // model data stream before processing the first predict data. + private final List<Row> bufferedPoints = new ArrayList<>(); + + // TODO: replace this simple implementation of model data version with the formal API to + // track model version after its design is settled. + private int modelDataVersion; + + public PredictLabelFunction(String featuresCol, DistanceMeasure distanceMeasure) { + this.featuresCol = featuresCol; + this.distanceMeasure = distanceMeasure; + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + + getRuntimeContext() + .getMetricGroup() + .gauge( + "modelDataVersion", Review comment: Could we put the metric name in a static final variable? Then `OnlineKMeansTest` could use the variable instead of copying the string. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,464 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private int currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = 0; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void transformAndOutputData(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + tEnv.toDataStream(outputTable).addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() throws InterruptedException { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.sleep(100); + } + waitModelDataUpdate(); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate() throws InterruptedException { + do { + int tmpModelDataVersion = + reporter.findMetrics("modelDataVersion").values().stream() + .map(x -> Integer.parseInt(((Gauge<String>) x).getValue())) + .min(Integer::compareTo) + .orElse(currentModelDataVersion); Review comment: Given that `waitModelDataUpdate()` is only called after or within `waitInitModelDataSetup()`, the `orElse(...)` will never be needed, right? Would it be simpler to replace `orElse(...)` with `get()`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java ########## @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import java.util.Arrays; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link SourceFunction} implementation that can directly receive records from tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySourceFunction<T> extends RichSourceFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<T> queue; + private boolean isRunning = true; Review comment: Should it be `volatile`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); Review comment: Could this line be removed? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> finalModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table finalModelDataTable = tEnv.fromDataStream(finalModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return kMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<KMeansModelData> newModelData = + points.countWindowAll(batchSize) + .apply(new GlobalBatchCreator()) + .flatMap(new GlobalBatchSplitter(parallelism)) + .rebalance() + .connect(modelData.broadcast()) + .transform( + "ModelDataLocalUpdater", + TypeInformation.of(KMeansModelData.class), + new ModelDataLocalUpdater(distanceMeasure, decayFactor)) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce(new ModelDataGlobalReducer()); + + return new IterationBodyResult( + DataStreamList.of(newModelData), DataStreamList.of(modelData)); + } + } + + private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> { + @Override + public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { + DenseVector weights = modelData.weights; + DenseVector[] centroids = modelData.centroids; + DenseVector newWeights = newModelData.weights; + DenseVector[] newCentroids = newModelData.centroids; + + int k = newCentroids.length; + int dim = newCentroids[0].size(); + + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + centroids[i].values[j] = + (centroids[i].values[j] * weights.values[i] + + newCentroids[i].values[j] * newWeights.values[i]) + / Math.max(weights.values[i] + newWeights.values[i], 1e-16); + } + weights.values[i] += newWeights.values[i]; + } + + return new KMeansModelData(centroids, weights); + } + } + + private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData> + implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private ListState<DenseVector[]> localBatchState; + private ListState<KMeansModelData> modelDataState; + + private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + TypeInformation<DenseVector[]> type = + ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + localBatchState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("localBatch", type)); + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelData", KMeansModelData.class)); + } + + @Override + public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception { + localBatchState.add(pointsRecord.getValue()); + alignAndComputeModelData(); + } + + @Override + public void processElement2(StreamRecord<KMeansModelData> modelDataRecord) + throws Exception { + modelDataState.add(modelDataRecord.getValue()); + alignAndComputeModelData(); + } + + private void alignAndComputeModelData() throws Exception { + if (!modelDataState.get().iterator().hasNext() + || !localBatchState.get().iterator().hasNext()) { + return; + } + + KMeansModelData modelData = + OperatorStateUtils.getUniqueElement(modelDataState, "modelData") + .orElseThrow((Supplier<Exception>) NullPointerException::new); + DenseVector[] centroids = modelData.centroids; + DenseVector weights = modelData.weights; + modelDataState.clear(); + + List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator()); + DenseVector[] points = pointsList.remove(0); + localBatchState.update(pointsList); + + int dim = centroids[0].size(); + int k = centroids.length; + int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + + // Computes new centroids. + DenseVector[] sums = new DenseVector[k]; + int[] counts = new int[k]; + + for (int i = 0; i < k; i++) { + sums[i] = new DenseVector(dim); + counts[i] = 0; + } + for (DenseVector point : points) { + int closestCentroidId = + KMeans.findClosestCentroidId(centroids, point, distanceMeasure); + counts[closestCentroidId]++; + for (int j = 0; j < dim; j++) { + sums[closestCentroidId].values[j] += point.values[j]; + } + } + + // Considers weight and decay factor when updating centroids. + BLAS.scal(decayFactor / parallelism, weights); + for (int i = 0; i < k; i++) { + if (counts[i] == 0) { + continue; + } + + DenseVector centroid = centroids[i]; + weights.values[i] = weights.values[i] + counts[i]; + double lambda = counts[i] / weights.values[i]; + + BLAS.scal(1.0 - lambda, centroid); + BLAS.axpy(lambda / counts[i], sums[i], centroid); + } + + output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights))); + } + } + + private static class FeaturesExtractor implements MapFunction<Row, DenseVector> { + private final String featuresCol; + + private FeaturesExtractor(String featuresCol) { + this.featuresCol = featuresCol; + } + + @Override + public DenseVector map(Row row) throws Exception { + return (DenseVector) row.getField(featuresCol); + } + } + + // An operator that splits a global batch into evenly-sized local batches, and distributes them + // to downstream operator. + private static class GlobalBatchSplitter + implements FlatMapFunction<DenseVector[], DenseVector[]> { + private final int downStreamParallelism; + + private GlobalBatchSplitter(int downStreamParallelism) { + this.downStreamParallelism = downStreamParallelism; + } + + @Override + public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) { + // Calculate the batch sizes to be distributed on each subtask. + List<Integer> sizes = new ArrayList<>(); + for (int i = 0; i < downStreamParallelism; i++) { + int start = i * values.length / downStreamParallelism; + int end = (i + 1) * values.length / downStreamParallelism; + sizes.add(end - start); + } + + int offset = 0; + for (Integer size : sizes) { + collector.collect(Arrays.copyOfRange(values, offset, offset + size)); + offset += size; + } + } + } + + private static class GlobalBatchCreator + implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> { + @Override + public void apply( + GlobalWindow timeWindow, + Iterable<DenseVector> iterable, + Collector<DenseVector[]> collector) { + List<DenseVector> points = IteratorUtils.toList(iterable.iterator()); + collector.collect(points.toArray(new DenseVector[0])); + } + } + + /** + * Sets the initial model data of the online training process with the provided model data + * table. + * + * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)} + * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}. + */ + public OnlineKMeans setInitialModelData(Table initModelDataTable) { + this.initModelDataTable = initModelDataTable; + return this; + } + + /** + * Sets the initial model data of the online training process with randomly created centroids. + * + * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)} + * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}. + * + * @param tEnv The stream table environment to create the centroids in. + * @param dim The dimension of the centroids to create. + * @param k The number of centroids to create. + * @param weight The weight of the centroids to create. + */ + public OnlineKMeans setRandomCentroids( + StreamTableEnvironment tEnv, int dim, int k, double weight) { + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); Review comment: While I believe we will have public API on Table to expose its env, I am not sure we will also have public API on `StreamTableEnvironmentImpl` to expose its StreamExecutionEnvironment. It will be better to reduce the use of internal APIs if possible. How about we just pass `StreamExecutionEnvironment` as the method's parameter? This would also be more consistent with `load(...)`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java ########## @@ -47,8 +47,16 @@ public DenseVector[] centroids; + public DenseVector weights; + + public KMeansModelData(DenseVector[] centroids, DenseVector weights) { + this.centroids = centroids; + this.weights = weights; + } + public KMeansModelData(DenseVector[] centroids) { this.centroids = centroids; + this.weights = new DenseVector(centroids.length); Review comment: It seems a bit weird to have weights to be a vector of 0s. Would it be simpler to remove this constructor and explicitly specify weights in tests? The weights specify easily by e.g. `Vectors.dense(1, 1)`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> finalModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table finalModelDataTable = tEnv.fromDataStream(finalModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return kMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<KMeansModelData> newModelData = + points.countWindowAll(batchSize) + .apply(new GlobalBatchCreator()) + .flatMap(new GlobalBatchSplitter(parallelism)) + .rebalance() + .connect(modelData.broadcast()) + .transform( + "ModelDataLocalUpdater", + TypeInformation.of(KMeansModelData.class), + new ModelDataLocalUpdater(distanceMeasure, decayFactor)) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce(new ModelDataGlobalReducer()); + + return new IterationBodyResult( + DataStreamList.of(newModelData), DataStreamList.of(modelData)); + } + } + + private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> { + @Override + public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { + DenseVector weights = modelData.weights; + DenseVector[] centroids = modelData.centroids; + DenseVector newWeights = newModelData.weights; + DenseVector[] newCentroids = newModelData.centroids; + + int k = newCentroids.length; + int dim = newCentroids[0].size(); + + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + centroids[i].values[j] = + (centroids[i].values[j] * weights.values[i] + + newCentroids[i].values[j] * newWeights.values[i]) + / Math.max(weights.values[i] + newWeights.values[i], 1e-16); + } + weights.values[i] += newWeights.values[i]; + } + + return new KMeansModelData(centroids, weights); + } + } + + private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData> + implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private ListState<DenseVector[]> localBatchState; + private ListState<KMeansModelData> modelDataState; + + private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + TypeInformation<DenseVector[]> type = + ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + localBatchState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("localBatch", type)); + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelData", KMeansModelData.class)); + } + + @Override + public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception { + localBatchState.add(pointsRecord.getValue()); + alignAndComputeModelData(); + } + + @Override + public void processElement2(StreamRecord<KMeansModelData> modelDataRecord) + throws Exception { + modelDataState.add(modelDataRecord.getValue()); + alignAndComputeModelData(); + } + + private void alignAndComputeModelData() throws Exception { + if (!modelDataState.get().iterator().hasNext() + || !localBatchState.get().iterator().hasNext()) { + return; + } + + KMeansModelData modelData = + OperatorStateUtils.getUniqueElement(modelDataState, "modelData") + .orElseThrow((Supplier<Exception>) NullPointerException::new); + DenseVector[] centroids = modelData.centroids; + DenseVector weights = modelData.weights; + modelDataState.clear(); + + List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator()); + DenseVector[] points = pointsList.remove(0); + localBatchState.update(pointsList); + + int dim = centroids[0].size(); + int k = centroids.length; + int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + + // Computes new centroids. + DenseVector[] sums = new DenseVector[k]; + int[] counts = new int[k]; + + for (int i = 0; i < k; i++) { + sums[i] = new DenseVector(dim); + counts[i] = 0; + } + for (DenseVector point : points) { + int closestCentroidId = + KMeans.findClosestCentroidId(centroids, point, distanceMeasure); + counts[closestCentroidId]++; + for (int j = 0; j < dim; j++) { + sums[closestCentroidId].values[j] += point.values[j]; + } + } + + // Considers weight and decay factor when updating centroids. + BLAS.scal(decayFactor / parallelism, weights); + for (int i = 0; i < k; i++) { + if (counts[i] == 0) { + continue; + } + + DenseVector centroid = centroids[i]; + weights.values[i] = weights.values[i] + counts[i]; + double lambda = counts[i] / weights.values[i]; + + BLAS.scal(1.0 - lambda, centroid); + BLAS.axpy(lambda / counts[i], sums[i], centroid); + } + + output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights))); + } + } + + private static class FeaturesExtractor implements MapFunction<Row, DenseVector> { + private final String featuresCol; + + private FeaturesExtractor(String featuresCol) { + this.featuresCol = featuresCol; + } + + @Override + public DenseVector map(Row row) throws Exception { + return (DenseVector) row.getField(featuresCol); + } + } + + // An operator that splits a global batch into evenly-sized local batches, and distributes them + // to downstream operator. + private static class GlobalBatchSplitter + implements FlatMapFunction<DenseVector[], DenseVector[]> { + private final int downStreamParallelism; + + private GlobalBatchSplitter(int downStreamParallelism) { + this.downStreamParallelism = downStreamParallelism; + } + + @Override + public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) { + // Calculate the batch sizes to be distributed on each subtask. + List<Integer> sizes = new ArrayList<>(); + for (int i = 0; i < downStreamParallelism; i++) { + int start = i * values.length / downStreamParallelism; + int end = (i + 1) * values.length / downStreamParallelism; + sizes.add(end - start); + } + + int offset = 0; + for (Integer size : sizes) { + collector.collect(Arrays.copyOfRange(values, offset, offset + size)); + offset += size; + } + } + } + + private static class GlobalBatchCreator + implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> { + @Override + public void apply( + GlobalWindow timeWindow, + Iterable<DenseVector> iterable, + Collector<DenseVector[]> collector) { + List<DenseVector> points = IteratorUtils.toList(iterable.iterator()); + collector.collect(points.toArray(new DenseVector[0])); + } + } + + /** + * Sets the initial model data of the online training process with the provided model data + * table. + * + * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)} + * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}. + */ + public OnlineKMeans setInitialModelData(Table initModelDataTable) { + this.initModelDataTable = initModelDataTable; + return this; + } + + /** + * Sets the initial model data of the online training process with randomly created centroids. + * + * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)} + * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}. + * + * @param tEnv The stream table environment to create the centroids in. + * @param dim The dimension of the centroids to create. + * @param k The number of centroids to create. + * @param weight The weight of the centroids to create. + */ + public OnlineKMeans setRandomCentroids( + StreamTableEnvironment tEnv, int dim, int k, double weight) { Review comment: Would it be better to still have set `k` as the parameter of `OnlineKMeans`, so that it is as similar to `KMeans` as possible? The number of centroids is a key aspect of the kmeans algorithm (including its online version) and users would want to know what its value. It would be harder for users to get this information if we don't specify it as OnlineKMeans's parameter. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> finalModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table finalModelDataTable = tEnv.fromDataStream(finalModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return kMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<KMeansModelData> newModelData = + points.countWindowAll(batchSize) + .apply(new GlobalBatchCreator()) + .flatMap(new GlobalBatchSplitter(parallelism)) + .rebalance() + .connect(modelData.broadcast()) + .transform( + "ModelDataLocalUpdater", + TypeInformation.of(KMeansModelData.class), + new ModelDataLocalUpdater(distanceMeasure, decayFactor)) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce(new ModelDataGlobalReducer()); + + return new IterationBodyResult( + DataStreamList.of(newModelData), DataStreamList.of(modelData)); + } + } + + private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> { + @Override + public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { + DenseVector weights = modelData.weights; + DenseVector[] centroids = modelData.centroids; + DenseVector newWeights = newModelData.weights; + DenseVector[] newCentroids = newModelData.centroids; + + int k = newCentroids.length; + int dim = newCentroids[0].size(); + + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + centroids[i].values[j] = + (centroids[i].values[j] * weights.values[i] + + newCentroids[i].values[j] * newWeights.values[i]) + / Math.max(weights.values[i] + newWeights.values[i], 1e-16); + } + weights.values[i] += newWeights.values[i]; + } + + return new KMeansModelData(centroids, weights); + } + } + + private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData> Review comment: Could we add Java doc explaining the algorithm used in this operator? Same for `ModelDataGlobalReducer`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> finalModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table finalModelDataTable = tEnv.fromDataStream(finalModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return kMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<KMeansModelData> newModelData = + points.countWindowAll(batchSize) + .apply(new GlobalBatchCreator()) + .flatMap(new GlobalBatchSplitter(parallelism)) + .rebalance() + .connect(modelData.broadcast()) + .transform( + "ModelDataLocalUpdater", + TypeInformation.of(KMeansModelData.class), + new ModelDataLocalUpdater(distanceMeasure, decayFactor)) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce(new ModelDataGlobalReducer()); + + return new IterationBodyResult( + DataStreamList.of(newModelData), DataStreamList.of(modelData)); + } + } + + private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> { + @Override + public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { + DenseVector weights = modelData.weights; + DenseVector[] centroids = modelData.centroids; + DenseVector newWeights = newModelData.weights; + DenseVector[] newCentroids = newModelData.centroids; + + int k = newCentroids.length; + int dim = newCentroids[0].size(); + + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + centroids[i].values[j] = + (centroids[i].values[j] * weights.values[i] + + newCentroids[i].values[j] * newWeights.values[i]) + / Math.max(weights.values[i] + newWeights.values[i], 1e-16); + } + weights.values[i] += newWeights.values[i]; + } + + return new KMeansModelData(centroids, weights); + } + } + + private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData> + implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private ListState<DenseVector[]> localBatchState; + private ListState<KMeansModelData> modelDataState; + + private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + TypeInformation<DenseVector[]> type = + ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + localBatchState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("localBatch", type)); + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelData", KMeansModelData.class)); + } + + @Override + public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception { + localBatchState.add(pointsRecord.getValue()); + alignAndComputeModelData(); + } + + @Override + public void processElement2(StreamRecord<KMeansModelData> modelDataRecord) + throws Exception { + modelDataState.add(modelDataRecord.getValue()); + alignAndComputeModelData(); + } + + private void alignAndComputeModelData() throws Exception { + if (!modelDataState.get().iterator().hasNext() + || !localBatchState.get().iterator().hasNext()) { + return; + } + + KMeansModelData modelData = + OperatorStateUtils.getUniqueElement(modelDataState, "modelData") + .orElseThrow((Supplier<Exception>) NullPointerException::new); + DenseVector[] centroids = modelData.centroids; + DenseVector weights = modelData.weights; + modelDataState.clear(); + + List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator()); + DenseVector[] points = pointsList.remove(0); + localBatchState.update(pointsList); + + int dim = centroids[0].size(); + int k = centroids.length; + int parallelism = getRuntimeContext().getNumberOfParallelSubtasks(); + + // Computes new centroids. + DenseVector[] sums = new DenseVector[k]; + int[] counts = new int[k]; + + for (int i = 0; i < k; i++) { + sums[i] = new DenseVector(dim); + counts[i] = 0; + } + for (DenseVector point : points) { + int closestCentroidId = + KMeans.findClosestCentroidId(centroids, point, distanceMeasure); + counts[closestCentroidId]++; + for (int j = 0; j < dim; j++) { + sums[closestCentroidId].values[j] += point.values[j]; + } + } + + // Considers weight and decay factor when updating centroids. + BLAS.scal(decayFactor / parallelism, weights); + for (int i = 0; i < k; i++) { + if (counts[i] == 0) { + continue; + } + + DenseVector centroid = centroids[i]; + weights.values[i] = weights.values[i] + counts[i]; + double lambda = counts[i] / weights.values[i]; + + BLAS.scal(1.0 - lambda, centroid); + BLAS.axpy(lambda / counts[i], sums[i], centroid); + } + + output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights))); + } + } + + private static class FeaturesExtractor implements MapFunction<Row, DenseVector> { + private final String featuresCol; + + private FeaturesExtractor(String featuresCol) { + this.featuresCol = featuresCol; + } + + @Override + public DenseVector map(Row row) throws Exception { Review comment: Would it be simpler to remove `throws Exception` here? Same for other methods. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org