This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 9af4e9d [FLINK-29786] VarianceThresholdSelector should implement HasInputCol 9af4e9d is described below commit 9af4e9da55747f07532d1755ca82190039f6d45b Author: jiangxin <jiangxin.ji...@alibaba-inc.com> AuthorDate: Fri Oct 28 15:18:25 2022 +0800 [FLINK-29786] VarianceThresholdSelector should implement HasInputCol This closes #167. --- .github/workflows/python-checks.yml | 3 +- .../operators/feature/variancethresholdselector.md | 35 ++++++++++++--------- .../feature/VarianceThresholdSelectorExample.java | 9 +++--- .../flink/ml/common/param/HasFeaturesCol.java | 8 ++++- .../VarianceThresholdSelector.java | 4 +-- .../VarianceThresholdSelectorModel.java | 12 ++++---- .../VarianceThresholdSelectorModelParams.java | 5 ++- .../ml/feature/VarianceThresholdSelectorTest.java | 36 +++++++++------------- .../feature/variancethresholdselector_example.py | 8 ++--- .../tests/test_variancethresholdselector.py | 16 +++++----- .../ml/lib/feature/variancethresholdselector.py | 17 +++++++--- flink-ml-python/pyflink/ml/lib/param.py | 3 ++ 12 files changed, 87 insertions(+), 69 deletions(-) diff --git a/.github/workflows/python-checks.yml b/.github/workflows/python-checks.yml index b98644b..7f41ac8 100644 --- a/.github/workflows/python-checks.yml +++ b/.github/workflows/python-checks.yml @@ -56,7 +56,8 @@ jobs: - name: Test the source code working-directory: flink-ml-python run: | - pytest pyflink/ml + pytest pyflink/ml/lib/feature + pytest pyflink/ml --ignore=pyflink/ml/lib/feature/tests/ pytest pyflink/examples diff --git a/docs/content/docs/operators/feature/variancethresholdselector.md b/docs/content/docs/operators/feature/variancethresholdselector.md index bc839ab..e8dccd9 100644 --- a/docs/content/docs/operators/feature/variancethresholdselector.md +++ b/docs/content/docs/operators/feature/variancethresholdselector.md @@ -34,9 +34,9 @@ variance 0 (i.e. features that have the same value in all samples) will be remov ### Input Columns -| Param name | Type | Default | Description | -|:------------|:-------|:-------------|:----------------| -| featuresCol | Vector | `"features"` | Feature vector. | +| Param name | Type | Default | Description | +|:------------|:-------|:----------|:----------------| +| inputCol | Vector | `"input"` | Input features. | ### Output Columns @@ -46,12 +46,20 @@ variance 0 (i.e. features that have the same value in all samples) will be remov ### Parameters +Below are the parameters required by `VarianceThresholdSelectorModel`. + +| Key | Default | Type | Required | Description | +|------------|------------|--------|----------|-----------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | + +`VarianceThresholdSelector` needs parameters above and also below. + | Key | Default | Type | Required | Description | |-------------------|--------------|--------|----------|---------------------------------------------------------------------------| -| featuresCol | `"features"` | String | no | Features column name. | -| outputCol | `"output"` | String | no | Output column name. | | varianceThreshold | `0.0` | Double | no | Features with a variance not greater than this threshold will be removed. | + ### Examples {{< tabs examples >}} @@ -89,14 +97,14 @@ public class VarianceThresholdSelectorExample { Row.of(4, Vectors.dense(1.0, 9.0, 8.0, 5.0, 7.0, 4.0)), Row.of(5, Vectors.dense(9.0, 8.0, 6.0, 5.0, 4.0, 4.0)), Row.of(6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0, 0.0))); - Table trainTable = tEnv.fromDataStream(trainStream).as("id", "features"); + Table trainTable = tEnv.fromDataStream(trainStream).as("id", "input"); // Create a VarianceThresholdSelector object and initialize its parameters double threshold = 8.0; VarianceThresholdSelector varianceThresholdSelector = new VarianceThresholdSelector() .setVarianceThreshold(threshold) - .setFeaturesCol("features"); + .setInputCol("input"); // Train the VarianceThresholdSelector model. VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainTable); @@ -109,11 +117,10 @@ public class VarianceThresholdSelectorExample { for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector inputValue = - (DenseVector) row.getField(varianceThresholdSelector.getFeaturesCol()); + (DenseVector) row.getField(varianceThresholdSelector.getInputCol()); DenseVector outputValue = (DenseVector) row.getField(varianceThresholdSelector.getOutputCol()); - System.out.printf( - "Original Features: %-15s\tSelected Features: %s\n", inputValue, outputValue); + System.out.printf("Input Values: %-15s\tOutput Values: %s\n", inputValue, outputValue); } } } @@ -151,14 +158,14 @@ train_data = t_env.from_data_stream( (6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0, 0.0),), ], type_info=Types.ROW_NAMED( - ['id', 'features'], + ['id', 'input'], [Types.INT(), DenseVectorTypeInfo()]) )) # create a VarianceThresholdSelector object and initialize its parameters threshold = 8.0 variance_thread_selector = VarianceThresholdSelector()\ - .set_features_col("features")\ + .set_input_col("input")\ .set_variance_threshold(threshold) # train the VarianceThresholdSelector model @@ -171,9 +178,9 @@ output = model.transform(train_data)[0] print("Variance Threshold: " + str(threshold)) field_names = output.get_schema().get_field_names() for result in t_env.to_data_stream(output).execute_and_collect(): - input_value = result[field_names.index(variance_thread_selector.get_features_col())] + input_value = result[field_names.index(variance_thread_selector.get_input_col())] output_value = result[field_names.index(variance_thread_selector.get_output_col())] - print('Original Features: ' + str(input_value) + ' \tSelected Features: ' + str(output_value)) + print('Input Values: ' + str(input_value) + ' \tOutput Values: ' + str(output_value)) ``` diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java index 9783fd8..d441a3b 100644 --- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VarianceThresholdSelectorExample.java @@ -48,14 +48,14 @@ public class VarianceThresholdSelectorExample { Row.of(4, Vectors.dense(1.0, 9.0, 8.0, 5.0, 7.0, 4.0)), Row.of(5, Vectors.dense(9.0, 8.0, 6.0, 5.0, 4.0, 4.0)), Row.of(6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0, 0.0))); - Table trainTable = tEnv.fromDataStream(trainStream).as("id", "features"); + Table trainTable = tEnv.fromDataStream(trainStream).as("id", "input"); // Create a VarianceThresholdSelector object and initialize its parameters double threshold = 8.0; VarianceThresholdSelector varianceThresholdSelector = new VarianceThresholdSelector() .setVarianceThreshold(threshold) - .setFeaturesCol("features"); + .setInputCol("input"); // Train the VarianceThresholdSelector model. VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainTable); @@ -68,11 +68,10 @@ public class VarianceThresholdSelectorExample { for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) { Row row = it.next(); DenseVector inputValue = - (DenseVector) row.getField(varianceThresholdSelector.getFeaturesCol()); + (DenseVector) row.getField(varianceThresholdSelector.getInputCol()); DenseVector outputValue = (DenseVector) row.getField(varianceThresholdSelector.getOutputCol()); - System.out.printf( - "Original Features: %-15s\tSelected Features: %s\n", inputValue, outputValue); + System.out.printf("Input Values: %-15s\tOutput Values: %s\n", inputValue, outputValue); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCol.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCol.java index d6909ce..8096823 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCol.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCol.java @@ -18,12 +18,18 @@ package org.apache.flink.ml.common.param; +import org.apache.flink.ml.api.Stage; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.param.ParamValidators; import org.apache.flink.ml.param.StringParam; import org.apache.flink.ml.param.WithParams; -/** Interface for the shared featuresCol param. */ +/** + * Interface for the shared featuresCol param. + * + * <p>{@link HasFeaturesCol} is typically used for {@link Stage}s that implement {@link + * HasLabelCol}. It is preferred to use {@link HasInputCol} for other cases. + */ public interface HasFeaturesCol<T> extends WithParams<T> { Param<String> FEATURES_COL = new StringParam( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java index 90db8fc..53944b6 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java @@ -60,14 +60,14 @@ public class VarianceThresholdSelector @Override public VarianceThresholdSelectorModel fit(Table... inputs) { Preconditions.checkArgument(inputs.length == 1); - final String featuresCol = getFeaturesCol(); + final String inputCol = getInputCol(); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream<DenseVector> inputData = tEnv.toDataStream(inputs[0]) .map( (MapFunction<Row, DenseVector>) - value -> ((Vector) value.getField(featuresCol)).toDense()); + value -> ((Vector) value.getField(inputCol)).toDense()); DataStream<VarianceThresholdSelectorModelData> modelData = DataStreamUtils.aggregate( diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java index 62c4909..a13dd06 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java @@ -119,7 +119,7 @@ public class VarianceThresholdSelectorModel inputList -> { DataStream input = inputList.get(0); return input.map( - new PredictOutputFunction(getFeaturesCol(), broadcastModelKey), + new PredictOutputFunction(getInputCol(), broadcastModelKey), outputTypeInfo); }); @@ -129,13 +129,13 @@ public class VarianceThresholdSelectorModel /** This operator loads model data and predicts result. */ private static class PredictOutputFunction extends RichMapFunction<Row, Row> { - private final String featureCol; + private final String inputCol; private final String broadcastKey; private int expectedNumOfFeatures; private int[] indices = null; - public PredictOutputFunction(String featureCol, String broadcastKey) { - this.featureCol = featureCol; + public PredictOutputFunction(String inputCol, String broadcastKey) { + this.inputCol = inputCol; this.broadcastKey = broadcastKey; } @@ -149,11 +149,11 @@ public class VarianceThresholdSelectorModel indices = varianceThresholdSelectorModelData.indices; } - DenseVector inputVec = ((Vector) row.getField(featureCol)).toDense(); + DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense(); Preconditions.checkArgument( inputVec.size() == expectedNumOfFeatures, "%s has %s features, but VarianceThresholdSelector is expecting %s features as input.", - featureCol, + inputCol, inputVec.size(), expectedNumOfFeatures); if (indices.length == 0) { diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModelParams.java index 8892257..5b6363d 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModelParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModelParams.java @@ -18,7 +18,7 @@ package org.apache.flink.ml.feature.variancethresholdselector; -import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasInputCol; import org.apache.flink.ml.common.param.HasOutputCol; /** @@ -26,5 +26,4 @@ import org.apache.flink.ml.common.param.HasOutputCol; * * @param <T> The class type of this instance. */ -public interface VarianceThresholdSelectorModelParams<T> - extends HasFeaturesCol<T>, HasOutputCol<T> {} +public interface VarianceThresholdSelectorModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java index e4e841c..1230780 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java @@ -93,8 +93,8 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { env.setRestartStrategy(RestartStrategies.noRestart()); tEnv = StreamTableEnvironment.create(env); - trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("id", "features"); - predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("features"); + trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("id", "input"); + predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("input"); } private static void verifyPredictionResult( @@ -113,15 +113,15 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { @Test public void testParam() { VarianceThresholdSelector varianceThresholdSelector = new VarianceThresholdSelector(); - assertEquals("features", varianceThresholdSelector.getFeaturesCol()); + assertEquals("input", varianceThresholdSelector.getInputCol()); assertEquals("output", varianceThresholdSelector.getOutputCol()); assertEquals(0.0, varianceThresholdSelector.getVarianceThreshold(), EPS); varianceThresholdSelector - .setFeaturesCol("test_feature") + .setInputCol("test_input") .setOutputCol("test_output") .setVarianceThreshold(0.5); - assertEquals("test_feature", varianceThresholdSelector.getFeaturesCol()); + assertEquals("test_input", varianceThresholdSelector.getInputCol()); assertEquals(0.5, varianceThresholdSelector.getVarianceThreshold(), EPS); assertEquals("test_output", varianceThresholdSelector.getOutputCol()); } @@ -129,11 +129,11 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { @Test public void testOutputSchema() { VarianceThresholdSelector varianceThresholdSelector = - new VarianceThresholdSelector().setOutputCol("output").setVarianceThreshold(0.5); + new VarianceThresholdSelector().setVarianceThreshold(0.5); VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainDataTable); Table output = model.transform(trainDataTable)[0]; assertEquals( - Arrays.asList("id", "features", "output"), + Arrays.asList("id", "input", "output"), output.getResolvedSchema().getColumnNames()); } @@ -163,7 +163,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { public void testInputTypeConversion() throws Exception { trainDataTable = TestUtils.convertDataTypesToSparseInt( - tEnv, trainDataTable.select(Expressions.$("features"))); + tEnv, trainDataTable.select(Expressions.$("input"))); predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictDataTable); assertArrayEquals( new Class<?>[] {SparseVector.class}, TestUtils.getColumnDataTypes(trainDataTable)); @@ -172,9 +172,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { TestUtils.getColumnDataTypes(predictDataTable)); VarianceThresholdSelector varianceThresholdSelector = - new VarianceThresholdSelector() - .setFeaturesCol("features") - .setVarianceThreshold(8.0); + new VarianceThresholdSelector().setVarianceThreshold(8.0); VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainDataTable); Table output = model.transform(predictDataTable)[0]; verifyPredictionResult(output, varianceThresholdSelector.getOutputCol(), EXPECTED_OUTPUT); @@ -183,9 +181,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { @Test public void testSaveLoadAndPredict() throws Exception { VarianceThresholdSelector varianceThresholdSelector = - new VarianceThresholdSelector() - .setFeaturesCol("features") - .setVarianceThreshold(8.0); + new VarianceThresholdSelector().setVarianceThreshold(8.0); VarianceThresholdSelector loadedVarianceThresholdSelector = TestUtils.saveAndReload( tEnv, varianceThresholdSelector, tempFolder.newFolder().getAbsolutePath()); @@ -203,7 +199,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { public void testFitOnEmptyData() { Table emptyTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA).filter(x -> x.getArity() == 0)) - .as("id", "features"); + .as("id", "input"); VarianceThresholdSelector varianceThresholdSelector = new VarianceThresholdSelector(); VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(emptyTable); Table modelDataTable = model.getModelData()[0]; @@ -227,7 +223,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { Arrays.asList( Row.of(Vectors.dense(1.0, 2.0, 3.0, 4.0)), Row.of(Vectors.dense(0.1, 0.2, 0.3, 0.4)))); - Table predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); + Table predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input"); Table output = model.transform(predictTable)[0]; try { output.execute().print(); @@ -243,9 +239,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { @Test public void testGetModelData() throws Exception { VarianceThresholdSelector varianceThresholdSelector = - new VarianceThresholdSelector() - .setFeaturesCol("features") - .setVarianceThreshold(8.0); + new VarianceThresholdSelector().setVarianceThreshold(8.0); VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainDataTable); Table modelData = model.getModelData()[0]; assertEquals( @@ -265,9 +259,7 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase { @Test public void testSetModelData() throws Exception { VarianceThresholdSelector varianceThresholdSelector = - new VarianceThresholdSelector() - .setFeaturesCol("features") - .setVarianceThreshold(8.0); + new VarianceThresholdSelector().setVarianceThreshold(8.0); VarianceThresholdSelectorModel modelA = varianceThresholdSelector.fit(trainDataTable); Table modelData = modelA.getModelData()[0]; diff --git a/flink-ml-python/pyflink/examples/ml/feature/variancethresholdselector_example.py b/flink-ml-python/pyflink/examples/ml/feature/variancethresholdselector_example.py index 3cf7ad6..e49ffba 100644 --- a/flink-ml-python/pyflink/examples/ml/feature/variancethresholdselector_example.py +++ b/flink-ml-python/pyflink/examples/ml/feature/variancethresholdselector_example.py @@ -42,14 +42,14 @@ train_data = t_env.from_data_stream( (6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0, 0.0),), ], type_info=Types.ROW_NAMED( - ['id', 'features'], + ['id', 'input'], [Types.INT(), DenseVectorTypeInfo()]) )) # create a VarianceThresholdSelector object and initialize its parameters threshold = 8.0 variance_thread_selector = VarianceThresholdSelector()\ - .set_features_col("features")\ + .set_input_col("input")\ .set_variance_threshold(threshold) # train the VarianceThresholdSelector model @@ -62,6 +62,6 @@ output = model.transform(train_data)[0] print("Variance Threshold: " + str(threshold)) field_names = output.get_schema().get_field_names() for result in t_env.to_data_stream(output).execute_and_collect(): - input_value = result[field_names.index(variance_thread_selector.get_features_col())] + input_value = result[field_names.index(variance_thread_selector.get_input_col())] output_value = result[field_names.index(variance_thread_selector.get_output_col())] - print('Original Features: ' + str(input_value) + ' \tSelected Features: ' + str(output_value)) + print('Input Values: ' + str(input_value) + ' \tOutput Values: ' + str(output_value)) diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py index 854b9bd..b7a49a4 100644 --- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py +++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py @@ -38,7 +38,7 @@ class VarianceThresholdSelectorTest(PyFlinkMLTestCase): (6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0, 0.0),), ], type_info=Types.ROW_NAMED( - ['id', 'features'], + ['id', 'input'], [Types.INT(), DenseVectorTypeInfo()]) )) @@ -48,7 +48,7 @@ class VarianceThresholdSelectorTest(PyFlinkMLTestCase): (Vectors.dense(0.1, 0.2, 0.3, 0.4, 0.5, 0.6),), ], type_info=Types.ROW_NAMED( - ['features'], + ['input'], [DenseVectorTypeInfo()]) )) self.expected_output = [ @@ -57,13 +57,15 @@ class VarianceThresholdSelectorTest(PyFlinkMLTestCase): def test_param(self): variance_threshold_selector = VarianceThresholdSelector() - self.assertEqual("features", variance_threshold_selector.features_col) + self.assertEqual("input", variance_threshold_selector.input_col) self.assertEqual("output", variance_threshold_selector.output_col) self.assertEqual(0.0, variance_threshold_selector.variance_threshold) - variance_threshold_selector.set_output_col("test_output")\ - .set_variance_threshold(8.0) - self.assertEqual("features", variance_threshold_selector.features_col) + variance_threshold_selector.\ + set_input_col("test_input").\ + set_output_col("test_output").\ + set_variance_threshold(8.0) + self.assertEqual("test_input", variance_threshold_selector.input_col) self.assertEqual("test_output", variance_threshold_selector.output_col) self.assertEqual(8.0, variance_threshold_selector.variance_threshold) @@ -99,7 +101,7 @@ class VarianceThresholdSelectorTest(PyFlinkMLTestCase): (Vectors.dense(0.1, 0.2, 0.3, 0.4, 0.5),), ], type_info=Types.ROW_NAMED( - ['features'], + ['input'], [DenseVectorTypeInfo()]) )) with self.assertRaisesRegex(Exception, 'but VarianceThresholdSelector is expecting'): diff --git a/flink-ml-python/pyflink/ml/lib/feature/variancethresholdselector.py b/flink-ml-python/pyflink/ml/lib/feature/variancethresholdselector.py index 23b2c6e..4c39e2d 100644 --- a/flink-ml-python/pyflink/ml/lib/feature/variancethresholdselector.py +++ b/flink-ml-python/pyflink/ml/lib/feature/variancethresholdselector.py @@ -20,14 +20,23 @@ import typing from pyflink.ml.core.param import Param, FloatParam, ParamValidators from pyflink.ml.core.wrapper import JavaWithParams from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator -from pyflink.ml.lib.param import HasFeaturesCol, HasOutputCol +from pyflink.ml.lib.param import HasInputCol, HasOutputCol -class _VarianceThresholdSelectorParams( +class _VarianceThresholdSelectorModelParams( JavaWithParams, - HasFeaturesCol, + HasInputCol, HasOutputCol ): + """ + Params for :class:`VarianceThresholdSelectorModel`. + """ + + def __init__(self, java_params): + super(_VarianceThresholdSelectorModelParams, self).__init__(java_params) + + +class _VarianceThresholdSelectorParams(_VarianceThresholdSelectorModelParams): """ Params for :class:`VarianceThresholdSelector`. """ @@ -53,7 +62,7 @@ class _VarianceThresholdSelectorParams( return self.get_variance_threshold() -class VarianceThresholdSelectorModel(JavaFeatureModel, _VarianceThresholdSelectorParams): +class VarianceThresholdSelectorModel(JavaFeatureModel, _VarianceThresholdSelectorModelParams): """ A Model which transforms data using the model data computed by :class:`VarianceThresholdSelector`. diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py index 4cf8e54..7bfbe2e 100644 --- a/flink-ml-python/pyflink/ml/lib/param.py +++ b/flink-ml-python/pyflink/ml/lib/param.py @@ -47,6 +47,9 @@ class HasDistanceMeasure(WithParams, ABC): class HasFeaturesCol(WithParams, ABC): """ Base class for the shared feature_col param. + + `HasFeaturesCol` is typically used for `Stage`s that implement `HasLabelCol`. It is preferred + to use `HasInputCol` for other cases. """ FEATURES_COL: Param[str] = StringParam( "features_col",