This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 4bd9083 [FLINK-27091] Add Transformer and Estimator of LinearSVC 4bd9083 is described below commit 4bd908312825e3fe4fd7566b15626ad766ebabc5 Author: Zhipeng Zhang <zhangzhipe...@gmail.com> AuthorDate: Sat May 7 15:53:35 2022 +0800 [FLINK-27091] Add Transformer and Estimator of LinearSVC This closes #93. --- .../ml/classification/linearsvc/LinearSVC.java | 123 ++++++++++++ .../classification/linearsvc/LinearSVCModel.java | 174 ++++++++++++++++ .../linearsvc/LinearSVCModelData.java | 111 +++++++++++ .../linearsvc/LinearSVCModelParams.java | 55 +++++ .../classification/linearsvc/LinearSVCParams.java | 44 ++++ .../apache/flink/ml/common/lossfunc/HingeLoss.java | 58 ++++++ ...isticRegressionTest.java => LinearSVCTest.java} | 222 ++++++++++----------- .../ml/classification/LogisticRegressionTest.java | 8 +- .../flink/ml/common/lossfunc/HingeLossTest.java | 57 ++++++ .../flink/ml/regression/LinearRegressionTest.java | 8 +- 10 files changed, 731 insertions(+), 129 deletions(-) diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java new file mode 100644 index 0000000..4169d48 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.HingeLoss; +import org.apache.flink.ml.common.optimizer.Optimizer; +import org.apache.flink.ml.common.optimizer.SGD; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the linear support vector classification. + * + * <p>See https://en.wikipedia.org/wiki/Support-vector_machine#Linear_SVM. + */ +public class LinearSVC implements Estimator<LinearSVC, LinearSVCModel>, LinearSVCParams<LinearSVC> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public LinearSVC() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings({"rawTypes", "ConstantConditions"}) + public LinearSVCModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<LabeledPointWithWeight> trainData = + tEnv.toDataStream(inputs[0]) + .map( + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : (Double) dataPoint.getField(getWeightCol()); + double label = (Double) dataPoint.getField(getLabelCol()); + Preconditions.checkState( + Double.compare(0.0, label) == 0 + || Double.compare(1.0, label) == 0, + "LinearSVC only supports binary classification. But detected label: %s.", + label); + DenseVector features = + (DenseVector) dataPoint.getField(getFeaturesCol()); + return new LabeledPointWithWeight(features, label, weight); + }); + + DataStream<DenseVector> initModelData = + DataStreamUtils.reduce( + trainData.map(x -> x.getFeatures().size()), + (ReduceFunction<Integer>) + (t0, t1) -> { + Preconditions.checkState( + t0.equals(t1), + "The training data should all have same dimensions."); + return t0; + }) + .map(DenseVector::new); + + Optimizer optimizer = + new SGD( + getMaxIter(), + getLearningRate(), + getGlobalBatchSize(), + getTol(), + getReg(), + getElasticNet()); + DataStream<DenseVector> rawModelData = + optimizer.optimize(initModelData, trainData, HingeLoss.INSTANCE); + + DataStream<LinearSVCModelData> modelData = rawModelData.map(LinearSVCModelData::new); + LinearSVCModel model = new LinearSVCModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LinearSVC load(StreamTableEnvironment tEnv, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java new file mode 100644 index 0000000..253bbcc --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** A Model which classifies data using the model data computed by {@link LinearSVC}. */ +public class LinearSVCModel implements Model<LinearSVCModel>, LinearSVCModelParams<LinearSVCModel> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + private Table modelDataTable; + + public LinearSVCModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]); + + final String broadcastModelKey = "broadcastModelKey"; + DataStream<LinearSVCModelData> modelDataStream = + LinearSVCModelData.getModelDataStream(modelDataTable); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + BasicTypeInfo.DOUBLE_TYPE_INFO, + DenseVectorTypeInfo.INSTANCE), + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), + getPredictionCol(), + getRawPredictionCol())); + + DataStream<Row> predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + DataStream inputData = inputList.get(0); + return inputData.map( + new PredictLabelFunction( + broadcastModelKey, getFeaturesCol(), getThreshold()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + @Override + public LinearSVCModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + LinearSVCModelData.getModelDataStream(modelDataTable), + path, + new LinearSVCModelData.ModelDataEncoder()); + } + + public static LinearSVCModel load(StreamTableEnvironment tEnv, String path) throws IOException { + LinearSVCModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData(tEnv, path, new LinearSVCModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + + private final String broadcastModelKey; + + private final String featuresCol; + + private final double threshold; + + private DenseVector coefficient; + + public PredictLabelFunction( + String broadcastModelKey, String featuresCol, double threshold) { + this.broadcastModelKey = broadcastModelKey; + this.featuresCol = featuresCol; + this.threshold = threshold; + } + + @Override + public Row map(Row dataPoint) { + if (coefficient == null) { + LinearSVCModelData modelData = + (LinearSVCModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + coefficient = modelData.coefficient; + } + DenseVector features = (DenseVector) dataPoint.getField(featuresCol); + Row predictionResult = predictOneDataPoint(features, coefficient, threshold); + return Row.join(dataPoint, predictionResult); + } + } + + /** + * The main logic that predicts one input data point. + * + * @param feature The input feature. + * @param coefficient The model parameters. + * @param threshold The threshold for prediction. + * @return The prediction label and the raw predictions. + */ + private static Row predictOneDataPoint( + DenseVector feature, DenseVector coefficient, double threshold) { + double dotValue = BLAS.dot(feature, coefficient); + return Row.of(dotValue >= threshold ? 1.0 : 0.0, Vectors.dense(dotValue, -dotValue)); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java new file mode 100644 index 0000000..96e8a27 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LinearSVCModel}. + * + * <p>This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LinearSVCModelData { + + public DenseVector coefficient; + + public LinearSVCModelData(DenseVector coefficient) { + this.coefficient = coefficient; + } + + public LinearSVCModelData() {} + + /** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ + public static DataStream<LinearSVCModelData> getModelDataStream(Table modelData) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); + return tEnv.toDataStream(modelData) + .map(x -> new LinearSVCModelData((DenseVector) x.getField(0))); + } + + /** Data encoder for {@link LinearSVCModel}. */ + public static class ModelDataEncoder implements Encoder<LinearSVCModelData> { + + @Override + public void encode(LinearSVCModelData modelData, OutputStream outputStream) + throws IOException { + DenseVectorSerializer.INSTANCE.serialize( + modelData.coefficient, new DataOutputViewStreamWrapper(outputStream)); + } + } + + /** Data decoder for {@link LinearSVCModel}. */ + public static class ModelDataDecoder extends SimpleStreamFormat<LinearSVCModelData> { + + @Override + public Reader<LinearSVCModelData> createReader( + Configuration configuration, FSDataInputStream inputStream) { + return new Reader<LinearSVCModelData>() { + + @Override + public LinearSVCModelData read() throws IOException { + try { + DenseVector coefficient = + DenseVectorSerializer.INSTANCE.deserialize( + new DataInputViewStreamWrapper(inputStream)); + return new LinearSVCModelData(coefficient); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + inputStream.close(); + } + }; + } + + @Override + public TypeInformation<LinearSVCModelData> getProducedType() { + return TypeInformation.of(LinearSVCModelData.class); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java new file mode 100644 index 0000000..9e02233 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params for {@link LinearSVCModel}. + * + * @param <T> The class type of this instance. + */ +public interface LinearSVCModelParams<T> + extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> { + /** + * Param for threshold in linear support vector classifier. It applies to the rawPrediction and + * can be any real number, where Inf makes all predictions 0.0 and -Inf makes all predictions + * 1.0. + */ + Param<Double> THRESHOLD = + new DoubleParam( + "threshold", + "Threshold in binary classification prediction applied to rawPrediction.", + 0.0, + ParamValidators.notNull()); + + default Double getThreshold() { + return get(THRESHOLD); + } + + default T setThreshold(Double value) { + set(THRESHOLD, value); + return (T) this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java new file mode 100644 index 0000000..7754e89 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linearsvc; + +import org.apache.flink.ml.common.param.HasElasticNet; +import org.apache.flink.ml.common.param.HasGlobalBatchSize; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasLearningRate; +import org.apache.flink.ml.common.param.HasMaxIter; +import org.apache.flink.ml.common.param.HasReg; +import org.apache.flink.ml.common.param.HasTol; +import org.apache.flink.ml.common.param.HasWeightCol; + +/** + * Params for {@link LinearSVC}. + * + * @param <T> The class type of this instance. + */ +public interface LinearSVCParams<T> + extends HasLabelCol<T>, + HasWeightCol<T>, + HasMaxIter<T>, + HasReg<T>, + HasElasticNet<T>, + HasLearningRate<T>, + HasGlobalBatchSize<T>, + HasTol<T>, + LinearSVCModelParams<T> {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java new file mode 100644 index 0000000..eb0f3bf --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.classification.linearsvc.LinearSVC; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; + +/** + * The loss function for hinge loss. See {@link LinearSVC} for example. + * + * <p>See https://en.wikipedia.org/wiki/Hinge_loss. + */ +@Internal +public class HingeLoss implements LossFunc { + public static final HingeLoss INSTANCE = new HingeLoss(); + + private HingeLoss() {} + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + if (1 - labelScaled * dot > 0) { + BLAS.axpy( + -labelScaled * dataPoint.getWeight(), + dataPoint.getFeatures(), + cumGradient, + dataPoint.getFeatures().size()); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java similarity index 52% copy from flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java copy to flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java index 8d06613..156244e 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java @@ -23,9 +23,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.configuration.Configuration; -import org.apache.flink.ml.classification.logisticregression.LogisticRegression; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; -import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.linearsvc.LinearSVC; +import org.apache.flink.ml.classification.linearsvc.LinearSVCModel; +import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; @@ -39,7 +39,6 @@ import org.apache.flink.types.Row; import org.apache.commons.collections.IteratorUtils; import org.apache.commons.lang3.RandomUtils; -import org.apache.commons.lang3.exception.ExceptionUtils; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -53,10 +52,9 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; -/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ -public class LogisticRegressionTest { +/** Tests {@link LinearSVC} and {@link LinearSVCModel}. */ +public class LinearSVCTest { @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); @@ -64,7 +62,7 @@ public class LogisticRegressionTest { private StreamTableEnvironment tEnv; - private static final List<Row> binomialTrainData = + private static final List<Row> trainData = Arrays.asList( Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), @@ -77,27 +75,12 @@ public class LogisticRegressionTest { Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); - private static final List<Row> multinomialTrainData = - Arrays.asList( - Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), - Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), - Row.of(Vectors.dense(3, 2, 3, 4), 2., 3.), - Row.of(Vectors.dense(4, 2, 3, 4), 2., 4.), - Row.of(Vectors.dense(5, 2, 3, 4), 2., 5.), - Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), - Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), - Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), - Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), - Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); - private static final double[] expectedCoefficient = - new double[] {0.525, -0.283, -0.425, -0.567}; + new double[] {0.470, -0.273, -0.410, -0.546}; private static final double TOLERANCE = 1e-7; - private Table binomialDataTable; - - private Table multinomialDataTable; + private Table trainDataTable; @Before public void before() { @@ -108,20 +91,11 @@ public class LogisticRegressionTest { env.enableCheckpointing(100); env.setRestartStrategy(RestartStrategies.noRestart()); tEnv = StreamTableEnvironment.create(env); - Collections.shuffle(binomialTrainData); - binomialDataTable = - tEnv.fromDataStream( - env.fromCollection( - binomialTrainData, - new RowTypeInfo( - new TypeInformation[] { - DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE - }, - new String[] {"features", "label", "weight"}))); - multinomialDataTable = + Collections.shuffle(trainData); + trainDataTable = tEnv.fromDataStream( env.fromCollection( - multinomialTrainData, + trainData, new RowTypeInfo( new TypeInformation[] { DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE @@ -140,31 +114,31 @@ public class LogisticRegressionTest { DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); if (feature.get(0) <= 5) { assertEquals(0, prediction, TOLERANCE); - assertTrue(rawPrediction.get(0) > 0.5); + assertTrue(rawPrediction.get(0) < 0); } else { assertEquals(1, prediction, TOLERANCE); - assertTrue(rawPrediction.get(0) < 0.5); + assertTrue(rawPrediction.get(0) > 0); } } } @Test public void testParam() { - LogisticRegression logisticRegression = new LogisticRegression(); - assertEquals("label", logisticRegression.getLabelCol()); - assertNull(logisticRegression.getWeightCol()); - assertEquals(20, logisticRegression.getMaxIter()); - assertEquals(0, logisticRegression.getReg(), TOLERANCE); - assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE); - assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE); - assertEquals(32, logisticRegression.getGlobalBatchSize()); - assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE); - assertEquals("auto", logisticRegression.getMultiClass()); - assertEquals("features", logisticRegression.getFeaturesCol()); - assertEquals("prediction", logisticRegression.getPredictionCol()); - assertEquals("rawPrediction", logisticRegression.getRawPredictionCol()); + LinearSVC linearSVC = new LinearSVC(); + assertEquals("features", linearSVC.getFeaturesCol()); + assertEquals("label", linearSVC.getLabelCol()); + assertNull(linearSVC.getWeightCol()); + assertEquals(20, linearSVC.getMaxIter()); + assertEquals(1e-6, linearSVC.getTol(), TOLERANCE); + assertEquals(0.1, linearSVC.getLearningRate(), TOLERANCE); + assertEquals(32, linearSVC.getGlobalBatchSize()); + assertEquals(0, linearSVC.getReg(), TOLERANCE); + assertEquals(0, linearSVC.getElasticNet(), TOLERANCE); + assertEquals(0.0, linearSVC.getThreshold(), TOLERANCE); + assertEquals("prediction", linearSVC.getPredictionCol()); + assertEquals("rawPrediction", linearSVC.getRawPredictionCol()); - logisticRegression + linearSVC .setFeaturesCol("test_features") .setLabelCol("test_label") .setWeightCol("test_weight") @@ -174,34 +148,34 @@ public class LogisticRegressionTest { .setGlobalBatchSize(1000) .setReg(0.1) .setElasticNet(0.5) - .setMultiClass("binomial") + .setThreshold(0.5) .setPredictionCol("test_predictionCol") .setRawPredictionCol("test_rawPredictionCol"); - assertEquals("test_features", logisticRegression.getFeaturesCol()); - assertEquals("test_label", logisticRegression.getLabelCol()); - assertEquals("test_weight", logisticRegression.getWeightCol()); - assertEquals(1000, logisticRegression.getMaxIter()); - assertEquals(0.001, logisticRegression.getTol(), TOLERANCE); - assertEquals(0.5, logisticRegression.getLearningRate(), TOLERANCE); - assertEquals(1000, logisticRegression.getGlobalBatchSize()); - assertEquals(0.1, logisticRegression.getReg(), TOLERANCE); - assertEquals(0.5, logisticRegression.getElasticNet(), TOLERANCE); - assertEquals("binomial", logisticRegression.getMultiClass()); - assertEquals("test_predictionCol", logisticRegression.getPredictionCol()); - assertEquals("test_rawPredictionCol", logisticRegression.getRawPredictionCol()); + assertEquals("test_features", linearSVC.getFeaturesCol()); + assertEquals("test_label", linearSVC.getLabelCol()); + assertEquals("test_weight", linearSVC.getWeightCol()); + assertEquals(1000, linearSVC.getMaxIter()); + assertEquals(0.001, linearSVC.getTol(), TOLERANCE); + assertEquals(0.5, linearSVC.getLearningRate(), TOLERANCE); + assertEquals(1000, linearSVC.getGlobalBatchSize()); + assertEquals(0.1, linearSVC.getReg(), TOLERANCE); + assertEquals(0.5, linearSVC.getElasticNet(), TOLERANCE); + assertEquals(0.5, linearSVC.getThreshold(), TOLERANCE); + assertEquals("test_predictionCol", linearSVC.getPredictionCol()); + assertEquals("test_rawPredictionCol", linearSVC.getRawPredictionCol()); } @Test public void testOutputSchema() { - Table tempTable = binomialDataTable.as("test_features", "test_label", "test_weight"); - LogisticRegression logisticRegression = - new LogisticRegression() + Table tempTable = trainDataTable.as("test_features", "test_label", "test_weight"); + LinearSVC linearSVC = + new LinearSVC() .setFeaturesCol("test_features") .setLabelCol("test_label") .setWeightCol("test_weight") .setPredictionCol("test_predictionCol") .setRawPredictionCol("test_rawPredictionCol"); - Table output = logisticRegression.fit(binomialDataTable).transform(tempTable)[0]; + Table output = linearSVC.fit(trainDataTable).transform(tempTable)[0]; assertEquals( Arrays.asList( "test_features", @@ -214,42 +188,42 @@ public class LogisticRegressionTest { @Test public void testFitAndPredict() throws Exception { - LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); - Table output = logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0]; + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + Table output = linearSVC.fit(trainDataTable).transform(trainDataTable)[0]; verifyPredictionResult( output, - logisticRegression.getFeaturesCol(), - logisticRegression.getPredictionCol(), - logisticRegression.getRawPredictionCol()); + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); } @Test public void testSaveLoadAndPredict() throws Exception { - LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); - logisticRegression = + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + linearSVC = StageTestUtils.saveAndReload( - tEnv, logisticRegression, tempFolder.newFolder().getAbsolutePath()); - LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + tEnv, linearSVC, tempFolder.newFolder().getAbsolutePath()); + LinearSVCModel model = linearSVC.fit(trainDataTable); model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); assertEquals( Collections.singletonList("coefficient"), model.getModelData()[0].getResolvedSchema().getColumnNames()); - Table output = model.transform(binomialDataTable)[0]; + Table output = model.transform(trainDataTable)[0]; verifyPredictionResult( output, - logisticRegression.getFeaturesCol(), - logisticRegression.getPredictionCol(), - logisticRegression.getRawPredictionCol()); + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); } @Test @SuppressWarnings("unchecked") public void testGetModelData() throws Exception { - LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); - LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); - List<LogisticRegressionModelData> modelData = + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + LinearSVCModel model = linearSVC.fit(trainDataTable); + List<LinearSVCModelData> modelData = IteratorUtils.toList( - LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + LinearSVCModelData.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); assertEquals(1, modelData.size()); assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); @@ -257,68 +231,74 @@ public class LogisticRegressionTest { @Test public void testSetModelData() throws Exception { - LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); - LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + LinearSVCModel model = linearSVC.fit(trainDataTable); - LogisticRegressionModel newModel = new LogisticRegressionModel(); + LinearSVCModel newModel = new LinearSVCModel(); ReadWriteUtils.updateExistingParams(newModel, model.getParamMap()); newModel.setModelData(model.getModelData()); - Table output = newModel.transform(binomialDataTable)[0]; + Table output = newModel.transform(trainDataTable)[0]; verifyPredictionResult( output, - logisticRegression.getFeaturesCol(), - logisticRegression.getPredictionCol(), - logisticRegression.getRawPredictionCol()); - } - - @Test - public void testMultinomialFit() { - try { - new LogisticRegression().fit(multinomialDataTable); - env.execute(); - fail(); - } catch (Throwable e) { - assertEquals( - "Multinomial classification is not supported yet. Supported options: [auto, binomial].", - ExceptionUtils.getRootCause(e).getMessage()); - } + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); } @Test public void testMoreSubtaskThanData() throws Exception { env.setParallelism(12); - LogisticRegression logisticRegression = - new LogisticRegression().setWeightCol("weight").setGlobalBatchSize(128); - Table output = logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0]; + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight").setGlobalBatchSize(128); + Table output = linearSVC.fit(trainDataTable).transform(trainDataTable)[0]; verifyPredictionResult( output, - logisticRegression.getFeaturesCol(), - logisticRegression.getPredictionCol(), - logisticRegression.getRawPredictionCol()); + linearSVC.getFeaturesCol(), + linearSVC.getPredictionCol(), + linearSVC.getRawPredictionCol()); } @Test public void testRegularization() throws Exception { checkRegularization(0, RandomUtils.nextDouble(0, 1), expectedCoefficient); - checkRegularization(0.1, 0, new double[] {0.484, -0.258, -0.388, -0.517}); - checkRegularization(0.1, 1, new double[] {0.417, -0.145, -0.312, -0.480}); - checkRegularization(0.1, 0.5, new double[] {0.451, -0.203, -0.351, -0.498}); + checkRegularization(0.1, 0, new double[] {0.437, -0.262, -0.393, -0.524}); + checkRegularization(0.1, 1, new double[] {0.426, -0.197, -0.329, -0.463}); + checkRegularization(0.1, 0.5, new double[] {0.419, -0.238, -0.372, -0.505}); + } + + @Test + public void testThreshold() throws Exception { + checkThreshold(-Double.MAX_VALUE, 1); + checkThreshold(Double.MAX_VALUE, 0); } @SuppressWarnings("unchecked") private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) throws Exception { - LogisticRegressionModel model = - new LogisticRegression() + LinearSVCModel model = + new LinearSVC() .setWeightCol("weight") .setReg(reg) .setElasticNet(elasticNet) - .fit(binomialDataTable); - List<LogisticRegressionModelData> modelData = + .fit(trainDataTable); + List<LinearSVCModelData> modelData = IteratorUtils.toList( - LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + LinearSVCModelData.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); final double errorTol = 1e-3; assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, errorTol); } + + @SuppressWarnings("unchecked") + private void checkThreshold(double threshold, double target) throws Exception { + LinearSVC linearSVC = new LinearSVC().setWeightCol("weight"); + + Table predictions = + linearSVC.setThreshold(threshold).fit(trainDataTable).transform(trainDataTable)[0]; + + List<Row> predResult = + IteratorUtils.toList(tEnv.toDataStream(predictions).executeAndCollect()); + for (Row r : predResult) { + assertEquals(target, r.getField(linearSVC.getPredictionCol())); + } + } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index 8d06613..4bdab6b 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -151,16 +151,16 @@ public class LogisticRegressionTest { @Test public void testParam() { LogisticRegression logisticRegression = new LogisticRegression(); + assertEquals("features", logisticRegression.getFeaturesCol()); assertEquals("label", logisticRegression.getLabelCol()); assertNull(logisticRegression.getWeightCol()); assertEquals(20, logisticRegression.getMaxIter()); - assertEquals(0, logisticRegression.getReg(), TOLERANCE); - assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE); + assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE); assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE); assertEquals(32, logisticRegression.getGlobalBatchSize()); - assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE); + assertEquals(0, logisticRegression.getReg(), TOLERANCE); + assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE); assertEquals("auto", logisticRegression.getMultiClass()); - assertEquals("features", logisticRegression.getFeaturesCol()); assertEquals("prediction", logisticRegression.getPredictionCol()); assertEquals("rawPrediction", logisticRegression.getRawPredictionCol()); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java new file mode 100644 index 0000000..1cd165e --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link HingeLoss}. */ +public class HingeLossTest { + private static final LabeledPointWithWeight dataPoint1 = + new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, -1.0), 1.0, 2.0); + private static final LabeledPointWithWeight dataPoint2 = + new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, 1.0), 1.0, 2.0); + private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final double TOLERANCE = 1e-7; + + @Test + public void computeLoss() { + double loss = HingeLoss.INSTANCE.computeLoss(dataPoint1, coefficient); + assertEquals(4.0, loss, TOLERANCE); + + loss = HingeLoss.INSTANCE.computeLoss(dataPoint2, coefficient); + assertEquals(0.0, loss, TOLERANCE); + } + + @Test + public void computeGradient() { + HingeLoss.INSTANCE.computeGradient(dataPoint1, coefficient, cumGradient); + assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, TOLERANCE); + + HingeLoss.INSTANCE.computeGradient(dataPoint2, coefficient, cumGradient); + assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, TOLERANCE); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java index 58e89c4..3ea99d3 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java @@ -118,15 +118,15 @@ public class LinearRegressionTest { @Test public void testParam() { LinearRegression linearRegression = new LinearRegression(); + assertEquals("features", linearRegression.getFeaturesCol()); assertEquals("label", linearRegression.getLabelCol()); assertNull(linearRegression.getWeightCol()); assertEquals(20, linearRegression.getMaxIter()); - assertEquals(0, linearRegression.getReg(), TOLERANCE); - assertEquals(0, linearRegression.getElasticNet(), TOLERANCE); + assertEquals(1e-6, linearRegression.getTol(), TOLERANCE); assertEquals(0.1, linearRegression.getLearningRate(), TOLERANCE); assertEquals(32, linearRegression.getGlobalBatchSize()); - assertEquals(1e-6, linearRegression.getTol(), TOLERANCE); - assertEquals("features", linearRegression.getFeaturesCol()); + assertEquals(0, linearRegression.getReg(), TOLERANCE); + assertEquals(0, linearRegression.getElasticNet(), TOLERANCE); assertEquals("prediction", linearRegression.getPredictionCol()); linearRegression