This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new f34dbb7 [FLINK-27093] Add Transformer and Estimator for LinearRegression f34dbb7 is described below commit f34dbb708a0d151caa5afab593d04badeb549224 Author: Zhipeng Zhang <zhangzhipe...@gmail.com> AuthorDate: Fri May 6 16:02:55 2022 +0800 [FLINK-27093] Add Transformer and Estimator for LinearRegression This closes #90. --- .../flink/ml/common/datastream/AllReduceImpl.java | 2 + .../ml/common/datastream/DataStreamUtils.java | 83 ++++- .../main/java/org/apache/flink/ml/linalg/BLAS.java | 19 +- .../ml/common/datastream/DataStreamUtilsTest.java | 11 + .../java/org/apache/flink/ml/linalg/BLASTest.java | 16 + .../logisticregression/LogisticGradient.java | 97 ----- .../logisticregression/LogisticRegression.java | 401 +++------------------ .../LogisticRegressionModel.java | 80 ++-- .../LogisticRegressionParams.java | 2 + .../ml/common/lossfunc/BinaryLogisticLoss.java | 50 +++ .../flink/ml/common/lossfunc/LeastSquareLoss.java | 50 +++ .../apache/flink/ml/common/lossfunc/LossFunc.java | 51 +++ .../flink/ml/common/optimizer/Optimizer.java | 46 +++ .../ml/common/optimizer/RegularizationUtils.java | 92 +++++ .../org/apache/flink/ml/common/optimizer/SGD.java | 390 ++++++++++++++++++++ .../flink/ml/common/param/HasElasticNet.java | 47 +++ .../linearregression/LinearRegression.java | 122 +++++++ .../linearregression/LinearRegressionModel.java} | 113 +++--- .../LinearRegressionModelData.java | 111 ++++++ .../LinearRegressionModelParams.java | 29 ++ .../linearregression/LinearRegressionParams.java} | 12 +- .../ml/classification/LogisticRegressionTest.java | 101 ++++-- .../ml/common/lossfunc/BinaryLogisticLossTest.java | 53 +++ .../ml/common/lossfunc/LeastSquareLossTest.java | 51 +++ .../common/optimizer/RegularizationUtilsTest.java | 47 +++ .../flink/ml/regression/LinearRegressionTest.java | 255 +++++++++++++ 26 files changed, 1721 insertions(+), 610 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java index 760b5db..167572a 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.common.datastream; +import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; @@ -49,6 +50,7 @@ import java.util.Map; * <li>All workers do reduce on all data it received and then broadcast partial results to others. * <li>All workers merge partial results into final result. */ +@Internal class AllReduceImpl { @VisibleForTesting static final int CHUNK_SIZE = 1024 * 4; diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java index 7eea6b0..10073b8 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java @@ -18,12 +18,16 @@ package org.apache.flink.ml.common.datastream; +import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; @@ -32,6 +36,7 @@ import org.apache.flink.streaming.api.operators.TimestampedCollector; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; /** Provides utility functions for {@link DataStream}. */ +@Internal public class DataStreamUtils { /** * Applies allReduceSum on the input data stream. The input data stream is supposed to contain @@ -55,8 +60,8 @@ public class DataStreamUtils { * * @param input The input data stream. * @param func The user defined mapPartition function. - * @param <IN> The class type of the input element. - * @param <OUT> The class type of output element. + * @param <IN> The class type of the input. + * @param <OUT> The class type of output. * @return The result data stream. */ public static <IN, OUT> DataStream<OUT> mapPartition( @@ -67,6 +72,28 @@ public class DataStreamUtils { .setParallelism(input.getParallelism()); } + /** + * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most + * one stream record and its parallelism is one. + * + * @param input The input data stream. + * @param func The user defined reduce function. + * @param <T> The class type of the input. + * @return The result data stream. + */ + public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) { + DataStream<T> partialReducedStream = + input.transform("reduce", input.getType(), new ReduceOperator<>(func)) + .setParallelism(input.getParallelism()); + if (partialReducedStream.getParallelism() == 1) { + return partialReducedStream; + } else { + return partialReducedStream + .transform("reduce", input.getType(), new ReduceOperator<>(func)) + .setParallelism(1); + } + } + /** * A stream operator to apply {@link MapPartitionFunction} on each partition of the input * bounded data stream. @@ -103,4 +130,56 @@ public class DataStreamUtils { valuesState.add(input.getValue()); } } + + /** A stream operator to apply {@link ReduceFunction} on the input bounded data stream. */ + private static class ReduceOperator<T> extends AbstractUdfStreamOperator<T, ReduceFunction<T>> + implements OneInputStreamOperator<T, T>, BoundedOneInput { + /** The temp result of the reduce function. */ + private T result; + + private ListState<T> state; + + public ReduceOperator(ReduceFunction<T> userFunction) { + super(userFunction); + } + + @Override + public void endInput() { + if (result != null) { + output.collect(new StreamRecord<>(result)); + } + } + + @Override + public void processElement(StreamRecord<T> streamRecord) throws Exception { + if (result == null) { + result = streamRecord.getValue(); + } else { + result = userFunction.reduce(streamRecord.getValue(), result); + } + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + state = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<T>( + "state", + getOperatorConfig() + .getTypeSerializerIn( + 0, getClass().getClassLoader()))); + result = OperatorStateUtils.getUniqueElement(state, "state").orElse(null); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + state.clear(); + if (result != null) { + state.add(result); + } + } + } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java index 383bf66..24cc3ef 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java @@ -34,10 +34,16 @@ public class BLAS { /** y += a * x . */ public static void axpy(double a, Vector x, DenseVector y) { Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); + axpy(a, x, y, x.size()); + } + + /** y += a * x for the first k dimensions, with the other dimensions unchanged. */ + public static void axpy(double a, Vector x, DenseVector y, int k) { + Preconditions.checkArgument(x.size() >= k && y.size() >= k); if (x instanceof SparseVector) { - axpy(a, (SparseVector) x, y); + axpy(a, (SparseVector) x, y, k); } else { - axpy(a, (DenseVector) x, y); + axpy(a, (DenseVector) x, y, k); } } @@ -112,13 +118,16 @@ public class BLAS { 1); } - private static void axpy(double a, DenseVector x, DenseVector y) { - JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1); + private static void axpy(double a, DenseVector x, DenseVector y, int k) { + JAVA_BLAS.daxpy(k, a, x.values, 1, y.values, 1); } - private static void axpy(double a, SparseVector x, DenseVector y) { + private static void axpy(double a, SparseVector x, DenseVector y, int k) { for (int i = 0; i < x.indices.length; i++) { int index = x.indices[i]; + if (index >= k) { + return; + } y.values[index] += a * x.values[i]; } } diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java index 7933859..7dc88c8 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java @@ -19,6 +19,7 @@ package org.apache.flink.ml.common.datastream; import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.restartstrategy.RestartStrategies; import org.apache.flink.api.common.typeinfo.Types; @@ -63,6 +64,16 @@ public class DataStreamUtilsTest { new int[] {5, 5, 5, 5}, counts.stream().mapToInt(Integer::intValue).toArray()); } + @Test + public void testReduce() throws Exception { + DataStream<Long> dataStream = + env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG); + DataStream<Long> result = + DataStreamUtils.reduce(dataStream, (ReduceFunction<Long>) Long::sum); + List<Long> sum = IteratorUtils.toList(result.executeAndCollect()); + assertArrayEquals(new long[] {190L}, sum.stream().mapToLong(Long::longValue).toArray()); + } + /** A simple implementation for a {@link MapPartitionFunction}. */ private static class TestMapPartitionFunc extends RichMapPartitionFunction<Long, Integer> { diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java index 7bcd853..7055c62 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java @@ -51,6 +51,22 @@ public class BLASTest { assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); } + @Test + public void testAxpyK() { + // Tests axpy(dense, dense, k). + DenseVector anotherDenseVec = Vectors.dense(1, 2, 3); + BLAS.axpy(1, inputDenseVec, anotherDenseVec, 3); + double[] expectedResult = new double[] {2, 0, 6}; + assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); + + // Tests axpy(sparse, dense, k). + SparseVector sparseVec = Vectors.sparse(5, new int[] {0, 2, 4}, new double[] {1, 3, 5}); + anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5, 6, 7); + BLAS.axpy(2, sparseVec, anotherDenseVec, 5); + expectedResult = new double[] {3, 2, 9, 4, 15, 6, 7}; + assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE); + } + @Test public void testDot() { DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java deleted file mode 100644 index 13f753b..0000000 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.classification.logisticregression; - -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.linalg.BLAS; -import org.apache.flink.ml.linalg.DenseVector; - -import java.io.Serializable; -import java.util.List; - -/** - * Utility class to compute gradient and loss for logistic loss function. - * - * <p>See http://mlwiki.org/index.php/Logistic_Regression. - */ -public class LogisticGradient implements Serializable { - - /** L2 regularization term. */ - private final double l2; - - public LogisticGradient(double l2) { - this.l2 = l2; - } - - /** - * Computes weight sum and loss sum on a set of samples. - * - * @param dataPoints A sample set of train data. - * @param coefficient The model parameters. - * @return Weight sum and loss sum of the input data. - */ - public Tuple2<Double, Double> computeLoss( - List<LabeledPointWithWeight> dataPoints, DenseVector coefficient) { - double weightSum = 0.0; - double lossSum = 0.0; - for (LabeledPointWithWeight dataPoint : dataPoints) { - lossSum += dataPoint.getWeight() * computeLoss(dataPoint, coefficient); - weightSum += dataPoint.getWeight(); - } - if (Double.compare(0, l2) != 0) { - lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2); - } - return Tuple2.of(weightSum, lossSum); - } - - /** - * Computes gradient on a set of samples. - * - * @param dataPoints A sample set of train data. - * @param coefficient The model parameters. - * @param cumGradient The accumulated gradients. - */ - public void computeGradient( - List<LabeledPointWithWeight> dataPoints, - DenseVector coefficient, - DenseVector cumGradient) { - for (LabeledPointWithWeight dataPoint : dataPoints) { - computeGradient(dataPoint, coefficient, cumGradient); - } - if (Double.compare(0, l2) != 0) { - BLAS.axpy(l2 * 2, coefficient, cumGradient); - } - } - - private double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - return Math.log(1 + Math.exp(-dot * labelScaled)); - } - - private void computeGradient( - LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { - double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); - double labelScaled = 2 * dataPoint.getLabel() - 1; - double multiplier = - dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); - BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient); - } -} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index 58cf0ce..f64fc10 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -18,54 +18,26 @@ package org.apache.flink.ml.classification.logisticregression; -import org.apache.flink.api.common.state.ListState; -import org.apache.flink.api.common.state.ListStateDescriptor; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.iteration.DataStreamList; -import org.apache.flink.iteration.IterationBody; -import org.apache.flink.iteration.IterationBodyResult; -import org.apache.flink.iteration.IterationConfig; -import org.apache.flink.iteration.IterationListener; -import org.apache.flink.iteration.Iterations; -import org.apache.flink.iteration.ReplayableDataStreamList; -import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.ml.api.Estimator; import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.feature.LabeledPointWithWeight; -import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; -import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss; +import org.apache.flink.ml.common.optimizer.Optimizer; +import org.apache.flink.ml.common.optimizer.SGD; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; -import org.apache.flink.runtime.state.StateInitializationContext; -import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; -import org.apache.flink.streaming.api.operators.AbstractStreamOperator; -import org.apache.flink.streaming.api.operators.BoundedOneInput; -import org.apache.flink.streaming.api.operators.OneInputStreamOperator; -import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; -import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.api.internal.TableImpl; -import org.apache.flink.util.Collector; -import org.apache.flink.util.OutputTag; import org.apache.flink.util.Preconditions; -import org.apache.commons.collections.IteratorUtils; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.Random; /** * An Estimator which implements the logistic regression algorithm. @@ -82,21 +54,6 @@ public class LogisticRegression ParamUtils.initializeMapWithDefaultValues(paramMap, this); } - @Override - public Map<Param<?>, Object> getParamMap() { - return paramMap; - } - - @Override - public void save(String path) throws IOException { - ReadWriteUtils.saveMetadata(this, path); - } - - public static LogisticRegression load(StreamTableEnvironment tEnv, String path) - throws IOException { - return ReadWriteUtils.loadStageParam(path); - } - @Override @SuppressWarnings({"rawTypes", "ConstantConditions"}) public LogisticRegressionModel fit(Table... inputs) { @@ -107,15 +64,16 @@ public class LogisticRegression "Multinomial classification is not supported yet. Supported options: [auto, binomial]."); StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<LabeledPointWithWeight> trainData = tEnv.toDataStream(inputs[0]) .map( dataPoint -> { - Double weight = + double weight = getWeightCol() == null ? 1.0 : (Double) dataPoint.getField(getWeightCol()); - Double label = (Double) dataPoint.getField(getLabelCol()); + double label = (Double) dataPoint.getField(getLabelCol()); boolean isBinomial = Double.compare(0., label) == 0 || Double.compare(1., label) == 0; @@ -127,327 +85,50 @@ public class LogisticRegression (DenseVector) dataPoint.getField(getFeaturesCol()); return new LabeledPointWithWeight(features, label, weight); }); - DataStream<double[]> initModelData = - trainData - .transform("getModelDim", BasicTypeInfo.INT_TYPE_INFO, new GetModelDim()) - .setParallelism(1) - .broadcast() - .map(double[]::new); - DataStream<LogisticRegressionModelData> modelData = train(trainData, initModelData); + DataStream<DenseVector> initModelData = + DataStreamUtils.reduce( + trainData.map(x -> x.getFeatures().size()), + (ReduceFunction<Integer>) + (t0, t1) -> { + Preconditions.checkState( + t0.equals(t1), + "The training data should all have same dimensions."); + return t0; + }) + .map(DenseVector::new); + + Optimizer optimizer = + new SGD( + getMaxIter(), + getLearningRate(), + getGlobalBatchSize(), + getTol(), + getReg(), + getElasticNet()); + DataStream<DenseVector> rawModelData = + optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE); + + DataStream<LogisticRegressionModelData> modelData = + rawModelData.map(LogisticRegressionModelData::new); LogisticRegressionModel model = new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); ReadWriteUtils.updateExistingParams(model, paramMap); return model; } - /** Gets the dimension of the model data. */ - private static class GetModelDim extends AbstractStreamOperator<Integer> - implements OneInputStreamOperator<LabeledPointWithWeight, Integer>, BoundedOneInput { - - private int dim = 0; - - private ListState<Integer> dimState; - - @Override - public void endInput() { - output.collect(new StreamRecord<>(dim)); - } - - @Override - public void processElement(StreamRecord<LabeledPointWithWeight> streamRecord) { - if (dim == 0) { - dim = streamRecord.getValue().getFeatures().size(); - } else { - if (dim != streamRecord.getValue().getFeatures().size()) { - throw new RuntimeException( - "The training data should all have same dimensions."); - } - } - } - - @Override - public void initializeState(StateInitializationContext context) throws Exception { - super.initializeState(context); - dimState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "dimState", BasicTypeInfo.INT_TYPE_INFO)); - dim = OperatorStateUtils.getUniqueElement(dimState, "dimState").orElse(0); - } - - @Override - public void snapshotState(StateSnapshotContext context) throws Exception { - dimState.clear(); - dimState.add(dim); - } - } - - /** - * Does machine learning training on the input data with the initialized model data. - * - * @param trainData The training data. - * @param initModelData The initialized model. - * @return The trained model data. - */ - private DataStream<LogisticRegressionModelData> train( - DataStream<LabeledPointWithWeight> trainData, DataStream<double[]> initModelData) { - LogisticGradient logisticGradient = new LogisticGradient(getReg()); - DataStreamList resultList = - Iterations.iterateBoundedStreamsUntilTermination( - DataStreamList.of(initModelData), - ReplayableDataStreamList.notReplay(trainData), - IterationConfig.newBuilder().build(), - new TrainIterationBody( - logisticGradient, - getGlobalBatchSize(), - getLearningRate(), - getMaxIter(), - getTol())); - return resultList.get(0); + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); } - /** The iteration implementation for training process. */ - private static class TrainIterationBody implements IterationBody { - - private final LogisticGradient logisticGradient; - - private final int globalBatchSize; - - private final double learningRate; - - private final int maxIter; - - private final double tol; - - public TrainIterationBody( - LogisticGradient logisticGradient, - int globalBatchSize, - double learningRate, - int maxIter, - double tol) { - this.logisticGradient = logisticGradient; - this.globalBatchSize = globalBatchSize; - this.learningRate = learningRate; - this.maxIter = maxIter; - this.tol = tol; - } - - @Override - public IterationBodyResult process( - DataStreamList variableStreams, DataStreamList dataStreams) { - // The variable stream at the first iteration is the initialized model data. - // In the following iterations, it contains: the computed gradient, weightSum and - // lossSum. - DataStream<double[]> variableStream = variableStreams.get(0); - DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0); - final OutputTag<LogisticRegressionModelData> modelDataOutputTag = - new OutputTag<LogisticRegressionModelData>("MODEL_OUTPUT") {}; - SingleOutputStreamOperator<double[]> gradientAndWeightAndLoss = - trainData - .connect(variableStream) - .transform( - "CacheDataAndDoTrain", - PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, - new CacheDataAndDoTrain( - logisticGradient, - globalBatchSize, - learningRate, - modelDataOutputTag)); - DataStreamList feedbackVariableStream = - IterationBody.forEachRound( - DataStreamList.of(gradientAndWeightAndLoss), - input -> { - DataStream<double[]> feedback = - DataStreamUtils.allReduceSum(input.get(0)); - return DataStreamList.of(feedback); - }); - DataStream<Integer> terminationCriteria = - feedbackVariableStream - .get(0) - .map( - reducedGradientAndWeightAndLoss -> { - double[] value = (double[]) reducedGradientAndWeightAndLoss; - return value[value.length - 1] / value[value.length - 2]; - }) - .flatMap(new TerminateOnMaxIterOrTol(maxIter, tol)); - return new IterationBodyResult( - DataStreamList.of(feedbackVariableStream.get(0)), - DataStreamList.of(gradientAndWeightAndLoss.getSideOutput(modelDataOutputTag)), - terminationCriteria); - } + public static LogisticRegression load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); } - /** - * A stream operator that caches the training data in the first iteration and updates the model - * using gradients iteratively. The first input is the training data, and the second input is - * the initialized model data or feedback of gradient, weight and loss. - */ - private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]> - implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>, - IterationListener<double[]> { - - private final int globalBatchSize; - - private int localBatchSize; - - private final double learningRate; - - private final LogisticGradient logisticGradient; - - private DenseVector gradient; - - private DenseVector coefficient; - - private int coefficientDim; - - private ListState<DenseVector> coefficientState; - - private List<LabeledPointWithWeight> trainData; - - private ListState<LabeledPointWithWeight> trainDataState; - - private final Random random = new Random(2021); - - private List<LabeledPointWithWeight> miniBatchData; - - /** The buffer for feedback record: {gradient, weightSum, loss}. */ - private double[] feedbackBuffer; - - private ListState<double[]> feedbackBufferState; - - private final OutputTag<LogisticRegressionModelData> modelDataOutputTag; - - public CacheDataAndDoTrain( - LogisticGradient logisticGradient, - int globalBatchSize, - double learningRate, - OutputTag<LogisticRegressionModelData> modelDataOutputTag) { - this.logisticGradient = logisticGradient; - this.globalBatchSize = globalBatchSize; - this.learningRate = learningRate; - this.modelDataOutputTag = modelDataOutputTag; - } - - @Override - public void open() { - int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); - int taskId = getRuntimeContext().getIndexOfThisSubtask(); - localBatchSize = globalBatchSize / numTasks; - if (globalBatchSize % numTasks > taskId) { - localBatchSize++; - } - this.miniBatchData = new ArrayList<>(localBatchSize); - } - - private List<LabeledPointWithWeight> getMiniBatchData( - List<LabeledPointWithWeight> fullBatchData, int batchSize) { - miniBatchData.clear(); - for (int i = 0; i < batchSize; i++) { - miniBatchData.add(fullBatchData.get(random.nextInt(fullBatchData.size()))); - } - return miniBatchData; - } - - private void updateModel() { - System.arraycopy(feedbackBuffer, 0, gradient.values, 0, gradient.size()); - double weightSum = feedbackBuffer[coefficientDim]; - BLAS.axpy(-learningRate / weightSum, gradient, coefficient); - } - - @Override - public void onEpochWatermarkIncremented( - int epochWatermark, Context context, Collector<double[]> collector) - throws Exception { - if (epochWatermark == 0) { - coefficient = new DenseVector(feedbackBuffer); - coefficientDim = coefficient.size(); - feedbackBuffer = new double[coefficientDim + 2]; - gradient = new DenseVector(coefficientDim); - } else { - updateModel(); - } - Arrays.fill(gradient.values, 0); - if (trainData == null) { - trainData = IteratorUtils.toList(trainDataState.get().iterator()); - } - if (trainData.size() > 0) { - miniBatchData = getMiniBatchData(trainData, localBatchSize); - Tuple2<Double, Double> weightSumAndLossSum = - logisticGradient.computeLoss(miniBatchData, coefficient); - logisticGradient.computeGradient(miniBatchData, coefficient, gradient); - System.arraycopy(gradient.values, 0, feedbackBuffer, 0, gradient.size()); - feedbackBuffer[coefficientDim] = weightSumAndLossSum.f0; - feedbackBuffer[coefficientDim + 1] = weightSumAndLossSum.f1; - collector.collect(feedbackBuffer); - } - } - - @Override - public void onIterationTerminated(Context context, Collector<double[]> collector) { - trainDataState.clear(); - coefficientState.clear(); - feedbackBufferState.clear(); - if (getRuntimeContext().getIndexOfThisSubtask() == 0) { - updateModel(); - context.output(modelDataOutputTag, new LogisticRegressionModelData(coefficient)); - } - } - - @Override - public void processElement1(StreamRecord<LabeledPointWithWeight> streamRecord) - throws Exception { - trainDataState.add(streamRecord.getValue()); - } - - @Override - public void processElement2(StreamRecord<double[]> streamRecord) { - feedbackBuffer = streamRecord.getValue(); - } - - @Override - public void initializeState(StateInitializationContext context) throws Exception { - super.initializeState(context); - trainDataState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "trainDataState", - TypeInformation.of(LabeledPointWithWeight.class))); - coefficientState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "coefficientState", - TypeInformation.of(DenseVector.class))); - OperatorStateUtils.getUniqueElement(coefficientState, "coefficientState") - .ifPresent(x -> coefficient = x); - feedbackBufferState = - context.getOperatorStateStore() - .getListState( - new ListStateDescriptor<>( - "feedbackBufferState", - PrimitiveArrayTypeInfo - .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); - OperatorStateUtils.getUniqueElement(feedbackBufferState, "feedbackBufferState") - .ifPresent(x -> feedbackBuffer = x); - if (coefficient != null) { - coefficientDim = coefficient.size(); - gradient = new DenseVector(new double[coefficientDim]); - } - } - - @Override - public void snapshotState(StateSnapshotContext context) throws Exception { - coefficientState.clear(); - if (coefficient != null) { - coefficientState.add(coefficient); - } - feedbackBufferState.clear(); - if (feedbackBuffer != null) { - feedbackBufferState.add(feedbackBuffer); - } - } + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index 3c30a8d..247a4d5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -21,7 +21,6 @@ package org.apache.flink.ml.classification.logisticregression; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; @@ -59,40 +58,6 @@ public class LogisticRegressionModel ParamUtils.initializeMapWithDefaultValues(paramMap, this); } - @Override - public Map<Param<?>, Object> getParamMap() { - return paramMap; - } - - @Override - public void save(String path) throws IOException { - ReadWriteUtils.saveMetadata(this, path); - ReadWriteUtils.saveModelData( - LogisticRegressionModelData.getModelDataStream(modelDataTable), - path, - new LogisticRegressionModelData.ModelDataEncoder()); - } - - public static LogisticRegressionModel load(StreamTableEnvironment tEnv, String path) - throws IOException { - LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); - Table modelDataTable = - ReadWriteUtils.loadModelData( - tEnv, path, new LogisticRegressionModelData.ModelDataDecoder()); - return model.setModelData(modelDataTable); - } - - @Override - public LogisticRegressionModel setModelData(Table... inputs) { - modelDataTable = inputs[0]; - return this; - } - - @Override - public Table[] getModelData() { - return new Table[] {modelDataTable}; - } - @Override @SuppressWarnings("unchecked") public Table[] transform(Table... inputs) { @@ -127,6 +92,40 @@ public class LogisticRegressionModel return new Table[] {tEnv.fromDataStream(predictionResult)}; } + @Override + public LogisticRegressionModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + LogisticRegressionModelData.getModelDataStream(modelDataTable), + path, + new LogisticRegressionModelData.ModelDataEncoder()); + } + + public static LogisticRegressionModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new LogisticRegressionModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + /** A utility function used for prediction. */ private static class PredictLabelFunction extends RichMapFunction<Row, Row> { @@ -150,22 +149,21 @@ public class LogisticRegressionModel coefficient = modelData.coefficient; } DenseVector features = (DenseVector) dataPoint.getField(featuresCol); - Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient); - return Row.join(dataPoint, Row.of(predictionResult.f0, predictionResult.f1)); + Row predictionResult = predictOneDataPoint(features, coefficient); + return Row.join(dataPoint, predictionResult); } } /** - * The main logic that predicts one input record. + * The main logic that predicts one input data point. * * @param feature The input feature. * @param coefficient The model parameters. * @return The prediction label and the raw probabilities. */ - private static Tuple2<Double, DenseVector> predictRaw( - DenseVector feature, DenseVector coefficient) { + private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) { double dotValue = BLAS.dot(feature, coefficient); double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); - return new Tuple2<>(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); + return Row.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java index c9b5919..f016978 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.classification.logisticregression; +import org.apache.flink.ml.common.param.HasElasticNet; import org.apache.flink.ml.common.param.HasGlobalBatchSize; import org.apache.flink.ml.common.param.HasLabelCol; import org.apache.flink.ml.common.param.HasLearningRate; @@ -37,6 +38,7 @@ public interface LogisticRegressionParams<T> HasWeightCol<T>, HasMaxIter<T>, HasReg<T>, + HasElasticNet<T>, HasLearningRate<T>, HasGlobalBatchSize<T>, HasTol<T>, diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java new file mode 100644 index 0000000..cd24c06 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.classification.logisticregression.LogisticRegression; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; + +/** The loss function for binary logistic loss. See {@link LogisticRegression} for example. */ +@Internal +public class BinaryLogisticLoss implements LossFunc { + public static final BinaryLogisticLoss INSTANCE = new BinaryLogisticLoss(); + + private BinaryLogisticLoss() {} + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + return dataPoint.getWeight() * Math.log(1 + Math.exp(-dot * labelScaled)); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + double labelScaled = 2 * dataPoint.getLabel() - 1; + double multiplier = + dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * labelScaled) + 1)); + BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient, dataPoint.getFeatures().size()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java new file mode 100644 index 0000000..ea64649 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.regression.linearregression.LinearRegression; + +/** The loss function for least square loss. See {@link LinearRegression} for example. */ +@Internal +public class LeastSquareLoss implements LossFunc { + public static final LeastSquareLoss INSTANCE = new LeastSquareLoss(); + + private LeastSquareLoss() {} + + @Override + public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + return dataPoint.getWeight() * 0.5 * Math.pow(dot - dataPoint.getLabel(), 2); + } + + @Override + public void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient) { + double dot = BLAS.dot(dataPoint.getFeatures(), coefficient); + BLAS.axpy( + (dot - dataPoint.getLabel()) * dataPoint.getWeight(), + dataPoint.getFeatures(), + cumGradient, + dataPoint.getFeatures().size()); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java new file mode 100644 index 0000000..a90967a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; + +import java.io.Serializable; + +/** + * A loss function is to compute the loss and gradient with the given coefficient and training data. + */ +@Internal +public interface LossFunc extends Serializable { + + /** + * Computes the loss on the given data point. + * + * @param dataPoint A training data point. + * @param coefficient The model parameters. + * @return The loss of the input data. + */ + double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient); + + /** + * Computes the gradient on the given data point and adds the computed gradient to cumGradient. + * + * @param dataPoint A training data point. + * @param coefficient The model parameters. + * @param cumGradient The accumulated gradient. + */ + void computeGradient( + LabeledPointWithWeight dataPoint, DenseVector coefficient, DenseVector cumGradient); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java new file mode 100644 index 0000000..647741d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.optimizer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.streaming.api.datastream.DataStream; + +/** + * An optimizer is a function to modify the weight of a machine learning model, which aims to find + * the optimal parameter configuration for a machine learning model. Examples of optimizers could be + * stochastic gradient descent (SGD), L-BFGS, etc. + */ +@Internal +public interface Optimizer { + /** + * Optimizes the given loss function using the initial model data and the bounded training data. + * + * @param initModelData The initial model data. + * @param trainData The training data. + * @param lossFunc The loss function to optimize. + * @return The fitted model data. + */ + DataStream<DenseVector> optimize( + DataStream<DenseVector> initModelData, + DataStream<LabeledPointWithWeight> trainData, + LossFunc lossFunc); +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java new file mode 100644 index 0000000..3d36d9a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.optimizer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; + +/** + * A utility class for algorithms that need to handle regularization. The regularization term is + * defined as: + * + * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2 + * + * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html. + */ +@Internal +class RegularizationUtils { + + /** + * Regularize the model coefficient. The gradient of each dimension could be computed as: + * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of + * coefficient at i-th dimension. + * + * @param coefficient The model coefficient. + * @param reg The reg param. + * @param elasticNet The elasticNet param. + * @param learningRate The learningRate param. + * @return The loss introduced by regularization. + */ + public static double regularize( + DenseVector coefficient, + final double reg, + final double elasticNet, + final double learningRate) { + + if (Double.compare(reg, 0) == 0) { + return 0; + } else if (Double.compare(elasticNet, 0) == 0) { + // Only L2 regularization. + double loss = reg / 2 * BLAS.norm2(coefficient); + BLAS.scal(1 - learningRate * reg, coefficient); + return loss; + } else if (Double.compare(elasticNet, 1) == 0) { + // Only L1 regularization. + double loss = 0; + double[] coefficientArray = coefficient.values; + for (int i = 0; i < coefficientArray.length; i++) { + if (Double.compare(coefficientArray[i], 0) == 0) { + continue; + } + loss += elasticNet * reg * Math.signum(coefficientArray[i]); + coefficientArray[i] -= + learningRate * elasticNet * reg * Math.signum(coefficientArray[i]); + } + return loss; + } else { + // Both L1 and L2 are not zero. + double loss = 0; + double[] coefficientArray = coefficient.values; + for (int i = 0; i < coefficientArray.length; i++) { + loss += + elasticNet * reg * Math.signum(coefficientArray[i]) + + (1 - elasticNet) + * (reg / 2) + * coefficientArray[i] + * coefficientArray[i]; + coefficientArray[i] -= + (learningRate + * (elasticNet * reg * Math.signum(coefficientArray[i]) + + (1 - elasticNet) * reg * coefficientArray[i])); + } + return loss; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java new file mode 100644 index 0000000..2f78004 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.optimizer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.regression.linearregression.LinearRegression; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +/** + * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine + * learning models. It iteratively makes small adjustments to the machine learning model according + * to the gradient at each step, to decrease the error of the model. + * + * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent. + */ +@Internal +public class SGD implements Optimizer { + /** Params for SGD optimizer. */ + private final SGDParams params; + + public SGD( + int maxIter, + double learningRate, + int globalBatchSize, + double tol, + double reg, + double elasticNet) { + this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet); + } + + @Override + public DataStream<DenseVector> optimize( + DataStream<DenseVector> initModelData, + DataStream<LabeledPointWithWeight> trainData, + LossFunc lossFunc) { + DataStreamList resultList = + Iterations.iterateBoundedStreamsUntilTermination( + DataStreamList.of( + initModelData.broadcast().map(modelVec -> modelVec.values)), + ReplayableDataStreamList.notReplay(trainData.rebalance().map(x -> x)), + IterationConfig.newBuilder().build(), + new TrainIterationBody(lossFunc, params)); + return resultList.get(0); + } + + /** The iteration implementation for training process. */ + private static class TrainIterationBody implements IterationBody { + private final LossFunc lossFunc; + private final SGDParams params; + + public TrainIterationBody(LossFunc lossFunc, SGDParams params) { + this.lossFunc = lossFunc; + this.params = params; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + // The variable stream at the first iteration is the initialized model data. + // In the following iterations, it contains: [the model update, totalWeight, and + // totalLoss]. + DataStream<double[]> variableStream = variableStreams.get(0); + DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0); + final OutputTag<DenseVector> modelDataOutputTag = + new OutputTag<DenseVector>("MODEL_OUTPUT") {}; + + SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss = + trainData + .connect(variableStream) + .transform( + "CacheDataAndDoTrain", + PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO, + new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag)); + + DataStreamList feedbackVariableStream = + IterationBody.forEachRound( + DataStreamList.of(modelUpdateAndWeightAndLoss), + input -> { + DataStream<double[]> feedback = + DataStreamUtils.allReduceSum(input.get(0)); + return DataStreamList.of(feedback); + }); + + DataStream<Integer> terminationCriteria = + feedbackVariableStream + .get(0) + .map( + reducedUpdateAndWeightAndLoss -> { + double[] value = (double[]) reducedUpdateAndWeightAndLoss; + return value[value.length - 1] / value[value.length - 2]; + }) + .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol)); + + return new IterationBodyResult( + DataStreamList.of(feedbackVariableStream.get(0)), + DataStreamList.of( + modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)), + terminationCriteria); + } + } + + /** + * A stream operator that caches the training data in the first iteration and updates the model + * iteratively. The first input is the training data, and the second input is the initial model + * data or feedback of model update, totalWeight, and totalLoss. + */ + private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]> + implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>, + IterationListener<double[]> { + /** Optimizer-related parameters. */ + private final SGDParams params; + + /** The loss function to optimize. */ + private final LossFunc lossFunc; + + /** The outputTag to output the model data when iteration ends. */ + private final OutputTag<DenseVector> modelDataOutputTag; + + /** The cached training data. */ + private List<LabeledPointWithWeight> trainData; + + private ListState<LabeledPointWithWeight> trainDataState; + + /** The start index (offset) of the next mini-batch data for training. */ + private int nextBatchOffset = 0; + + private ListState<Integer> nextBatchOffsetState; + + /** The model coefficient. */ + private DenseVector coefficient; + + private ListState<DenseVector> coefficientState; + + /** The dimension of the coefficient. */ + private int coefficientDim; + + /** + * The double array to sync among all workers. For example, when training {@link + * LinearRegression}, this double array consists of [modelUpdate, totalWeight, totalLoss]. + */ + private double[] feedbackArray; + + private ListState<double[]> feedbackArrayState; + + /** The batch size on this partition. */ + private int localBatchSize; + + private CacheDataAndDoTrain( + LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) { + this.lossFunc = lossFunc; + this.params = params; + this.modelDataOutputTag = modelDataOutputTag; + } + + @Override + public void open() { + int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); + int taskId = getRuntimeContext().getIndexOfThisSubtask(); + localBatchSize = params.globalBatchSize / numTasks; + if (params.globalBatchSize % numTasks > taskId) { + localBatchSize++; + } + } + + private double getTotalWeight() { + return feedbackArray[coefficientDim]; + } + + private void setTotalWeight(double totalWeight) { + feedbackArray[coefficientDim] = totalWeight; + } + + private double getTotalLoss() { + return feedbackArray[coefficientDim + 1]; + } + + private void setTotalLoss(double totalLoss) { + feedbackArray[coefficientDim + 1] = totalLoss; + } + + private void updateModel() { + if (getTotalWeight() > 0) { + BLAS.axpy( + -params.learningRate / getTotalWeight(), + new DenseVector(feedbackArray), + coefficient, + coefficientDim); + double regLoss = + RegularizationUtils.regularize( + coefficient, params.reg, params.elasticNet, params.learningRate); + setTotalLoss(getTotalLoss() + regLoss); + } + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector<double[]> collector) + throws Exception { + if (epochWatermark == 0) { + coefficient = new DenseVector(feedbackArray); + coefficientDim = coefficient.size(); + feedbackArray = new double[coefficient.size() + 2]; + } else { + updateModel(); + } + + if (trainData == null) { + trainData = IteratorUtils.toList(trainDataState.get().iterator()); + } + + // TODO: supports efficient shuffle of training set on each partition. + if (trainData.size() > 0) { + List<LabeledPointWithWeight> miniBatchData = + trainData.subList( + nextBatchOffset, + Math.min(nextBatchOffset + localBatchSize, trainData.size())); + nextBatchOffset += localBatchSize; + nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset; + + // Does the training. + Arrays.fill(feedbackArray, 0); + double totalLoss = 0; + double totalWeight = 0; + DenseVector cumGradientsWrapper = new DenseVector(feedbackArray); + for (LabeledPointWithWeight dataPoint : miniBatchData) { + totalLoss += lossFunc.computeLoss(dataPoint, coefficient); + lossFunc.computeGradient(dataPoint, coefficient, cumGradientsWrapper); + totalWeight += dataPoint.getWeight(); + } + setTotalLoss(totalLoss); + setTotalWeight(totalWeight); + + collector.collect(feedbackArray); + } + } + + @Override + public void onIterationTerminated(Context context, Collector<double[]> collector) { + trainDataState.clear(); + if (getRuntimeContext().getIndexOfThisSubtask() == 0) { + updateModel(); + context.output(modelDataOutputTag, coefficient); + } + } + + @Override + public void processElement1(StreamRecord<LabeledPointWithWeight> streamRecord) + throws Exception { + trainDataState.add(streamRecord.getValue()); + } + + @Override + public void processElement2(StreamRecord<double[]> streamRecord) { + feedbackArray = streamRecord.getValue(); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + coefficientState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "coefficientState", DenseVectorTypeInfo.INSTANCE)); + OperatorStateUtils.getUniqueElement(coefficientState, "coefficientState") + .ifPresent(x -> coefficient = x); + if (coefficient != null) { + coefficientDim = coefficient.size(); + } + + feedbackArrayState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "feedbackArrayState", + PrimitiveArrayTypeInfo + .DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO)); + OperatorStateUtils.getUniqueElement(feedbackArrayState, "feedbackArrayState") + .ifPresent(x -> feedbackArray = x); + + trainDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "trainDataState", + TypeInformation.of(LabeledPointWithWeight.class))); + + nextBatchOffsetState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "nextBatchOffsetState", BasicTypeInfo.INT_TYPE_INFO)); + nextBatchOffset = + OperatorStateUtils.getUniqueElement( + nextBatchOffsetState, "nextBatchOffsetState") + .orElse(0); + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + coefficientState.clear(); + if (coefficient != null) { + coefficientState.add(coefficient); + } + + feedbackArrayState.clear(); + if (feedbackArray != null) { + feedbackArrayState.add(feedbackArray); + } + + nextBatchOffsetState.clear(); + nextBatchOffsetState.add(nextBatchOffset); + } + } + + /** Parameters for {@link SGD}. */ + private static class SGDParams implements Serializable { + public final int maxIter; + public final double learningRate; + public final int globalBatchSize; + public final double tol; + public final double reg; + public final double elasticNet; + + private SGDParams( + int maxIter, + double learningRate, + int globalBatchSize, + double tol, + double reg, + double elasticNet) { + this.maxIter = maxIter; + this.learningRate = learningRate; + this.globalBatchSize = globalBatchSize; + this.tol = tol; + this.reg = reg; + this.elasticNet = elasticNet; + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNet.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNet.java new file mode 100644 index 0000000..467308a --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNet.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty: + * + * <ul> + * <li>If the value is zero, it is L2 penalty. + * <li>If the value is one, it is L1 penalty. + * <li>For value in (0,1), it is a combination of L1 and L2 penalty. + * </ul> + */ +public interface HasElasticNet<T> extends WithParams<T> { + Param<Double> ELASTIC_NET = + new DoubleParam( + "elasticNet", "ElasticNet parameter.", 0.0, ParamValidators.inRange(0.0, 1.0)); + + default double getElasticNet() { + return get(ELASTIC_NET); + } + + default T setElasticNet(Double value) { + return set(ELASTIC_NET, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java new file mode 100644 index 0000000..7edffc4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.linearregression; + +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.LeastSquareLoss; +import org.apache.flink.ml.common.optimizer.Optimizer; +import org.apache.flink.ml.common.optimizer.SGD; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the linear regression algorithm. + * + * <p>See https://en.wikipedia.org/wiki/Linear_regression. + */ +public class LinearRegression + implements Estimator<LinearRegression, LinearRegressionModel>, + LinearRegressionParams<LinearRegression> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public LinearRegression() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + @SuppressWarnings({"rawTypes", "ConstantConditions"}) + public LinearRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<LabeledPointWithWeight> trainData = + tEnv.toDataStream(inputs[0]) + .map( + dataPoint -> { + double weight = + getWeightCol() == null + ? 1.0 + : (Double) dataPoint.getField(getWeightCol()); + double label = (Double) dataPoint.getField(getLabelCol()); + DenseVector features = + (DenseVector) dataPoint.getField(getFeaturesCol()); + return new LabeledPointWithWeight(features, label, weight); + }); + + DataStream<DenseVector> initModelData = + DataStreamUtils.reduce( + trainData.map(x -> x.getFeatures().size()), + (ReduceFunction<Integer>) + (t0, t1) -> { + Preconditions.checkState( + t0.equals(t1), + "The training data should all have same dimensions."); + return t0; + }) + .map(DenseVector::new); + + Optimizer optimizer = + new SGD( + getMaxIter(), + getLearningRate(), + getGlobalBatchSize(), + getTol(), + getReg(), + getElasticNet()); + DataStream<DenseVector> rawModelData = + optimizer.optimize(initModelData, trainData, LeastSquareLoss.INSTANCE); + + DataStream<LinearRegressionModelData> modelData = + rawModelData.map(LinearRegressionModelData::new); + LinearRegressionModel model = + new LinearRegressionModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static LinearRegression load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java similarity index 68% copy from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java copy to flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java index 3c30a8d..7fb9019 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java @@ -16,19 +16,16 @@ * limitations under the License. */ -package org.apache.flink.ml.classification.logisticregression; +package org.apache.flink.ml.regression.linearregression; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -46,53 +43,19 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */ -public class LogisticRegressionModel - implements Model<LogisticRegressionModel>, - LogisticRegressionModelParams<LogisticRegressionModel> { +/** A Model which predicts data using the model data computed by {@link LinearRegression}. */ +public class LinearRegressionModel + implements Model<LinearRegressionModel>, + LinearRegressionModelParams<LinearRegressionModel> { private final Map<Param<?>, Object> paramMap = new HashMap<>(); private Table modelDataTable; - public LogisticRegressionModel() { + public LinearRegressionModel() { ParamUtils.initializeMapWithDefaultValues(paramMap, this); } - @Override - public Map<Param<?>, Object> getParamMap() { - return paramMap; - } - - @Override - public void save(String path) throws IOException { - ReadWriteUtils.saveMetadata(this, path); - ReadWriteUtils.saveModelData( - LogisticRegressionModelData.getModelDataStream(modelDataTable), - path, - new LogisticRegressionModelData.ModelDataEncoder()); - } - - public static LogisticRegressionModel load(StreamTableEnvironment tEnv, String path) - throws IOException { - LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); - Table modelDataTable = - ReadWriteUtils.loadModelData( - tEnv, path, new LogisticRegressionModelData.ModelDataDecoder()); - return model.setModelData(modelDataTable); - } - - @Override - public LogisticRegressionModel setModelData(Table... inputs) { - modelDataTable = inputs[0]; - return this; - } - - @Override - public Table[] getModelData() { - return new Table[] {modelDataTable}; - } - @Override @SuppressWarnings("unchecked") public Table[] transform(Table... inputs) { @@ -101,19 +64,14 @@ public class LogisticRegressionModel (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]); final String broadcastModelKey = "broadcastModelKey"; - DataStream<LogisticRegressionModelData> modelDataStream = - LogisticRegressionModelData.getModelDataStream(modelDataTable); + DataStream<LinearRegressionModelData> modelDataStream = + LinearRegressionModelData.getModelDataStream(modelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); RowTypeInfo outputTypeInfo = new RowTypeInfo( ArrayUtils.addAll( - inputTypeInfo.getFieldTypes(), - BasicTypeInfo.DOUBLE_TYPE_INFO, - TypeInformation.of(DenseVector.class)), - ArrayUtils.addAll( - inputTypeInfo.getFieldNames(), - getPredictionCol(), - getRawPredictionCol())); + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); DataStream<Row> predictionResult = BroadcastUtils.withBroadcastStream( Collections.singletonList(inputStream), @@ -127,6 +85,40 @@ public class LogisticRegressionModel return new Table[] {tEnv.fromDataStream(predictionResult)}; } + @Override + public LinearRegressionModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + LinearRegressionModelData.getModelDataStream(modelDataTable), + path, + new LinearRegressionModelData.ModelDataEncoder()); + } + + public static LinearRegressionModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + LinearRegressionModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new LinearRegressionModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + /** A utility function used for prediction. */ private static class PredictLabelFunction extends RichMapFunction<Row, Row> { @@ -144,28 +136,25 @@ public class LogisticRegressionModel @Override public Row map(Row dataPoint) { if (coefficient == null) { - LogisticRegressionModelData modelData = - (LogisticRegressionModelData) + LinearRegressionModelData modelData = + (LinearRegressionModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); coefficient = modelData.coefficient; } DenseVector features = (DenseVector) dataPoint.getField(featuresCol); - Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient); - return Row.join(dataPoint, Row.of(predictionResult.f0, predictionResult.f1)); + Row predictionResult = predictOneDataPoint(features, coefficient); + return Row.join(dataPoint, predictionResult); } } /** - * The main logic that predicts one input record. + * The main logic that predicts one input data point. * * @param feature The input feature. * @param coefficient The model parameters. * @return The prediction label and the raw probabilities. */ - private static Tuple2<Double, DenseVector> predictRaw( - DenseVector feature, DenseVector coefficient) { - double dotValue = BLAS.dot(feature, coefficient); - double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); - return new Tuple2<>(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); + private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) { + return Row.of(BLAS.dot(feature, coefficient)); } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java new file mode 100644 index 0000000..278d5b3 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.linearregression; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LinearRegressionModel}. + * + * <p>This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LinearRegressionModelData { + + public DenseVector coefficient; + + public LinearRegressionModelData(DenseVector coefficient) { + this.coefficient = coefficient; + } + + public LinearRegressionModelData() {} + + /** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ + public static DataStream<LinearRegressionModelData> getModelDataStream(Table modelData) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); + return tEnv.toDataStream(modelData) + .map(x -> new LinearRegressionModelData((DenseVector) x.getField(0))); + } + + /** Data encoder for {@link LinearRegressionModel}. */ + public static class ModelDataEncoder implements Encoder<LinearRegressionModelData> { + + @Override + public void encode(LinearRegressionModelData modelData, OutputStream outputStream) + throws IOException { + DenseVectorSerializer.INSTANCE.serialize( + modelData.coefficient, new DataOutputViewStreamWrapper(outputStream)); + } + } + + /** Data decoder for {@link LinearRegressionModel}. */ + public static class ModelDataDecoder extends SimpleStreamFormat<LinearRegressionModelData> { + + @Override + public Reader<LinearRegressionModelData> createReader( + Configuration configuration, FSDataInputStream inputStream) { + return new Reader<LinearRegressionModelData>() { + + @Override + public LinearRegressionModelData read() throws IOException { + try { + DenseVector coefficient = + DenseVectorSerializer.INSTANCE.deserialize( + new DataInputViewStreamWrapper(inputStream)); + return new LinearRegressionModelData(coefficient); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + inputStream.close(); + } + }; + } + + @Override + public TypeInformation<LinearRegressionModelData> getProducedType() { + return TypeInformation.of(LinearRegressionModelData.class); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelParams.java new file mode 100644 index 0000000..687fed6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelParams.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression.linearregression; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; + +/** + * Params for {@link LinearRegressionModel}. + * + * @param <T> The class type of this instance. + */ +public interface LinearRegressionModelParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> {} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionParams.java similarity index 83% copy from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java copy to flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionParams.java index c9b5919..a5ca6cc 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionParams.java @@ -16,29 +16,29 @@ * limitations under the License. */ -package org.apache.flink.ml.classification.logisticregression; +package org.apache.flink.ml.regression.linearregression; +import org.apache.flink.ml.common.param.HasElasticNet; import org.apache.flink.ml.common.param.HasGlobalBatchSize; import org.apache.flink.ml.common.param.HasLabelCol; import org.apache.flink.ml.common.param.HasLearningRate; import org.apache.flink.ml.common.param.HasMaxIter; -import org.apache.flink.ml.common.param.HasMultiClass; import org.apache.flink.ml.common.param.HasReg; import org.apache.flink.ml.common.param.HasTol; import org.apache.flink.ml.common.param.HasWeightCol; /** - * Params for {@link LogisticRegression}. + * Params for {@link LinearRegression}. * * @param <T> The class type of this instance. */ -public interface LogisticRegressionParams<T> +public interface LinearRegressionParams<T> extends HasLabelCol<T>, HasWeightCol<T>, HasMaxIter<T>, HasReg<T>, + HasElasticNet<T>, HasLearningRate<T>, HasGlobalBatchSize<T>, HasTol<T>, - HasMultiClass<T>, - LogisticRegressionModelParams<T> {} + LinearRegressionModelParams<T> {} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index 88882c6..8d06613 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -28,6 +28,7 @@ import org.apache.flink.ml.classification.logisticregression.LogisticRegressionM import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.ml.util.StageTestUtils; import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; @@ -37,6 +38,7 @@ import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.types.Row; import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.RandomUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.junit.Before; import org.junit.Rule; @@ -49,7 +51,6 @@ import java.util.List; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -90,7 +91,7 @@ public class LogisticRegressionTest { Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); private static final double[] expectedCoefficient = - new double[] {0.528, -0.286, -0.429, -0.572}; + new double[] {0.525, -0.283, -0.425, -0.567}; private static final double TOLERANCE = 1e-7; @@ -114,9 +115,7 @@ public class LogisticRegressionTest { binomialTrainData, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(DenseVector.class), - Types.DOUBLE, - Types.DOUBLE + DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, new String[] {"features", "label", "weight"}))); multinomialDataTable = @@ -125,14 +124,12 @@ public class LogisticRegressionTest { multinomialTrainData, new RowTypeInfo( new TypeInformation[] { - TypeInformation.of(DenseVector.class), - Types.DOUBLE, - Types.DOUBLE + DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, new String[] {"features", "label", "weight"}))); } - @SuppressWarnings("ConstantConditions") + @SuppressWarnings("ConstantConditions, unchecked") private void verifyPredictionResult( Table output, String featuresCol, String predictionCol, String rawPredictionCol) throws Exception { @@ -154,17 +151,18 @@ public class LogisticRegressionTest { @Test public void testParam() { LogisticRegression logisticRegression = new LogisticRegression(); - assertEquals(logisticRegression.getLabelCol(), "label"); + assertEquals("label", logisticRegression.getLabelCol()); assertNull(logisticRegression.getWeightCol()); - assertEquals(logisticRegression.getMaxIter(), 20); - assertEquals(logisticRegression.getReg(), 0, TOLERANCE); - assertEquals(logisticRegression.getLearningRate(), 0.1, TOLERANCE); - assertEquals(logisticRegression.getGlobalBatchSize(), 32); - assertEquals(logisticRegression.getTol(), 1e-6, TOLERANCE); - assertEquals(logisticRegression.getMultiClass(), "auto"); - assertEquals(logisticRegression.getFeaturesCol(), "features"); - assertEquals(logisticRegression.getPredictionCol(), "prediction"); - assertEquals(logisticRegression.getRawPredictionCol(), "rawPrediction"); + assertEquals(20, logisticRegression.getMaxIter()); + assertEquals(0, logisticRegression.getReg(), TOLERANCE); + assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE); + assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE); + assertEquals(32, logisticRegression.getGlobalBatchSize()); + assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE); + assertEquals("auto", logisticRegression.getMultiClass()); + assertEquals("features", logisticRegression.getFeaturesCol()); + assertEquals("prediction", logisticRegression.getPredictionCol()); + assertEquals("rawPrediction", logisticRegression.getRawPredictionCol()); logisticRegression .setFeaturesCol("test_features") @@ -175,20 +173,22 @@ public class LogisticRegressionTest { .setLearningRate(0.5) .setGlobalBatchSize(1000) .setReg(0.1) + .setElasticNet(0.5) .setMultiClass("binomial") .setPredictionCol("test_predictionCol") .setRawPredictionCol("test_rawPredictionCol"); - assertEquals(logisticRegression.getFeaturesCol(), "test_features"); - assertEquals(logisticRegression.getLabelCol(), "test_label"); - assertEquals(logisticRegression.getWeightCol(), "test_weight"); - assertEquals(logisticRegression.getMaxIter(), 1000); - assertEquals(logisticRegression.getTol(), 0.001, TOLERANCE); - assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE); - assertEquals(logisticRegression.getGlobalBatchSize(), 1000); - assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE); - assertEquals(logisticRegression.getMultiClass(), "binomial"); - assertEquals(logisticRegression.getPredictionCol(), "test_predictionCol"); - assertEquals(logisticRegression.getRawPredictionCol(), "test_rawPredictionCol"); + assertEquals("test_features", logisticRegression.getFeaturesCol()); + assertEquals("test_label", logisticRegression.getLabelCol()); + assertEquals("test_weight", logisticRegression.getWeightCol()); + assertEquals(1000, logisticRegression.getMaxIter()); + assertEquals(0.001, logisticRegression.getTol(), TOLERANCE); + assertEquals(0.5, logisticRegression.getLearningRate(), TOLERANCE); + assertEquals(1000, logisticRegression.getGlobalBatchSize()); + assertEquals(0.1, logisticRegression.getReg(), TOLERANCE); + assertEquals(0.5, logisticRegression.getElasticNet(), TOLERANCE); + assertEquals("binomial", logisticRegression.getMultiClass()); + assertEquals("test_predictionCol", logisticRegression.getPredictionCol()); + assertEquals("test_rawPredictionCol", logisticRegression.getRawPredictionCol()); } @Test @@ -243,15 +243,16 @@ public class LogisticRegressionTest { } @Test + @SuppressWarnings("unchecked") public void testGetModelData() throws Exception { LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); - LogisticRegressionModelData modelData = - LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) - .executeAndCollect() - .next(); - assertNotNull(modelData); - assertArrayEquals(expectedCoefficient, modelData.coefficient.values, 0.1); + List<LogisticRegressionModelData> modelData = + IteratorUtils.toList( + LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + assertEquals(1, modelData.size()); + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); } @Test @@ -286,7 +287,8 @@ public class LogisticRegressionTest { @Test public void testMoreSubtaskThanData() throws Exception { env.setParallelism(12); - LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegression logisticRegression = + new LogisticRegression().setWeightCol("weight").setGlobalBatchSize(128); Table output = logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0]; verifyPredictionResult( output, @@ -294,4 +296,29 @@ public class LogisticRegressionTest { logisticRegression.getPredictionCol(), logisticRegression.getRawPredictionCol()); } + + @Test + public void testRegularization() throws Exception { + checkRegularization(0, RandomUtils.nextDouble(0, 1), expectedCoefficient); + checkRegularization(0.1, 0, new double[] {0.484, -0.258, -0.388, -0.517}); + checkRegularization(0.1, 1, new double[] {0.417, -0.145, -0.312, -0.480}); + checkRegularization(0.1, 0.5, new double[] {0.451, -0.203, -0.351, -0.498}); + } + + @SuppressWarnings("unchecked") + private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) + throws Exception { + LogisticRegressionModel model = + new LogisticRegression() + .setWeightCol("weight") + .setReg(reg) + .setElasticNet(elasticNet) + .fit(binomialDataTable); + List<LogisticRegressionModelData> modelData = + IteratorUtils.toList( + LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + final double errorTol = 1e-3; + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, errorTol); + } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java new file mode 100644 index 0000000..22ce2de --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link BinaryLogisticLoss}. */ +public class BinaryLogisticLossTest { + private static final LabeledPointWithWeight dataPoint = + new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0); + private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final double TOLERANCE = 1e-7; + + @Test + public void computeLoss() { + double loss = BinaryLogisticLoss.INSTANCE.computeLoss(dataPoint, coefficient); + assertEquals(0.0049513, loss, TOLERANCE); + } + + @Test + public void computeGradient() { + BinaryLogisticLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient); + assertArrayEquals( + new double[] {-0.0049452, -0.0098904, -0.0148357}, cumGradient.values, TOLERANCE); + BinaryLogisticLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient); + assertArrayEquals( + new double[] {-0.0098904, -0.0197809, -0.0296714}, cumGradient.values, TOLERANCE); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java new file mode 100644 index 0000000..ee2d030 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link LeastSquareLoss}. */ +public class LeastSquareLossTest { + private static final LabeledPointWithWeight dataPoint = + new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0); + private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 1.0); + private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 0.0); + private static final double TOLERANCE = 1e-7; + + @Test + public void computeLoss() { + double loss = LeastSquareLoss.INSTANCE.computeLoss(dataPoint, coefficient); + assertEquals(25.0, loss, TOLERANCE); + } + + @Test + public void computeGradient() { + LeastSquareLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient); + assertArrayEquals(new double[] {10.0, 20.0, 30.0}, cumGradient.values, TOLERANCE); + LeastSquareLoss.INSTANCE.computeGradient(dataPoint, coefficient, cumGradient); + assertArrayEquals(new double[] {20.0, 40.0, 60.0}, cumGradient.values, TOLERANCE); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java new file mode 100644 index 0000000..83d6883 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.optimizer; + +import org.apache.flink.ml.linalg.DenseVector; + +import org.apache.commons.lang3.RandomUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** Tests {@link RegularizationUtils}. */ +public class RegularizationUtilsTest { + private static final double learningRate = 0.1; + private static final double TOLERANCE = 1e-7; + private static final DenseVector coefficient = new DenseVector(new double[] {1.0, -2.0, 0}); + + @Test + public void testRegularization() { + checkRegularization(0, RandomUtils.nextDouble(0, 1), new double[] {1, -2.0, 0}); + checkRegularization(0.1, 0, new double[] {0.99, -1.98, 0}); + checkRegularization(0.1, 1, new double[] {0.99, -1.99, 0}); + checkRegularization(0.1, 0.1, new double[] {0.99, -1.981, 0}); + } + + private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) { + DenseVector clonedCoefficient = coefficient.clone(); + RegularizationUtils.regularize(clonedCoefficient, reg, elasticNet, learningRate); + assertArrayEquals(expectedCoefficient, clonedCoefficient.values, TOLERANCE); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java new file mode 100644 index 0000000..58e89c4 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.regression.linearregression.LinearRegression; +import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; +import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.RandomUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */ +public class LinearRegressionTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List<Row> trainData = + Arrays.asList( + Row.of(Vectors.dense(2, 1), 4.0, 1.0), + Row.of(Vectors.dense(3, 2), 7.0, 1.0), + Row.of(Vectors.dense(4, 3), 10.0, 1.0), + Row.of(Vectors.dense(2, 4), 10.0, 1.0), + Row.of(Vectors.dense(2, 2), 6.0, 1.0), + Row.of(Vectors.dense(4, 3), 10.0, 1.0), + Row.of(Vectors.dense(1, 2), 5.0, 1.0), + Row.of(Vectors.dense(5, 3), 11.0, 1.0)); + + private static final double[] expectedCoefficient = new double[] {1.141, 1.829}; + + private static final double TOLERANCE = 1e-7; + + private static final double PREDICTION_TOLERANCE = 0.1; + + private static final double COEFFICIENT_TOLERANCE = 0.1; + + private Table trainDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(trainData); + trainDataTable = + tEnv.fromDataStream( + env.fromCollection( + trainData, + new RowTypeInfo( + new TypeInformation[] { + DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + } + + @SuppressWarnings("unchecked") + private void verifyPredictionResult(Table output, String labelCol, String predictionCol) + throws Exception { + List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + double label = (double) predictionRow.getField(labelCol); + double prediction = (double) predictionRow.getField(predictionCol); + assertTrue(Math.abs(prediction - label) / label < PREDICTION_TOLERANCE); + } + } + + @Test + public void testParam() { + LinearRegression linearRegression = new LinearRegression(); + assertEquals("label", linearRegression.getLabelCol()); + assertNull(linearRegression.getWeightCol()); + assertEquals(20, linearRegression.getMaxIter()); + assertEquals(0, linearRegression.getReg(), TOLERANCE); + assertEquals(0, linearRegression.getElasticNet(), TOLERANCE); + assertEquals(0.1, linearRegression.getLearningRate(), TOLERANCE); + assertEquals(32, linearRegression.getGlobalBatchSize()); + assertEquals(1e-6, linearRegression.getTol(), TOLERANCE); + assertEquals("features", linearRegression.getFeaturesCol()); + assertEquals("prediction", linearRegression.getPredictionCol()); + + linearRegression + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setLearningRate(0.5) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setElasticNet(0.5) + .setPredictionCol("test_predictionCol"); + assertEquals("test_features", linearRegression.getFeaturesCol()); + assertEquals("test_label", linearRegression.getLabelCol()); + assertEquals("test_weight", linearRegression.getWeightCol()); + assertEquals(1000, linearRegression.getMaxIter()); + assertEquals(0.001, linearRegression.getTol(), TOLERANCE); + assertEquals(0.5, linearRegression.getLearningRate(), TOLERANCE); + assertEquals(1000, linearRegression.getGlobalBatchSize()); + assertEquals(0.1, linearRegression.getReg(), TOLERANCE); + assertEquals(0.5, linearRegression.getElasticNet(), TOLERANCE); + assertEquals("test_predictionCol", linearRegression.getPredictionCol()); + } + + @Test + public void testOutputSchema() { + Table tempTable = trainDataTable.as("test_features", "test_label", "test_weight"); + LinearRegression linearRegression = + new LinearRegression() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol"); + Table output = linearRegression.fit(trainDataTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList("test_features", "test_label", "test_weight", "test_predictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + LinearRegression linearRegression = new LinearRegression().setWeightCol("weight"); + Table output = linearRegression.fit(trainDataTable).transform(trainDataTable)[0]; + verifyPredictionResult( + output, linearRegression.getLabelCol(), linearRegression.getPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + LinearRegression linearRegression = new LinearRegression().setWeightCol("weight"); + linearRegression = + StageTestUtils.saveAndReload( + tEnv, linearRegression, tempFolder.newFolder().getAbsolutePath()); + LinearRegressionModel model = linearRegression.fit(trainDataTable); + model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + assertEquals( + Collections.singletonList("coefficient"), + model.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = model.transform(trainDataTable)[0]; + verifyPredictionResult( + output, linearRegression.getLabelCol(), linearRegression.getPredictionCol()); + } + + @Test + public void testGetModelData() throws Exception { + LinearRegression linearRegression = new LinearRegression().setWeightCol("weight"); + LinearRegressionModel model = linearRegression.fit(trainDataTable); + List<LinearRegressionModelData> modelData = + IteratorUtils.toList( + LinearRegressionModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + assertNotNull(modelData); + assertEquals(1, modelData.size()); + assertArrayEquals( + expectedCoefficient, modelData.get(0).coefficient.values, COEFFICIENT_TOLERANCE); + } + + @Test + public void testSetModelData() throws Exception { + LinearRegression linearRegression = new LinearRegression().setWeightCol("weight"); + LinearRegressionModel model = linearRegression.fit(trainDataTable); + + LinearRegressionModel newModel = new LinearRegressionModel(); + ReadWriteUtils.updateExistingParams(newModel, model.getParamMap()); + newModel.setModelData(model.getModelData()); + Table output = newModel.transform(trainDataTable)[0]; + verifyPredictionResult( + output, linearRegression.getLabelCol(), linearRegression.getPredictionCol()); + } + + @Test + public void testMoreSubtaskThanData() throws Exception { + env.setParallelism(12); + LinearRegression linearRegression = + new LinearRegression().setWeightCol("weight").setGlobalBatchSize(128); + Table output = linearRegression.fit(trainDataTable).transform(trainDataTable)[0]; + verifyPredictionResult( + output, linearRegression.getLabelCol(), linearRegression.getPredictionCol()); + } + + @Test + public void testRegularization() throws Exception { + checkRegularization(0, RandomUtils.nextDouble(0, 1), expectedCoefficient); + checkRegularization(0.1, 0, new double[] {1.165, 1.780}); + checkRegularization(0.1, 1, new double[] {1.143, 1.812}); + checkRegularization(0.1, 0.5, new double[] {1.154, 1.796}); + } + + @SuppressWarnings("unchecked") + private void checkRegularization(double reg, double elasticNet, double[] expectedCoefficient) + throws Exception { + LinearRegressionModel model = + new LinearRegression() + .setWeightCol("weight") + .setReg(reg) + .setElasticNet(elasticNet) + .fit(trainDataTable); + List<LinearRegressionModelData> modelData = + IteratorUtils.toList( + LinearRegressionModelData.getModelDataStream(model.getModelData()[0]) + .executeAndCollect()); + final double errorTol = 1e-3; + assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, errorTol); + } +}