This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 3ab3273 [FLINK-25552] Add Estimator and Transformer for MinMaxScaler 3ab3273 is described below commit 3ab327394769d5bd4f88be04428d33320b0928a3 Author: weibo <weibo.z...@alibaba-inc.com> AuthorDate: Mon Mar 21 19:33:03 2022 +0800 [FLINK-25552] Add Estimator and Transformer for MinMaxScaler This closes #54. --- .../flink/ml/classification/knn/KnnModelData.java | 10 +- .../ml/feature/minmaxscaler/MinMaxScaler.java | 203 +++++++++++++++++++ .../ml/feature/minmaxscaler/MinMaxScalerModel.java | 183 +++++++++++++++++ .../minmaxscaler/MinMaxScalerModelData.java} | 69 +++---- .../feature/minmaxscaler/MinMaxScalerParams.java | 62 ++++++ .../apache/flink/ml/feature/MinMaxScalerTest.java | 218 +++++++++++++++++++++ 6 files changed, 699 insertions(+), 46 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java index 89051e6..4bf0adb 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java @@ -82,13 +82,11 @@ public class KnnModelData { /** Encoder for {@link KnnModelData}. */ public static class ModelDataEncoder implements Encoder<KnnModelData> { @Override - public void encode(KnnModelData knnModelData, OutputStream outputStream) - throws IOException { + public void encode(KnnModelData modelData, OutputStream outputStream) throws IOException { DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); - DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView); - DenseVectorSerializer.INSTANCE.serialize( - knnModelData.featureNormSquares, dataOutputView); - DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView); + DenseMatrixSerializer.INSTANCE.serialize(modelData.packedFeatures, dataOutputView); + DenseVectorSerializer.INSTANCE.serialize(modelData.featureNormSquares, dataOutputView); + DenseVectorSerializer.INSTANCE.serialize(modelData.labels, dataOutputView); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java new file mode 100644 index 0000000..19a9f6f --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * An Estimator which implements the MinMaxScaler algorithm. This algorithm rescales feature values + * to a common range [min, max] which defined by user. + * + * <blockquote> + * + * $$ Rescaled(value) = \frac{value - E_{min}}{E_{max} - E_{min}} * (max - min) + min $$ + * + * </blockquote> + * + * <p>For the case \(E_{max} == E_{min}\), \(Rescaled(value) = 0.5 * (max + min)\). + * + * <p>See https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization). + */ +public class MinMaxScaler + implements Estimator<MinMaxScaler, MinMaxScalerModel>, MinMaxScalerParams<MinMaxScaler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public MinMaxScaler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public MinMaxScalerModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + final String featureCol = getFeaturesCol(); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<DenseVector> features = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, DenseVector>) + value -> (DenseVector) value.getField(featureCol)); + DataStream<DenseVector> minMaxValues = + features.transform( + "reduceInEachPartition", + features.getType(), + new MinMaxReduceFunctionOperator()) + .transform( + "reduceInFinalPartition", + features.getType(), + new MinMaxReduceFunctionOperator()) + .setParallelism(1); + DataStream<MinMaxScalerModelData> modelData = + DataStreamUtils.mapPartition( + minMaxValues, + new RichMapPartitionFunction<DenseVector, MinMaxScalerModelData>() { + @Override + public void mapPartition( + Iterable<DenseVector> values, + Collector<MinMaxScalerModelData> out) { + Iterator<DenseVector> iter = values.iterator(); + DenseVector minVector = iter.next(); + DenseVector maxVector = iter.next(); + out.collect(new MinMaxScalerModelData(minVector, maxVector)); + } + }); + + MinMaxScalerModel model = + new MinMaxScalerModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + /** + * A stream operator to compute the min and max values in each partition of the input bounded + * data stream. + */ + private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector> + implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput { + private ListState<DenseVector> minState; + private ListState<DenseVector> maxState; + + private DenseVector minVector; + private DenseVector maxVector; + + @Override + public void endInput() { + if (minVector != null) { + output.collect(new StreamRecord<>(minVector)); + output.collect(new StreamRecord<>(maxVector)); + } + } + + @Override + public void processElement(StreamRecord<DenseVector> streamRecord) { + DenseVector currentValue = streamRecord.getValue(); + if (minVector == null) { + int vecSize = currentValue.size(); + minVector = new DenseVector(vecSize); + maxVector = new DenseVector(vecSize); + System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize); + System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize); + } else { + Preconditions.checkArgument( + currentValue.size() == maxVector.size(), + "CurrentValue should has same size with maxVector."); + for (int i = 0; i < currentValue.size(); ++i) { + minVector.values[i] = Math.min(minVector.values[i], currentValue.values[i]); + maxVector.values[i] = Math.max(maxVector.values[i], currentValue.values[i]); + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + minState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "minState", TypeInformation.of(DenseVector.class))); + maxState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "maxState", TypeInformation.of(DenseVector.class))); + + OperatorStateUtils.getUniqueElement(minState, "minState").ifPresent(x -> minVector = x); + OperatorStateUtils.getUniqueElement(maxState, "maxState").ifPresent(x -> maxVector = x); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + minState.clear(); + maxState.clear(); + if (minVector != null) { + minState.add(minVector); + maxState.add(maxVector); + } + } + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static MinMaxScaler load(StreamExecutionEnvironment env, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java new file mode 100644 index 0000000..762d74a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}. + */ +public class MinMaxScalerModel + implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public MinMaxScalerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public MinMaxScalerModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<MinMaxScalerModelData> minMaxScalerModel = + MinMaxScalerModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + TypeInformation.of(DenseVector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, minMaxScalerModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictOutputFunction( + broadcastModelKey, + getMax(), + getMin(), + getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + MinMaxScalerModelData.getModelDataStream(modelDataTable), + path, + new MinMaxScalerModelData.ModelDataEncoder()); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return MinMaxScalerModel model. + */ + public static MinMaxScalerModel load(StreamExecutionEnvironment env, String path) + throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + MinMaxScalerModel model = ReadWriteUtils.loadStageParam(path); + DataStream<MinMaxScalerModelData> modelData = + ReadWriteUtils.loadModelData( + env, path, new MinMaxScalerModelData.ModelDataDecoder()); + return model.setModelData(tEnv.fromDataStream(modelData)); + } + + /** This operator loads model data and predicts result. */ + private static class PredictOutputFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private final String broadcastKey; + private final double upperBound; + private final double lowerBound; + private DenseVector scaleVector; + private DenseVector offsetVector; + + public PredictOutputFunction( + String broadcastKey, double upperBound, double lowerBound, String featureCol) { + this.upperBound = upperBound; + this.lowerBound = lowerBound; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (scaleVector == null) { + MinMaxScalerModelData minMaxScalerModelData = + (MinMaxScalerModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + DenseVector minVector = minMaxScalerModelData.minVector; + DenseVector maxVector = minMaxScalerModelData.maxVector; + scaleVector = new DenseVector(minVector.size()); + offsetVector = new DenseVector(minVector.size()); + for (int i = 0; i < maxVector.size(); ++i) { + if (Math.abs(minVector.values[i] - maxVector.values[i]) < 1.0e-5) { + scaleVector.values[i] = 0.0; + offsetVector.values[i] = (upperBound + lowerBound) / 2; + } else { + scaleVector.values[i] = + (upperBound - lowerBound) + / (maxVector.values[i] - minVector.values[i]); + offsetVector.values[i] = + lowerBound - minVector.values[i] * scaleVector.values[i]; + } + } + } + DenseVector feature = (DenseVector) row.getField(featureCol); + DenseVector outputVector = new DenseVector(scaleVector.size()); + for (int i = 0; i < scaleVector.size(); ++i) { + outputVector.values[i] = + feature.values[i] * scaleVector.values[i] + offsetVector.values[i]; + } + return Row.join(row, Row.of(outputVector)); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java similarity index 56% copy from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java index 89051e6..301eadd 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.ml.classification.knn; +package org.apache.flink.ml.feature.minmaxscaler; import org.apache.flink.api.common.serialization.Encoder; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -27,9 +27,7 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.ml.linalg.DenseMatrix; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.table.api.Table; @@ -41,24 +39,21 @@ import java.io.IOException; import java.io.OutputStream; /** - * Model data of {@link KnnModel}. + * Model data of {@link MinMaxScalerModel}. * * <p>This class also provides methods to convert model data from Table to a data stream, and * classes to save/load model data. */ -public class KnnModelData { +public class MinMaxScalerModelData { + public DenseVector minVector; - public DenseMatrix packedFeatures; - public DenseVector featureNormSquares; - public DenseVector labels; + public DenseVector maxVector; - public KnnModelData() {} + public MinMaxScalerModelData() {} - public KnnModelData( - DenseMatrix packedFeatures, DenseVector featureNormSquares, DenseVector labels) { - this.packedFeatures = packedFeatures; - this.featureNormSquares = featureNormSquares; - this.labels = labels; + public MinMaxScalerModelData(DenseVector minVector, DenseVector maxVector) { + this.minVector = minVector; + this.maxVector = maxVector; } /** @@ -67,47 +62,41 @@ public class KnnModelData { * @param modelDataTable The table model data. * @return The data stream model data. */ - public static DataStream<KnnModelData> getModelDataStream(Table modelDataTable) { + public static DataStream<MinMaxScalerModelData> getModelDataStream(Table modelDataTable) { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); return tEnv.toDataStream(modelDataTable) .map( x -> - new KnnModelData( - (DenseMatrix) x.getField(0), - (DenseVector) x.getField(1), - (DenseVector) x.getField(2))); + new MinMaxScalerModelData( + (DenseVector) x.getField(0), (DenseVector) x.getField(1))); } - /** Encoder for {@link KnnModelData}. */ - public static class ModelDataEncoder implements Encoder<KnnModelData> { + /** Encoder for {@link MinMaxScalerModelData}. */ + public static class ModelDataEncoder implements Encoder<MinMaxScalerModelData> { @Override - public void encode(KnnModelData knnModelData, OutputStream outputStream) + public void encode(MinMaxScalerModelData modelData, OutputStream outputStream) throws IOException { DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); - DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView); - DenseVectorSerializer.INSTANCE.serialize( - knnModelData.featureNormSquares, dataOutputView); - DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView); + DenseVectorSerializer.INSTANCE.serialize(modelData.minVector, dataOutputView); + DenseVectorSerializer.INSTANCE.serialize(modelData.maxVector, dataOutputView); } } - /** Decoder for {@link KnnModelData}. */ - public static class ModelDataDecoder extends SimpleStreamFormat<KnnModelData> { + /** Decoder for {@link MinMaxScalerModelData}. */ + public static class ModelDataDecoder extends SimpleStreamFormat<MinMaxScalerModelData> { @Override - public Reader<KnnModelData> createReader(Configuration config, FSDataInputStream stream) { - return new Reader<KnnModelData>() { - - private final DataInputView source = new DataInputViewStreamWrapper(stream); + public Reader<MinMaxScalerModelData> createReader( + Configuration config, FSDataInputStream stream) { + return new Reader<MinMaxScalerModelData>() { @Override - public KnnModelData read() throws IOException { + public MinMaxScalerModelData read() throws IOException { + DataInputView source = new DataInputViewStreamWrapper(stream); try { - DenseMatrix matrix = DenseMatrixSerializer.INSTANCE.deserialize(source); - DenseVector normSquares = - DenseVectorSerializer.INSTANCE.deserialize(source); - DenseVector labels = DenseVectorSerializer.INSTANCE.deserialize(source); - return new KnnModelData(matrix, normSquares, labels); + DenseVector minVector = DenseVectorSerializer.INSTANCE.deserialize(source); + DenseVector maxVector = DenseVectorSerializer.INSTANCE.deserialize(source); + return new MinMaxScalerModelData(minVector, maxVector); } catch (EOFException e) { return null; } @@ -121,8 +110,8 @@ public class KnnModelData { } @Override - public TypeInformation<KnnModelData> getProducedType() { - return TypeInformation.of(KnnModelData.class); + public TypeInformation<MinMaxScalerModelData> getProducedType() { + return TypeInformation.of(MinMaxScalerModelData.class); } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java new file mode 100644 index 0000000..aade500 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params for {@link MinMaxScaler}. + * + * @param <T> The class type of this instance. + */ +public interface MinMaxScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> { + Param<Double> MIN = + new DoubleParam( + "min", + "Lower bound of the output feature range.", + 0.0, + ParamValidators.notNull()); + + default Double getMin() { + return get(MIN); + } + + default T setMin(Double value) { + return set(MIN, value); + } + + Param<Double> MAX = + new DoubleParam( + "max", + "Upper bound of the output feature range.", + 1.0, + ParamValidators.notNull()); + + default Double getMax() { + return get(MAX); + } + + default T setMax(Double value) { + return set(MAX, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java new file mode 100644 index 0000000..24ec885 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */ +public class MinMaxScalerTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainDataTable; + private Table predictDataTable; + private static final List<Row> TRAIN_DATA = + new ArrayList<>( + Arrays.asList( + Row.of(Vectors.dense(0.0, 3.0)), + Row.of(Vectors.dense(2.1, 0.0)), + Row.of(Vectors.dense(4.1, 5.1)), + Row.of(Vectors.dense(6.1, 8.1)), + Row.of(Vectors.dense(200, 400)))); + private static final List<Row> PREDICT_DATA = + new ArrayList<>( + Arrays.asList( + Row.of(Vectors.dense(150.0, 90.0)), + Row.of(Vectors.dense(50.0, 40.0)), + Row.of(Vectors.dense(100.0, 50.0)))); + private static final double EPS = 1.0e-5; + private static final List<DenseVector> EXPECTED_DATA = + new ArrayList<>( + Arrays.asList( + Vectors.dense(0.25, 0.1), + Vectors.dense(0.5, 0.125), + Vectors.dense(0.75, 0.225))); + + /** Note: this comparator imposes orderings that are inconsistent with equals. */ + private static int compare(DenseVector first, DenseVector second) { + for (int i = 0; i < first.size(); i++) { + int cmp = Double.compare(first.get(i), second.get(i)); + if (cmp != 0) { + return cmp; + } + } + return 0; + } + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("features"); + predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("features"); + } + + private static void verifyPredictionResult( + Table output, String outputCol, List<DenseVector> expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + DataStream<DenseVector> stream = + tEnv.toDataStream(output) + .map( + (MapFunction<Row, DenseVector>) + row -> (DenseVector) row.getField(outputCol)); + List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect()); + result.sort(MinMaxScalerTest::compare); + assertEquals(expected, result); + } + + @Test + public void testParam() { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + assertEquals("features", minMaxScaler.getFeaturesCol()); + assertEquals("prediction", minMaxScaler.getPredictionCol()); + assertEquals(0.0, minMaxScaler.getMin(), EPS); + assertEquals(1.0, minMaxScaler.getMax(), EPS); + minMaxScaler + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMin(1.0) + .setMax(4.0); + assertEquals("test_features", minMaxScaler.getFeaturesCol()); + assertEquals(1.0, minMaxScaler.getMin(), EPS); + assertEquals(4.0, minMaxScaler.getMax(), EPS); + assertEquals("test_output", minMaxScaler.getPredictionCol()); + } + + @Test + public void testFeaturePredictionParam() { + MinMaxScaler minMaxScaler = + new MinMaxScaler() + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMin(1.0) + .setMax(4.0); + + MinMaxScalerModel model = minMaxScaler.fit(trainDataTable.as("test_features")); + Table output = model.transform(predictDataTable.as("test_features"))[0]; + assertEquals( + Arrays.asList("test_features", "test_output"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testMaxValueEqualsMinValueButPredictValueNotEquals() throws Exception { + List<Row> trainData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(40.0, 80.0)))); + Table trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features"); + List<Row> predictData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(30.0, 50.0)))); + Table predictDataTable = + tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); + MinMaxScaler minMaxScaler = new MinMaxScaler().setMax(10.0).setMin(0.0); + MinMaxScalerModel model = minMaxScaler.fit(trainTable); + Table result = model.transform(predictDataTable)[0]; + verifyPredictionResult( + result, + minMaxScaler.getPredictionCol(), + Collections.singletonList(Vectors.dense(5.0, 5.0))); + } + + @Test + public void testFitAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table output = minMaxScalerModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScaler loadedMinMaxScaler = + StageTestUtils.saveAndReload( + env, minMaxScaler, tempFolder.newFolder().getAbsolutePath()); + MinMaxScalerModel model = loadedMinMaxScaler.fit(trainDataTable); + MinMaxScalerModel loadedModel = + StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath()); + assertEquals( + Arrays.asList("minVector", "maxVector"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = loadedModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA); + } + + @Test + public void testGetModelData() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table modelData = minMaxScalerModel.getModelData()[0]; + assertEquals( + Arrays.asList("minVector", "maxVector"), + modelData.getResolvedSchema().getColumnNames()); + DataStream<Row> output = tEnv.toDataStream(modelData); + List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect()); + assertEquals(new DenseVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0)); + assertEquals(new DenseVector(new double[] {200.0, 400.0}), modelRows.get(0).getField(1)); + } + + @Test + public void testSetModelData() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel modelA = minMaxScaler.fit(trainDataTable); + Table modelData = modelA.getModelData()[0]; + MinMaxScalerModel modelB = new MinMaxScalerModel().setModelData(modelData); + ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap()); + Table output = modelB.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA); + } +}