This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 57650e1  [FLINK-26269][python] Add clustering algorithm support for 
KMeans in ML Python API
57650e1 is described below

commit 57650e155c2ee68c0d333144da2cc8f4e1f0585f
Author: huangxingbo <hxbks...@gmail.com>
AuthorDate: Mon Apr 25 15:02:52 2022 +0800

    [FLINK-26269][python] Add clustering algorithm support for KMeans in ML 
Python API
    
    This closes #91.
---
 flink-ml-python/pyflink/ml/core/wrapper.py         |   5 +-
 .../pyflink/ml/lib/classification/knn.py           |  13 +-
 .../ml/lib/classification/logisticregression.py    |  10 +-
 .../pyflink/ml/lib/classification/naivebayes.py    |  11 +-
 .../pyflink/ml/lib/clustering/__init__.py          |  17 ++
 .../pyflink/ml/lib/clustering/common.py            |  74 ++++++++
 .../pyflink/ml/lib/clustering/kmeans.py            | 192 ++++++++++++++++++++
 .../pyflink/ml/lib/clustering/tests/__init__.py    |  30 +++
 .../pyflink/ml/lib/clustering/tests/test_kmeans.py | 202 +++++++++++++++++++++
 flink-ml-python/pyflink/ml/lib/param.py            |  39 ++++
 10 files changed, 581 insertions(+), 12 deletions(-)

diff --git a/flink-ml-python/pyflink/ml/core/wrapper.py 
b/flink-ml-python/pyflink/ml/core/wrapper.py
index 104048b..f78f3e7 100644
--- a/flink-ml-python/pyflink/ml/core/wrapper.py
+++ b/flink-ml-python/pyflink/ml/core/wrapper.py
@@ -58,7 +58,10 @@ class JavaWithParams(WithParams, JavaWrapper):
         'weight_col': 'weightCol',
         'k': 'k',
         'model_type': 'modelType',
-        'smoothing': 'smoothing'
+        'smoothing': 'smoothing',
+        'init_mode': 'initMode',
+        'batch_strategy': 'batchStrategy',
+        'decay_factor': 'decayFactor'
     }
 
     def __init__(self, java_params):
diff --git a/flink-ml-python/pyflink/ml/lib/classification/knn.py 
b/flink-ml-python/pyflink/ml/lib/classification/knn.py
index 6264fd9..c084143 100644
--- a/flink-ml-python/pyflink/ml/lib/classification/knn.py
+++ b/flink-ml-python/pyflink/ml/lib/classification/knn.py
@@ -17,6 +17,8 @@
 
################################################################################
 from abc import ABC
 
+import typing
+
 from pyflink.ml.core.param import Param, IntParam, ParamValidators
 from pyflink.ml.core.wrapper import JavaWithParams
 from pyflink.ml.lib.classification.common import (JavaClassificationModel,
@@ -25,6 +27,7 @@ from pyflink.ml.lib.param import HasFeaturesCol, 
HasPredictionCol, HasLabelCol
 
 
 class _KNNModelParams(
+    JavaWithParams,
     HasFeaturesCol,
     HasPredictionCol,
     ABC
@@ -39,8 +42,11 @@ class _KNNModelParams(
         5,
         ParamValidators.gt(0))
 
+    def __init__(self, java_params):
+        super(_KNNModelParams, self).__init__(java_params)
+
     def set_k(self, value: int):
-        return self.set(self.K, value)
+        return typing.cast(_KNNModelParams, self.set(self.K, value))
 
     def get_k(self) -> int:
         return self.get(self.K)
@@ -51,9 +57,8 @@ class _KNNModelParams(
 
 
 class _KNNParams(
-    JavaWithParams,
-    HasLabelCol,
-    _KNNModelParams
+    _KNNModelParams,
+    HasLabelCol
 ):
     """
     Params for :class:`KNN`.
diff --git 
a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py 
b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py
index 35bc783..97adcd2 100644
--- a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py
+++ b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py
@@ -26,6 +26,7 @@ from pyflink.ml.lib.param import (HasWeightCol, HasMaxIter, 
HasReg, HasLearningR
 
 
 class _LogisticRegressionModelParams(
+    JavaWithParams,
     HasFeaturesCol,
     HasPredictionCol,
     HasRawPredictionCol,
@@ -34,11 +35,13 @@ class _LogisticRegressionModelParams(
     """
     Params for :class:`LogisticRegressionModel`.
     """
-    pass
+
+    def __init__(self, java_params):
+        super(_LogisticRegressionModelParams, self).__init__(java_params)
 
 
 class _LogisticRegressionParams(
-    JavaWithParams,
+    _LogisticRegressionModelParams,
     HasLabelCol,
     HasWeightCol,
     HasMaxIter,
@@ -46,8 +49,7 @@ class _LogisticRegressionParams(
     HasLearningRate,
     HasGlobalBatchSize,
     HasTol,
-    HasMultiClass,
-    _LogisticRegressionModelParams
+    HasMultiClass
 ):
     """
     Params for :class:`LogisticRegression`.
diff --git a/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py 
b/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py
index 77d0cc1..46b0c53 100644
--- a/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py
+++ b/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py
@@ -17,6 +17,8 @@
 
################################################################################
 from abc import ABC
 
+import typing
+
 from pyflink.ml.core.param import Param, StringParam, ParamValidators, 
FloatParam
 from pyflink.ml.core.wrapper import JavaWithParams
 from pyflink.ml.lib.classification.common import (JavaClassificationModel,
@@ -25,6 +27,7 @@ from pyflink.ml.lib.param import HasFeaturesCol, 
HasPredictionCol, HasLabelCol
 
 
 class _NaiveBayesModelParams(
+    JavaWithParams,
     HasFeaturesCol,
     HasPredictionCol,
     ABC
@@ -39,6 +42,9 @@ class _NaiveBayesModelParams(
         "multinomial",
         ParamValidators.in_array(["multinomial"]))
 
+    def __init__(self, java_params):
+        super(_NaiveBayesModelParams, self).__init__(java_params)
+
     def set_model_type(self, value: str):
         return self.set(self.MODEL_TYPE, value)
 
@@ -51,9 +57,8 @@ class _NaiveBayesModelParams(
 
 
 class _NaiveBayesParams(
-    JavaWithParams,
+    _NaiveBayesModelParams,
     HasLabelCol,
-    _NaiveBayesModelParams
 ):
     """
     Params for :class:`NaiveBayes`.
@@ -69,7 +74,7 @@ class _NaiveBayesParams(
         super(_NaiveBayesParams, self).__init__(java_params)
 
     def set_smoothing(self, value: float):
-        return self.set(self.SMOOTHING, value)
+        return typing.cast(_NaiveBayesParams, self.set(self.SMOOTHING, value))
 
     def get_smoothing(self) -> float:
         return self.get(self.SMOOTHING)
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/__init__.py 
b/flink-ml-python/pyflink/ml/lib/clustering/__init__.py
new file mode 100644
index 0000000..65b48d4
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/common.py 
b/flink-ml-python/pyflink/ml/lib/clustering/common.py
new file mode 100644
index 0000000..2665db1
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/common.py
@@ -0,0 +1,74 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from abc import ABC, abstractmethod
+
+from pyflink.ml.core.wrapper import JavaModel, JavaEstimator
+
+JAVA_CLUSTERING_PACKAGE_NAME = "org.apache.flink.ml.clustering"
+
+
+class JavaClusteringModel(JavaModel, ABC):
+    """
+    Wrapper class for a Java Clustering Model.
+    """
+
+    def __init__(self, java_model):
+        super(JavaClusteringModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_path(cls) -> str:
+        return ".".join(
+            [JAVA_CLUSTERING_PACKAGE_NAME,
+             cls._java_model_package_name(),
+             cls._java_model_class_name()])
+
+    @classmethod
+    @abstractmethod
+    def _java_model_package_name(cls) -> str:
+        pass
+
+    @classmethod
+    @abstractmethod
+    def _java_model_class_name(cls) -> str:
+        pass
+
+
+class JavaClusteringEstimator(JavaEstimator, ABC):
+    """
+    Wrapper class for a Java Clustering Estimator.
+    """
+
+    def __init__(self):
+        super(JavaClusteringEstimator, self).__init__()
+
+    @classmethod
+    def _java_estimator_path(cls):
+        return ".".join(
+            [JAVA_CLUSTERING_PACKAGE_NAME,
+             cls._java_estimator_package_name(),
+             cls._java_estimator_class_name()])
+
+    @classmethod
+    @abstractmethod
+    def _java_estimator_package_name(cls) -> str:
+        pass
+
+    @classmethod
+    @abstractmethod
+    def _java_estimator_class_name(cls) -> str:
+        pass
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/kmeans.py 
b/flink-ml-python/pyflink/ml/lib/clustering/kmeans.py
new file mode 100644
index 0000000..4037351
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/kmeans.py
@@ -0,0 +1,192 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from abc import ABC
+
+import typing
+
+from pyflink.ml.core.param import ParamValidators, Param, IntParam, StringParam
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.clustering.common import JavaClusteringModel, 
JavaClusteringEstimator
+from pyflink.ml.lib.param import (HasDistanceMeasure, HasFeaturesCol, 
HasPredictionCol,
+                                  HasBatchStrategy, HasGlobalBatchSize, 
HasDecayFactor, HasSeed,
+                                  HasMaxIter)
+
+
+class _KMeansModelParams(
+    JavaWithParams,
+    HasDistanceMeasure,
+    HasFeaturesCol,
+    HasPredictionCol,
+    ABC
+):
+    """
+    Params for :class:`KMeansModel`.
+    """
+
+    K: Param[int] = IntParam(
+        "k",
+        "The max number of clusters to create.",
+        2,
+        ParamValidators.gt(1))
+
+    def __init__(self, java_params):
+        super(_KMeansModelParams, self).__init__(java_params)
+
+    def set_k(self, value: int):
+        return typing.cast(_KMeansModelParams, self.set(self.K, value))
+
+    def get_k(self) -> int:
+        return self.get(self.K)
+
+    @property
+    def k(self) -> int:
+        return self.get_k()
+
+
+class _KMeansParams(
+    _KMeansModelParams,
+    HasSeed,
+    HasMaxIter
+):
+    """
+    Params for :class:`KMeans`.
+    """
+    INIT_MODE: Param[str] = StringParam(
+        "init_mode",
+        "The initialization algorithm. Supported options: 'random'.",
+        "random",
+        ParamValidators.in_array(["random"]))
+
+    def __init__(self, java_params):
+        super(_KMeansParams, self).__init__(java_params)
+
+    def set_init_mode(self, value: str):
+        return self.set(self.INIT_MODE, value)
+
+    def get_init_mode(self) -> str:
+        return self.get(self.INIT_MODE)
+
+    @property
+    def init_mode(self):
+        return self.get_init_mode()
+
+
+class _OnlineKMeansParams(
+    _KMeansModelParams,
+    HasBatchStrategy,
+    HasGlobalBatchSize,
+    HasDecayFactor,
+    HasSeed,
+):
+    """
+    Params of :class:OnlineKMeans.
+    """
+
+    def __init__(self, java_params):
+        super(_OnlineKMeansParams, self).__init__(java_params)
+
+
+class KMeansModel(JavaClusteringModel, _KMeansModelParams):
+    """
+    A Model which clusters data into k clusters using the model data computed 
by :class:`KMeans`.
+    """
+
+    def __init__(self, java_model=None):
+        super(KMeansModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "kmeans"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "KMeansModel"
+
+
+class OnlineKMeansModel(JavaClusteringModel, _KMeansModelParams):
+    """
+    OnlineKMeansModel can be regarded as an advanced :class:`KMeansModel` 
operator which can update
+    model data in a streaming format, using the model data provided by 
:class:`OnlineKMeans`.
+    """
+
+    def __init__(self, java_model=None):
+        super(OnlineKMeansModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "kmeans"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "OnlineKMeansModel"
+
+
+class KMeans(JavaClusteringEstimator, _KMeansParams):
+    """
+    An Estimator which implements the k-means clustering algorithm.
+
+    See https://en.wikipedia.org/wiki/K-means_clustering.
+    """
+
+    def __init__(self):
+        super(KMeans, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> KMeansModel:
+        return KMeansModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "kmeans"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "KMeans"
+
+
+class OnlineKMeans(JavaClusteringEstimator, _OnlineKMeansParams):
+    """
+    OnlineKMeans extends the function of :class:`KMeans`, supporting to train 
a K-Means model
+    continuously according to an unbounded stream of train data.
+
+    OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized 
to incorporate
+    forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+    OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+    estimated centroids. The weight of the estimated centroids is the number 
of points assigned to
+    them. The weight of the original centroids is also the number of points, 
but additionally
+    multiplying with the decay factor.
+
+    The decay factor scales the contribution of the clusters as estimated thus 
far. If the decay
+    factor is 1, all batches are weighted equally. If the decay factor is 0, 
new centroids are
+    determined entirely by recent data. Lower values correspond to more 
forgetting.
+    """
+
+    def __init__(self):
+        super(OnlineKMeans, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> KMeansModel:
+        return KMeansModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "kmeans"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "OnlineKMeans"
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/__init__.py 
b/flink-ml-python/pyflink/ml/lib/clustering/tests/__init__.py
new file mode 100644
index 0000000..6698191
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/__init__.py
@@ -0,0 +1,30 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+import sys
+from pathlib import Path
+
+# Because the project and the dependent `pyflink` project have the same 
directory structure,
+# we need to manually add `flink-ml-python` path to `sys.path` in the test of 
this project to change
+# the order of package search.
+flink_ml_python_dir = Path(__file__).parents[5]
+sys.path.append(str(flink_ml_python_dir))
+
+import pyflink
+
+pyflink.__path__.insert(0, os.path.join(flink_ml_python_dir, 'pyflink'))
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py 
b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
new file mode 100644
index 0000000..f4738c5
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
@@ -0,0 +1,202 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+import typing
+from pyflink.common import Types, Row
+from typing import List, Dict, Set
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, DenseVector
+from pyflink.ml.lib.clustering.kmeans import KMeans, KMeansModel, OnlineKMeans
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+def group_features_by_prediction(
+        rows: List[Row], feature_index: int, prediction_index: int):
+    map = {}  # type: Dict[int, Set]
+    for row in rows:
+        vector = typing.cast(DenseVector, row[feature_index])
+        predict = typing.cast(int, row[prediction_index])
+        if predict in map:
+            l = map[predict]
+        else:
+            l = set()
+            map[predict] = l
+        l.add(vector)
+    return [item for item in map.values()]
+
+
+class KMeansTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(KMeansTest, self).setUp()
+        self.data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense([0.0, 0.0]),),
+                (Vectors.dense([0.0, 0.3]),),
+                (Vectors.dense([0.3, 3.0]),),
+                (Vectors.dense([9.0, 0.0]),),
+                (Vectors.dense([9.0, 0.6]),),
+                (Vectors.dense([9.6, 0.0]),),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['features'],
+                    [DenseVectorTypeInfo()])))
+        self.expected_groups = [
+            {DenseVector([0.0, 0.3]), DenseVector([0.3, 3.0]), 
DenseVector([0.0, 0.0])},
+            {DenseVector([9.6, 0.0]), DenseVector([9.0, 0.0]), 
DenseVector([9.0, 0.6])}]
+
+    def test_param(self):
+        kmeans = KMeans()
+        self.assertEqual('features', kmeans.get_features_col())
+        self.assertEqual('prediction', kmeans.get_prediction_col())
+        self.assertEqual('euclidean', kmeans.get_distance_measure())
+        self.assertEqual('random', kmeans.get_init_mode())
+        self.assertEqual(2, kmeans.get_k())
+        self.assertEqual(20, kmeans.get_max_iter())
+
+        kmeans.set_k(9) \
+            .set_features_col('test_feature') \
+            .set_prediction_col('test_prediction') \
+            .set_k(3) \
+            .set_max_iter(30) \
+            .set_seed(100)
+
+        self.assertEqual('test_feature', kmeans.get_features_col())
+        self.assertEqual('test_prediction', kmeans.get_prediction_col())
+        self.assertEqual(3, kmeans.get_k())
+        self.assertEqual(30, kmeans.get_max_iter())
+        self.assertEqual(100, kmeans.get_seed())
+
+    def test_output_schema(self):
+        input = self.data_table.alias('test_feature')
+        kmeans = 
KMeans().set_features_col('test_feature').set_prediction_col('test_prediction')
+
+        model = kmeans.fit(input)
+        output = model.transform(input)[0]
+
+        field_names = output.get_schema().get_field_names()
+        self.assertEqual(['test_feature', 'test_prediction'],
+                         field_names)
+
+        results = [result for result in 
self.t_env.to_data_stream(output).execute_and_collect()]
+        actual_groups = group_features_by_prediction(
+            results,
+            field_names.index(kmeans.features_col),
+            field_names.index(kmeans.prediction_col))
+
+        self.assertTrue(actual_groups[0] == self.expected_groups[0]
+                        and actual_groups[1] == self.expected_groups[1] or
+                        actual_groups[0] == self.expected_groups[1]
+                        and actual_groups[1] == self.expected_groups[0])
+
+    def test_fewer_distinct_points_than_cluster(self):
+        input = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense([0.0, 0.1]),),
+                (Vectors.dense([0.0, 0.1]),),
+                (Vectors.dense([0.0, 0.1]),),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['features'],
+                    [DenseVectorTypeInfo()])))
+
+        kmeans = KMeans().set_k(2)
+        model = kmeans.fit(input)
+        output = model.transform(input)[0]
+        results = [result for result in 
self.t_env.to_data_stream(output).execute_and_collect()]
+        field_names = output.get_schema().get_field_names()
+        actual_groups = group_features_by_prediction(
+            results,
+            field_names.index(kmeans.features_col),
+            field_names.index(kmeans.prediction_col))
+
+        expected_groups = [{DenseVector([0.0, 0.1])}]
+
+        self.assertEqual(actual_groups, expected_groups)
+
+    def test_fit_and_predict(self):
+        kmeans = KMeans().set_max_iter(2).set_k(2)
+        model = kmeans.fit(self.data_table)
+        output = model.transform(self.data_table)[0]
+
+        self.assertEqual(['features', 'prediction'], 
output.get_schema().get_field_names())
+        results = [result for result in 
self.t_env.to_data_stream(output).execute_and_collect()]
+        field_names = output.get_schema().get_field_names()
+        actual_groups = group_features_by_prediction(
+            results,
+            field_names.index(kmeans.features_col),
+            field_names.index(kmeans.prediction_col))
+
+        self.assertTrue(actual_groups[0] == self.expected_groups[0]
+                        and actual_groups[1] == self.expected_groups[1] or
+                        actual_groups[0] == self.expected_groups[1]
+                        and actual_groups[1] == self.expected_groups[0])
+
+    def test_save_load_and_predict(self):
+        kmeans = KMeans().set_max_iter(2).set_k(2)
+        model = kmeans.fit(self.data_table)
+        path = os.path.join(self.temp_dir, 
'test_save_load_and_predict_kmeans_model')
+        model.save(path)
+        self.env.execute()
+        loaded_model = KMeansModel.load(self.t_env, path)  # type: KMeansModel
+        output = loaded_model.transform(self.data_table)[0]
+        self.assertEqual(
+            ['centroids', 'weights'],
+            loaded_model.get_model_data()[0].get_schema().get_field_names())
+
+        self.assertEqual(
+            ['features', 'prediction'],
+            output.get_schema().get_field_names())
+
+        results = [result for result in 
self.t_env.to_data_stream(output).execute_and_collect()]
+        field_names = output.get_schema().get_field_names()
+        actual_groups = group_features_by_prediction(
+            results,
+            field_names.index(kmeans.features_col),
+            field_names.index(kmeans.prediction_col))
+        self.assertTrue(actual_groups[0] == self.expected_groups[0]
+                        and actual_groups[1] == self.expected_groups[1] or
+                        actual_groups[0] == self.expected_groups[1]
+                        and actual_groups[1] == self.expected_groups[0])
+
+
+class OnlineKMeansTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(OnlineKMeansTest, self).setUp()
+
+    def test_param(self):
+        online_kmeans = OnlineKMeans()
+        self.assertEqual('features', online_kmeans.features_col)
+        self.assertEqual('prediction', online_kmeans.prediction_col)
+        self.assertEqual('count', online_kmeans.batch_strategy)
+        self.assertEqual('euclidean', online_kmeans.distance_measure)
+        self.assertEqual(32, online_kmeans.global_batch_size)
+        self.assertEqual(0., online_kmeans.decay_factor)
+
+        online_kmeans.set_features_col('test_feature') \
+            .set_prediction_col('test_prediction') \
+            .set_global_batch_size(5) \
+            .set_decay_factor(0.25) \
+            .set_seed(100)
+
+        self.assertEqual('test_feature', online_kmeans.features_col)
+        self.assertEqual('test_prediction', online_kmeans.prediction_col)
+        self.assertEqual('count', online_kmeans.batch_strategy)
+        self.assertEqual('euclidean', online_kmeans.distance_measure)
+        self.assertEqual(5, online_kmeans.global_batch_size)
+        self.assertEqual(0.25, online_kmeans.decay_factor)
+        self.assertEqual(100, online_kmeans.get_seed())
diff --git a/flink-ml-python/pyflink/ml/lib/param.py 
b/flink-ml-python/pyflink/ml/lib/param.py
index a24902c..51a607c 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -369,3 +369,42 @@ class HasWeightCol(WithParams, ABC):
     @property
     def weight_col(self):
         return self.get_weight_col()
+
+
+class HasBatchStrategy(WithParams, ABC):
+    """
+    Base class for the shared batch strategy param.
+    """
+    BATCH_STRATEGY: Param[str] = StringParam(
+        "batch_strategy",
+        "Strategy to create mini batch from online train data.",
+        "count",
+        ParamValidators.in_array(["count"]))
+
+    def get_batch_strategy(self) -> str:
+        return self.get(self.BATCH_STRATEGY)
+
+    @property
+    def batch_strategy(self):
+        return self.get_batch_strategy()
+
+
+class HasDecayFactor(WithParams, ABC):
+    """
+    Base class for the shared decay factor param.
+    """
+    DECAY_FACTOR: Param[float] = FloatParam(
+        "decay_factor",
+        "The forgetfulness of the previous centroids.",
+        0.,
+        ParamValidators.in_range(0, 1))
+
+    def set_decay_factor(self, value: float):
+        return self.set(self.DECAY_FACTOR, value)
+
+    def get_decay_factor(self) -> float:
+        return self.get(self.DECAY_FACTOR)
+
+    @property
+    def decay_factor(self):
+        return self.get(self.DECAY_FACTOR)

Reply via email to