This is an automated email from the ASF dual-hosted git repository. hxb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 57650e1 [FLINK-26269][python] Add clustering algorithm support for KMeans in ML Python API 57650e1 is described below commit 57650e155c2ee68c0d333144da2cc8f4e1f0585f Author: huangxingbo <hxbks...@gmail.com> AuthorDate: Mon Apr 25 15:02:52 2022 +0800 [FLINK-26269][python] Add clustering algorithm support for KMeans in ML Python API This closes #91. --- flink-ml-python/pyflink/ml/core/wrapper.py | 5 +- .../pyflink/ml/lib/classification/knn.py | 13 +- .../ml/lib/classification/logisticregression.py | 10 +- .../pyflink/ml/lib/classification/naivebayes.py | 11 +- .../pyflink/ml/lib/clustering/__init__.py | 17 ++ .../pyflink/ml/lib/clustering/common.py | 74 ++++++++ .../pyflink/ml/lib/clustering/kmeans.py | 192 ++++++++++++++++++++ .../pyflink/ml/lib/clustering/tests/__init__.py | 30 +++ .../pyflink/ml/lib/clustering/tests/test_kmeans.py | 202 +++++++++++++++++++++ flink-ml-python/pyflink/ml/lib/param.py | 39 ++++ 10 files changed, 581 insertions(+), 12 deletions(-) diff --git a/flink-ml-python/pyflink/ml/core/wrapper.py b/flink-ml-python/pyflink/ml/core/wrapper.py index 104048b..f78f3e7 100644 --- a/flink-ml-python/pyflink/ml/core/wrapper.py +++ b/flink-ml-python/pyflink/ml/core/wrapper.py @@ -58,7 +58,10 @@ class JavaWithParams(WithParams, JavaWrapper): 'weight_col': 'weightCol', 'k': 'k', 'model_type': 'modelType', - 'smoothing': 'smoothing' + 'smoothing': 'smoothing', + 'init_mode': 'initMode', + 'batch_strategy': 'batchStrategy', + 'decay_factor': 'decayFactor' } def __init__(self, java_params): diff --git a/flink-ml-python/pyflink/ml/lib/classification/knn.py b/flink-ml-python/pyflink/ml/lib/classification/knn.py index 6264fd9..c084143 100644 --- a/flink-ml-python/pyflink/ml/lib/classification/knn.py +++ b/flink-ml-python/pyflink/ml/lib/classification/knn.py @@ -17,6 +17,8 @@ ################################################################################ from abc import ABC +import typing + from pyflink.ml.core.param import Param, IntParam, ParamValidators from pyflink.ml.core.wrapper import JavaWithParams from pyflink.ml.lib.classification.common import (JavaClassificationModel, @@ -25,6 +27,7 @@ from pyflink.ml.lib.param import HasFeaturesCol, HasPredictionCol, HasLabelCol class _KNNModelParams( + JavaWithParams, HasFeaturesCol, HasPredictionCol, ABC @@ -39,8 +42,11 @@ class _KNNModelParams( 5, ParamValidators.gt(0)) + def __init__(self, java_params): + super(_KNNModelParams, self).__init__(java_params) + def set_k(self, value: int): - return self.set(self.K, value) + return typing.cast(_KNNModelParams, self.set(self.K, value)) def get_k(self) -> int: return self.get(self.K) @@ -51,9 +57,8 @@ class _KNNModelParams( class _KNNParams( - JavaWithParams, - HasLabelCol, - _KNNModelParams + _KNNModelParams, + HasLabelCol ): """ Params for :class:`KNN`. diff --git a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py index 35bc783..97adcd2 100644 --- a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py +++ b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py @@ -26,6 +26,7 @@ from pyflink.ml.lib.param import (HasWeightCol, HasMaxIter, HasReg, HasLearningR class _LogisticRegressionModelParams( + JavaWithParams, HasFeaturesCol, HasPredictionCol, HasRawPredictionCol, @@ -34,11 +35,13 @@ class _LogisticRegressionModelParams( """ Params for :class:`LogisticRegressionModel`. """ - pass + + def __init__(self, java_params): + super(_LogisticRegressionModelParams, self).__init__(java_params) class _LogisticRegressionParams( - JavaWithParams, + _LogisticRegressionModelParams, HasLabelCol, HasWeightCol, HasMaxIter, @@ -46,8 +49,7 @@ class _LogisticRegressionParams( HasLearningRate, HasGlobalBatchSize, HasTol, - HasMultiClass, - _LogisticRegressionModelParams + HasMultiClass ): """ Params for :class:`LogisticRegression`. diff --git a/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py b/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py index 77d0cc1..46b0c53 100644 --- a/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py +++ b/flink-ml-python/pyflink/ml/lib/classification/naivebayes.py @@ -17,6 +17,8 @@ ################################################################################ from abc import ABC +import typing + from pyflink.ml.core.param import Param, StringParam, ParamValidators, FloatParam from pyflink.ml.core.wrapper import JavaWithParams from pyflink.ml.lib.classification.common import (JavaClassificationModel, @@ -25,6 +27,7 @@ from pyflink.ml.lib.param import HasFeaturesCol, HasPredictionCol, HasLabelCol class _NaiveBayesModelParams( + JavaWithParams, HasFeaturesCol, HasPredictionCol, ABC @@ -39,6 +42,9 @@ class _NaiveBayesModelParams( "multinomial", ParamValidators.in_array(["multinomial"])) + def __init__(self, java_params): + super(_NaiveBayesModelParams, self).__init__(java_params) + def set_model_type(self, value: str): return self.set(self.MODEL_TYPE, value) @@ -51,9 +57,8 @@ class _NaiveBayesModelParams( class _NaiveBayesParams( - JavaWithParams, + _NaiveBayesModelParams, HasLabelCol, - _NaiveBayesModelParams ): """ Params for :class:`NaiveBayes`. @@ -69,7 +74,7 @@ class _NaiveBayesParams( super(_NaiveBayesParams, self).__init__(java_params) def set_smoothing(self, value: float): - return self.set(self.SMOOTHING, value) + return typing.cast(_NaiveBayesParams, self.set(self.SMOOTHING, value)) def get_smoothing(self) -> float: return self.get(self.SMOOTHING) diff --git a/flink-ml-python/pyflink/ml/lib/clustering/__init__.py b/flink-ml-python/pyflink/ml/lib/clustering/__init__.py new file mode 100644 index 0000000..65b48d4 --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/clustering/__init__.py @@ -0,0 +1,17 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ diff --git a/flink-ml-python/pyflink/ml/lib/clustering/common.py b/flink-ml-python/pyflink/ml/lib/clustering/common.py new file mode 100644 index 0000000..2665db1 --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/clustering/common.py @@ -0,0 +1,74 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from abc import ABC, abstractmethod + +from pyflink.ml.core.wrapper import JavaModel, JavaEstimator + +JAVA_CLUSTERING_PACKAGE_NAME = "org.apache.flink.ml.clustering" + + +class JavaClusteringModel(JavaModel, ABC): + """ + Wrapper class for a Java Clustering Model. + """ + + def __init__(self, java_model): + super(JavaClusteringModel, self).__init__(java_model) + + @classmethod + def _java_model_path(cls) -> str: + return ".".join( + [JAVA_CLUSTERING_PACKAGE_NAME, + cls._java_model_package_name(), + cls._java_model_class_name()]) + + @classmethod + @abstractmethod + def _java_model_package_name(cls) -> str: + pass + + @classmethod + @abstractmethod + def _java_model_class_name(cls) -> str: + pass + + +class JavaClusteringEstimator(JavaEstimator, ABC): + """ + Wrapper class for a Java Clustering Estimator. + """ + + def __init__(self): + super(JavaClusteringEstimator, self).__init__() + + @classmethod + def _java_estimator_path(cls): + return ".".join( + [JAVA_CLUSTERING_PACKAGE_NAME, + cls._java_estimator_package_name(), + cls._java_estimator_class_name()]) + + @classmethod + @abstractmethod + def _java_estimator_package_name(cls) -> str: + pass + + @classmethod + @abstractmethod + def _java_estimator_class_name(cls) -> str: + pass diff --git a/flink-ml-python/pyflink/ml/lib/clustering/kmeans.py b/flink-ml-python/pyflink/ml/lib/clustering/kmeans.py new file mode 100644 index 0000000..4037351 --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/clustering/kmeans.py @@ -0,0 +1,192 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +from abc import ABC + +import typing + +from pyflink.ml.core.param import ParamValidators, Param, IntParam, StringParam +from pyflink.ml.core.wrapper import JavaWithParams +from pyflink.ml.lib.clustering.common import JavaClusteringModel, JavaClusteringEstimator +from pyflink.ml.lib.param import (HasDistanceMeasure, HasFeaturesCol, HasPredictionCol, + HasBatchStrategy, HasGlobalBatchSize, HasDecayFactor, HasSeed, + HasMaxIter) + + +class _KMeansModelParams( + JavaWithParams, + HasDistanceMeasure, + HasFeaturesCol, + HasPredictionCol, + ABC +): + """ + Params for :class:`KMeansModel`. + """ + + K: Param[int] = IntParam( + "k", + "The max number of clusters to create.", + 2, + ParamValidators.gt(1)) + + def __init__(self, java_params): + super(_KMeansModelParams, self).__init__(java_params) + + def set_k(self, value: int): + return typing.cast(_KMeansModelParams, self.set(self.K, value)) + + def get_k(self) -> int: + return self.get(self.K) + + @property + def k(self) -> int: + return self.get_k() + + +class _KMeansParams( + _KMeansModelParams, + HasSeed, + HasMaxIter +): + """ + Params for :class:`KMeans`. + """ + INIT_MODE: Param[str] = StringParam( + "init_mode", + "The initialization algorithm. Supported options: 'random'.", + "random", + ParamValidators.in_array(["random"])) + + def __init__(self, java_params): + super(_KMeansParams, self).__init__(java_params) + + def set_init_mode(self, value: str): + return self.set(self.INIT_MODE, value) + + def get_init_mode(self) -> str: + return self.get(self.INIT_MODE) + + @property + def init_mode(self): + return self.get_init_mode() + + +class _OnlineKMeansParams( + _KMeansModelParams, + HasBatchStrategy, + HasGlobalBatchSize, + HasDecayFactor, + HasSeed, +): + """ + Params of :class:OnlineKMeans. + """ + + def __init__(self, java_params): + super(_OnlineKMeansParams, self).__init__(java_params) + + +class KMeansModel(JavaClusteringModel, _KMeansModelParams): + """ + A Model which clusters data into k clusters using the model data computed by :class:`KMeans`. + """ + + def __init__(self, java_model=None): + super(KMeansModel, self).__init__(java_model) + + @classmethod + def _java_model_package_name(cls) -> str: + return "kmeans" + + @classmethod + def _java_model_class_name(cls) -> str: + return "KMeansModel" + + +class OnlineKMeansModel(JavaClusteringModel, _KMeansModelParams): + """ + OnlineKMeansModel can be regarded as an advanced :class:`KMeansModel` operator which can update + model data in a streaming format, using the model data provided by :class:`OnlineKMeans`. + """ + + def __init__(self, java_model=None): + super(OnlineKMeansModel, self).__init__(java_model) + + @classmethod + def _java_model_package_name(cls) -> str: + return "kmeans" + + @classmethod + def _java_model_class_name(cls) -> str: + return "OnlineKMeansModel" + + +class KMeans(JavaClusteringEstimator, _KMeansParams): + """ + An Estimator which implements the k-means clustering algorithm. + + See https://en.wikipedia.org/wiki/K-means_clustering. + """ + + def __init__(self): + super(KMeans, self).__init__() + + @classmethod + def _create_model(cls, java_model) -> KMeansModel: + return KMeansModel(java_model) + + @classmethod + def _java_estimator_package_name(cls) -> str: + return "kmeans" + + @classmethod + def _java_estimator_class_name(cls) -> str: + return "KMeans" + + +class OnlineKMeans(JavaClusteringEstimator, _OnlineKMeansParams): + """ + OnlineKMeans extends the function of :class:`KMeans`, supporting to train a K-Means model + continuously according to an unbounded stream of train data. + + OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + OnlineKMeans computes the new centroids from the weighted average between the original and the + estimated centroids. The weight of the estimated centroids is the number of points assigned to + them. The weight of the original centroids is also the number of points, but additionally + multiplying with the decay factor. + + The decay factor scales the contribution of the clusters as estimated thus far. If the decay + factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are + determined entirely by recent data. Lower values correspond to more forgetting. + """ + + def __init__(self): + super(OnlineKMeans, self).__init__() + + @classmethod + def _create_model(cls, java_model) -> KMeansModel: + return KMeansModel(java_model) + + @classmethod + def _java_estimator_package_name(cls) -> str: + return "kmeans" + + @classmethod + def _java_estimator_class_name(cls) -> str: + return "OnlineKMeans" diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/__init__.py b/flink-ml-python/pyflink/ml/lib/clustering/tests/__init__.py new file mode 100644 index 0000000..6698191 --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/__init__.py @@ -0,0 +1,30 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import os +import sys +from pathlib import Path + +# Because the project and the dependent `pyflink` project have the same directory structure, +# we need to manually add `flink-ml-python` path to `sys.path` in the test of this project to change +# the order of package search. +flink_ml_python_dir = Path(__file__).parents[5] +sys.path.append(str(flink_ml_python_dir)) + +import pyflink + +pyflink.__path__.insert(0, os.path.join(flink_ml_python_dir, 'pyflink')) diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py new file mode 100644 index 0000000..f4738c5 --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py @@ -0,0 +1,202 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import os +import typing +from pyflink.common import Types, Row +from typing import List, Dict, Set + +from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, DenseVector +from pyflink.ml.lib.clustering.kmeans import KMeans, KMeansModel, OnlineKMeans +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase + + +def group_features_by_prediction( + rows: List[Row], feature_index: int, prediction_index: int): + map = {} # type: Dict[int, Set] + for row in rows: + vector = typing.cast(DenseVector, row[feature_index]) + predict = typing.cast(int, row[prediction_index]) + if predict in map: + l = map[predict] + else: + l = set() + map[predict] = l + l.add(vector) + return [item for item in map.values()] + + +class KMeansTest(PyFlinkMLTestCase): + def setUp(self): + super(KMeansTest, self).setUp() + self.data_table = self.t_env.from_data_stream( + self.env.from_collection([ + (Vectors.dense([0.0, 0.0]),), + (Vectors.dense([0.0, 0.3]),), + (Vectors.dense([0.3, 3.0]),), + (Vectors.dense([9.0, 0.0]),), + (Vectors.dense([9.0, 0.6]),), + (Vectors.dense([9.6, 0.0]),), + ], + type_info=Types.ROW_NAMED( + ['features'], + [DenseVectorTypeInfo()]))) + self.expected_groups = [ + {DenseVector([0.0, 0.3]), DenseVector([0.3, 3.0]), DenseVector([0.0, 0.0])}, + {DenseVector([9.6, 0.0]), DenseVector([9.0, 0.0]), DenseVector([9.0, 0.6])}] + + def test_param(self): + kmeans = KMeans() + self.assertEqual('features', kmeans.get_features_col()) + self.assertEqual('prediction', kmeans.get_prediction_col()) + self.assertEqual('euclidean', kmeans.get_distance_measure()) + self.assertEqual('random', kmeans.get_init_mode()) + self.assertEqual(2, kmeans.get_k()) + self.assertEqual(20, kmeans.get_max_iter()) + + kmeans.set_k(9) \ + .set_features_col('test_feature') \ + .set_prediction_col('test_prediction') \ + .set_k(3) \ + .set_max_iter(30) \ + .set_seed(100) + + self.assertEqual('test_feature', kmeans.get_features_col()) + self.assertEqual('test_prediction', kmeans.get_prediction_col()) + self.assertEqual(3, kmeans.get_k()) + self.assertEqual(30, kmeans.get_max_iter()) + self.assertEqual(100, kmeans.get_seed()) + + def test_output_schema(self): + input = self.data_table.alias('test_feature') + kmeans = KMeans().set_features_col('test_feature').set_prediction_col('test_prediction') + + model = kmeans.fit(input) + output = model.transform(input)[0] + + field_names = output.get_schema().get_field_names() + self.assertEqual(['test_feature', 'test_prediction'], + field_names) + + results = [result for result in self.t_env.to_data_stream(output).execute_and_collect()] + actual_groups = group_features_by_prediction( + results, + field_names.index(kmeans.features_col), + field_names.index(kmeans.prediction_col)) + + self.assertTrue(actual_groups[0] == self.expected_groups[0] + and actual_groups[1] == self.expected_groups[1] or + actual_groups[0] == self.expected_groups[1] + and actual_groups[1] == self.expected_groups[0]) + + def test_fewer_distinct_points_than_cluster(self): + input = self.t_env.from_data_stream( + self.env.from_collection([ + (Vectors.dense([0.0, 0.1]),), + (Vectors.dense([0.0, 0.1]),), + (Vectors.dense([0.0, 0.1]),), + ], + type_info=Types.ROW_NAMED( + ['features'], + [DenseVectorTypeInfo()]))) + + kmeans = KMeans().set_k(2) + model = kmeans.fit(input) + output = model.transform(input)[0] + results = [result for result in self.t_env.to_data_stream(output).execute_and_collect()] + field_names = output.get_schema().get_field_names() + actual_groups = group_features_by_prediction( + results, + field_names.index(kmeans.features_col), + field_names.index(kmeans.prediction_col)) + + expected_groups = [{DenseVector([0.0, 0.1])}] + + self.assertEqual(actual_groups, expected_groups) + + def test_fit_and_predict(self): + kmeans = KMeans().set_max_iter(2).set_k(2) + model = kmeans.fit(self.data_table) + output = model.transform(self.data_table)[0] + + self.assertEqual(['features', 'prediction'], output.get_schema().get_field_names()) + results = [result for result in self.t_env.to_data_stream(output).execute_and_collect()] + field_names = output.get_schema().get_field_names() + actual_groups = group_features_by_prediction( + results, + field_names.index(kmeans.features_col), + field_names.index(kmeans.prediction_col)) + + self.assertTrue(actual_groups[0] == self.expected_groups[0] + and actual_groups[1] == self.expected_groups[1] or + actual_groups[0] == self.expected_groups[1] + and actual_groups[1] == self.expected_groups[0]) + + def test_save_load_and_predict(self): + kmeans = KMeans().set_max_iter(2).set_k(2) + model = kmeans.fit(self.data_table) + path = os.path.join(self.temp_dir, 'test_save_load_and_predict_kmeans_model') + model.save(path) + self.env.execute() + loaded_model = KMeansModel.load(self.t_env, path) # type: KMeansModel + output = loaded_model.transform(self.data_table)[0] + self.assertEqual( + ['centroids', 'weights'], + loaded_model.get_model_data()[0].get_schema().get_field_names()) + + self.assertEqual( + ['features', 'prediction'], + output.get_schema().get_field_names()) + + results = [result for result in self.t_env.to_data_stream(output).execute_and_collect()] + field_names = output.get_schema().get_field_names() + actual_groups = group_features_by_prediction( + results, + field_names.index(kmeans.features_col), + field_names.index(kmeans.prediction_col)) + self.assertTrue(actual_groups[0] == self.expected_groups[0] + and actual_groups[1] == self.expected_groups[1] or + actual_groups[0] == self.expected_groups[1] + and actual_groups[1] == self.expected_groups[0]) + + +class OnlineKMeansTest(PyFlinkMLTestCase): + def setUp(self): + super(OnlineKMeansTest, self).setUp() + + def test_param(self): + online_kmeans = OnlineKMeans() + self.assertEqual('features', online_kmeans.features_col) + self.assertEqual('prediction', online_kmeans.prediction_col) + self.assertEqual('count', online_kmeans.batch_strategy) + self.assertEqual('euclidean', online_kmeans.distance_measure) + self.assertEqual(32, online_kmeans.global_batch_size) + self.assertEqual(0., online_kmeans.decay_factor) + + online_kmeans.set_features_col('test_feature') \ + .set_prediction_col('test_prediction') \ + .set_global_batch_size(5) \ + .set_decay_factor(0.25) \ + .set_seed(100) + + self.assertEqual('test_feature', online_kmeans.features_col) + self.assertEqual('test_prediction', online_kmeans.prediction_col) + self.assertEqual('count', online_kmeans.batch_strategy) + self.assertEqual('euclidean', online_kmeans.distance_measure) + self.assertEqual(5, online_kmeans.global_batch_size) + self.assertEqual(0.25, online_kmeans.decay_factor) + self.assertEqual(100, online_kmeans.get_seed()) diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py index a24902c..51a607c 100644 --- a/flink-ml-python/pyflink/ml/lib/param.py +++ b/flink-ml-python/pyflink/ml/lib/param.py @@ -369,3 +369,42 @@ class HasWeightCol(WithParams, ABC): @property def weight_col(self): return self.get_weight_col() + + +class HasBatchStrategy(WithParams, ABC): + """ + Base class for the shared batch strategy param. + """ + BATCH_STRATEGY: Param[str] = StringParam( + "batch_strategy", + "Strategy to create mini batch from online train data.", + "count", + ParamValidators.in_array(["count"])) + + def get_batch_strategy(self) -> str: + return self.get(self.BATCH_STRATEGY) + + @property + def batch_strategy(self): + return self.get_batch_strategy() + + +class HasDecayFactor(WithParams, ABC): + """ + Base class for the shared decay factor param. + """ + DECAY_FACTOR: Param[float] = FloatParam( + "decay_factor", + "The forgetfulness of the previous centroids.", + 0., + ParamValidators.in_range(0, 1)) + + def set_decay_factor(self, value: float): + return self.set(self.DECAY_FACTOR, value) + + def get_decay_factor(self) -> float: + return self.get(self.DECAY_FACTOR) + + @property + def decay_factor(self): + return self.get(self.DECAY_FACTOR)