This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push: new 239788f [FLINK-27170] Add Transformer and Estimator of OnlineLogistciRegression 239788f is described below commit 239788f2b1f1f3a4e55ca112517980b598705a15 Author: weibo <wbz...@pku.edu.cn> AuthorDate: Thu Jun 2 10:42:09 2022 +0800 [FLINK-27170] Add Transformer and Estimator of OnlineLogistciRegression This closes #83. --- .../ml/common/datastream/DataStreamUtils.java | 74 +++ .../main/java/org/apache/flink/ml/linalg/BLAS.java | 46 +- .../ml/common/datastream/DataStreamUtilsTest.java | 13 + .../java/org/apache/flink/ml/linalg/BLASTest.java | 11 + .../logisticregression/LogisticRegression.java | 2 +- .../LogisticRegressionModel.java | 2 +- .../LogisticRegressionModelData.java | 66 +- .../OnlineLogisticRegression.java | 424 +++++++++++++ .../OnlineLogisticRegressionModel.java | 198 ++++++ .../OnlineLogisticRegressionModelParams.java | 50 ++ .../OnlineLogisticRegressionParams.java | 66 ++ .../flink/ml/clustering/kmeans/OnlineKMeans.java | 57 +- .../ml/classification/LogisticRegressionTest.java | 2 +- .../OnlineLogisticRegressionTest.java | 687 +++++++++++++++++++++ .../apache/flink/ml/feature/MinMaxScalerTest.java | 15 +- .../java/org/apache/flink/ml/util/TestUtils.java | 14 + .../ml/lib/classification/logisticregression.py | 121 +++- .../tests/test_logisticregression.py | 41 +- 18 files changed, 1803 insertions(+), 86 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java index 10073b8..45ad02e 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java @@ -19,21 +19,32 @@ package org.apache.flink.ml.common.datastream; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.runtime.state.StateSnapshotContext; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; import org.apache.flink.streaming.api.operators.BoundedOneInput; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.Collector; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.Arrays; +import java.util.List; /** Provides utility functions for {@link DataStream}. */ @Internal @@ -182,4 +193,67 @@ public class DataStreamUtils { } } } + + /** + * Splits the input data into global batches of batchSize. After splitting, each global batch is + * further split into local batches for downstream operators with each worker has one batch. + */ + public static <T> DataStream<T[]> generateBatchData( + DataStream<T> inputData, final int downStreamParallelism, int batchSize) { + return inputData + .countWindowAll(batchSize) + .apply(new GlobalBatchCreator<>()) + .flatMap(new GlobalBatchSplitter<>(downStreamParallelism)) + .partitionCustom((chunkId, numPartitions) -> chunkId, x -> x.f0) + .map( + new MapFunction<Tuple2<Integer, T[]>, T[]>() { + @Override + public T[] map(Tuple2<Integer, T[]> integerTuple2) throws Exception { + return integerTuple2.f1; + } + }); + } + + /** Splits the input data into global batches. */ + private static class GlobalBatchCreator<T> implements AllWindowFunction<T, T[], GlobalWindow> { + @Override + public void apply(GlobalWindow timeWindow, Iterable<T> iterable, Collector<T[]> collector) { + List<T> points = IteratorUtils.toList(iterable.iterator()); + collector.collect(points.toArray((T[]) new Object[0])); + } + } + + /** + * An operator that splits a global batch into evenly-sized local batches, and distributes them + * to downstream operator. + */ + private static class GlobalBatchSplitter<T> + implements FlatMapFunction<T[], Tuple2<Integer, T[]>> { + private final int downStreamParallelism; + + public GlobalBatchSplitter(int downStreamParallelism) { + this.downStreamParallelism = downStreamParallelism; + } + + @Override + public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) { + int div = values.length / downStreamParallelism; + int mod = values.length % downStreamParallelism; + + int offset = 0; + int i = 0; + + int size = div + 1; + for (; i < mod; i++) { + collector.collect(Tuple2.of(i, Arrays.copyOfRange(values, offset, offset + size))); + offset += size; + } + + size = div; + for (; i < downStreamParallelism; i++) { + collector.collect(Tuple2.of(i, Arrays.copyOfRange(values, offset, offset + size))); + offset += size; + } + } + } } diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java index 24cc3ef..c00f642 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java @@ -65,12 +65,54 @@ public class BLAS { } } - /** x \cdot y . */ - public static double dot(DenseVector x, DenseVector y) { + /** Computes the dot of the two vectors (y \dot x). */ + public static double dot(Vector x, Vector y) { Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); + if (x instanceof SparseVector) { + if (y instanceof SparseVector) { + return dot((SparseVector) x, (SparseVector) y); + } else { + return dot((DenseVector) y, (SparseVector) x); + } + } else { + if (y instanceof SparseVector) { + return dot((DenseVector) x, (SparseVector) y); + } else { + return dot((DenseVector) x, (DenseVector) y); + } + } + } + + private static double dot(DenseVector x, DenseVector y) { return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1); } + private static double dot(DenseVector x, SparseVector y) { + double dotValue = 0.0; + for (int i = 0; i < y.indices.length; ++i) { + dotValue += y.values[i] * x.values[y.indices[i]]; + } + return dotValue; + } + + private static double dot(SparseVector x, SparseVector y) { + double dotValue = 0; + int p0 = 0; + int p1 = 0; + while (p0 < x.values.length && p1 < y.values.length) { + if (x.indices[p0] == y.indices[p1]) { + dotValue += x.values[p0] * y.values[p1]; + p0++; + p1++; + } else if (x.indices[p0] < y.indices[p1]) { + p0++; + } else { + p1++; + } + } + return dotValue; + } + /** \sqrt(\sum_i x_i * x_i) . */ public static double norm2(DenseVector x) { return JAVA_BLAS.dnrm2(x.size(), x.values, 1); diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java index 7dc88c8..a968a0e 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java @@ -37,6 +37,7 @@ import org.junit.Test; import java.util.List; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; /** Tests the {@link DataStreamUtils}. */ @@ -74,6 +75,18 @@ public class DataStreamUtilsTest { assertArrayEquals(new long[] {190L}, sum.stream().mapToLong(Long::longValue).toArray()); } + @Test + public void testGenerateBatchData() throws Exception { + DataStream<Long> dataStream = + env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG); + DataStream<Long[]> result = DataStreamUtils.generateBatchData(dataStream, 2, 4); + List<Long[]> batches = IteratorUtils.toList(result.executeAndCollect()); + for (Long[] batch : batches) { + assertEquals(2, batch.length); + } + assertEquals(10, batches.size()); + } + /** A simple implementation for a {@link MapPartitionFunction}. */ private static class TestMapPartitionFunc extends RichMapPartitionFunction<Long, Integer> { diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java index 7055c62..21d68a9 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java @@ -70,7 +70,18 @@ public class BLASTest { @Test public void testDot() { DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5); + SparseVector sparseVector1 = + Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {1., 1., 4.}); + SparseVector sparseVector2 = + Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {1., 2., 1.}); + // Tests dot(dense, dense). assertEquals(-3, BLAS.dot(inputDenseVec, anotherDenseVec), TOLERANCE); + // Tests dot(dense, sparse). + assertEquals(-19, BLAS.dot(inputDenseVec, sparseVector1), TOLERANCE); + // Tests dot(sparse, dense). + assertEquals(1, BLAS.dot(sparseVector2, inputDenseVec), TOLERANCE); + // Tests dot(sparse, sparse). + assertEquals(5, BLAS.dot(sparseVector1, sparseVector2), TOLERANCE); } @Test diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index df8c386..551b66a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -115,7 +115,7 @@ public class LogisticRegression optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE); DataStream<LogisticRegressionModelData> modelData = - rawModelData.map(LogisticRegressionModelData::new); + rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0)); LogisticRegressionModel model = new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); ReadWriteUtils.updateExistingParams(model, paramMap); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index ac42142..675846a 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -162,7 +162,7 @@ public class LogisticRegressionModel * @param coefficient The model parameters. * @return The prediction label and the raw probabilities. */ - private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) { + protected static Row predictOneDataPoint(Vector feature, DenseVector coefficient) { double dotValue = BLAS.dot(feature, coefficient); double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); return Row.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java index d2f451f..a9a6285 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java @@ -18,6 +18,7 @@ package org.apache.flink.ml.classification.logisticregression; +import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.serialization.Encoder; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; @@ -25,10 +26,11 @@ import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.api.internal.TableImpl; @@ -36,9 +38,10 @@ import org.apache.flink.table.api.internal.TableImpl; import java.io.EOFException; import java.io.IOException; import java.io.OutputStream; +import java.util.Random; /** - * Model data of {@link LogisticRegressionModel}. + * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}. * * <p>This class also provides methods to convert model data from Table to Datastream, and classes * to save/load model data. @@ -46,12 +49,49 @@ import java.io.OutputStream; public class LogisticRegressionModelData { public DenseVector coefficient; + public long modelVersion; - public LogisticRegressionModelData(DenseVector coefficient) { + public LogisticRegressionModelData() {} + + public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) { this.coefficient = coefficient; + this.modelVersion = modelVersion; } - public LogisticRegressionModelData() {} + /** + * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly + * generated coefficient. + * + * @param tEnv The environment where to create the table. + * @param dim The size of generated coefficient. + * @param seed Random seed. + */ + public static Table generateRandomModelData(StreamTableEnvironment tEnv, int dim, int seed) { + StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); + return tEnv.fromDataStream( + env.fromElements(1).map(new RandomModelDataGenerator(dim, seed))); + } + + private static class RandomModelDataGenerator + implements MapFunction<Integer, LogisticRegressionModelData> { + private final int dim; + private final int seed; + + public RandomModelDataGenerator(int dim, int seed) { + this.dim = dim; + this.seed = seed; + } + + @Override + public LogisticRegressionModelData map(Integer integer) throws Exception { + DenseVector vector = new DenseVector(dim); + Random random = new Random(seed); + for (int j = 0; j < dim; j++) { + vector.values[j] = random.nextDouble(); + } + return new LogisticRegressionModelData(vector, 0L); + } + } /** * Converts the table model to a data stream. @@ -63,21 +103,24 @@ public class LogisticRegressionModelData { StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); return tEnv.toDataStream(modelData) - .map(x -> new LogisticRegressionModelData(((Vector) x.getField(0)).toDense())); + .map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1))); } - /** Data encoder for {@link LogisticRegressionModel}. */ + /** Data encoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ public static class ModelDataEncoder implements Encoder<LogisticRegressionModelData> { @Override public void encode(LogisticRegressionModelData modelData, OutputStream outputStream) throws IOException { + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(outputStream); DenseVectorSerializer.INSTANCE.serialize( - modelData.coefficient, new DataOutputViewStreamWrapper(outputStream)); + modelData.coefficient, dataOutputViewStreamWrapper); + dataOutputViewStreamWrapper.writeLong(modelData.modelVersion); } } - /** Data decoder for {@link LogisticRegressionModel}. */ + /** Data decoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ public static class ModelDataDecoder extends SimpleStreamFormat<LogisticRegressionModelData> { @Override @@ -88,10 +131,13 @@ public class LogisticRegressionModelData { @Override public LogisticRegressionModelData read() throws IOException { try { + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(inputStream); DenseVector coefficient = DenseVectorSerializer.INSTANCE.deserialize( - new DataInputViewStreamWrapper(inputStream)); - return new LogisticRegressionModelData(coefficient); + dataInputViewStreamWrapper); + long modelVersion = dataInputViewStreamWrapper.readLong(); + return new LogisticRegressionModelData(coefficient, modelVersion); } catch (EOFException e) { return null; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java new file mode 100644 index 0000000..59a9102 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the online logistic regression algorithm. The online optimizer of + * this algorithm is The FTRL-Proximal proposed by H.Brendan McMahan et al. + * + * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click + * prediction: a view from the trenches.</a> + */ +public class OnlineLogisticRegression + implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>, + OnlineLogisticRegressionParams<OnlineLogisticRegression> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineLogisticRegression() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineLogisticRegressionModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<LogisticRegressionModelData> modelDataStream = + LogisticRegressionModelData.getModelDataStream(initModelDataTable); + + DataStream<Row> points = + tEnv.toDataStream(inputs[0]) + .map( + new FeaturesLabelExtractor( + getFeaturesCol(), getLabelCol(), getWeightCol())); + + DataStream<DenseVector> initModelData = + modelDataStream.map( + (MapFunction<LogisticRegressionModelData, DenseVector>) + value -> value.coefficient); + + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new FtrlIterationBody( + getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet()); + + DataStream<LogisticRegressionModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineLogisticRegressionModel model = + new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + private static class FeaturesLabelExtractor implements MapFunction<Row, Row> { + private final String featuresCol; + private final String labelCol; + private final String weightCol; + + private FeaturesLabelExtractor(String featuresCol, String labelCol, String weightCol) { + this.featuresCol = featuresCol; + this.labelCol = labelCol; + this.weightCol = weightCol; + } + + @Override + public Row map(Row row) throws Exception { + if (weightCol == null) { + return Row.of(row.getField(featuresCol), row.getField(labelCol)); + } else { + return Row.of( + row.getField(featuresCol), row.getField(labelCol), row.getField(weightCol)); + } + } + } + + /** + * In the implementation of ftrl optimizer, gradients are calculated in distributed workers and + * reduce them to one final gradient. The reduced gradient is used to update model by ftrl + * method. When the feature vector is dense, it can get the same result as tensorflow's ftrl. If + * feature vector is sparse, we use the mean value in every feature dim instead of mean value of + * whole vector, which can get a better convergence. + * + * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl + * + * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common. + */ + private static class FtrlIterationBody implements IterationBody { + private final int batchSize; + private final double alpha; + private final double beta; + private final double l1; + private final double l2; + + public FtrlIterationBody( + int batchSize, double alpha, double beta, double reg, double elasticNet) { + this.batchSize = batchSize; + this.alpha = alpha; + this.beta = beta; + this.l1 = elasticNet * reg; + this.l2 = (1 - elasticNet) * reg; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<DenseVector> modelData = variableStreams.get(0); + + DataStream<Row> points = dataStreams.get(0); + int parallelism = points.getParallelism(); + Preconditions.checkState( + parallelism <= batchSize, + "There are more subtasks in the training process than the number " + + "of elements in each batch. Some subtasks might be idling forever."); + + DataStream<DenseVector[]> newGradient = + DataStreamUtils.generateBatchData(points, parallelism, batchSize) + .connect(modelData.broadcast()) + .transform( + "LocalGradientCalculator", + TypeInformation.of(DenseVector[].class), + new CalculateLocalGradient()) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce( + (ReduceFunction<DenseVector[]>) + (gradientInfo, newGradientInfo) -> { + BLAS.axpy(1.0, gradientInfo[0], newGradientInfo[0]); + BLAS.axpy(1.0, gradientInfo[1], newGradientInfo[1]); + if (newGradientInfo[2] == null) { + newGradientInfo[2] = gradientInfo[2]; + } + return newGradientInfo; + }); + DataStream<DenseVector> feedbackModelData = + newGradient + .transform( + "ModelDataUpdater", + TypeInformation.of(DenseVector.class), + new UpdateModel(alpha, beta, l1, l2)) + .setParallelism(1); + + DataStream<LogisticRegressionModelData> outputModelData = + feedbackModelData.map(new CreateLrModelData()).setParallelism(1); + return new IterationBodyResult( + DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData)); + } + } + + private static class CreateLrModelData + implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction { + private Long modelVersion = 1L; + private transient ListState<Long> modelVersionState; + + @Override + public LogisticRegressionModelData map(DenseVector denseVector) throws Exception { + return new LogisticRegressionModelData(denseVector, modelVersion++); + } + + @Override + public void snapshotState(FunctionSnapshotContext functionSnapshotContext) + throws Exception { + modelVersionState.update(Collections.singletonList(modelVersion)); + } + + @Override + public void initializeState(FunctionInitializationContext context) throws Exception { + modelVersionState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelVersionState", Long.class)); + } + } + + /** Updates model. */ + private static class UpdateModel extends AbstractStreamOperator<DenseVector> + implements OneInputStreamOperator<DenseVector[], DenseVector> { + private ListState<double[]> nParamState; + private ListState<double[]> zParamState; + private final double alpha; + private final double beta; + private final double l1; + private final double l2; + private double[] nParam; + private double[] zParam; + + public UpdateModel(double alpha, double beta, double l1, double l2) { + this.alpha = alpha; + this.beta = beta; + this.l1 = l1; + this.l2 = l2; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + nParamState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("nParamState", double[].class)); + zParamState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("zParamState", double[].class)); + } + + @Override + public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception { + DenseVector[] gradientInfo = streamRecord.getValue(); + double[] coefficient = gradientInfo[2].values; + double[] g = gradientInfo[0].values; + for (int i = 0; i < g.length; ++i) { + if (gradientInfo[1].values[i] != 0.0) { + g[i] = g[i] / gradientInfo[1].values[i]; + } + } + if (zParam == null) { + zParam = new double[g.length]; + nParam = new double[g.length]; + nParamState.add(nParam); + zParamState.add(zParam); + } + + for (int i = 0; i < zParam.length; ++i) { + double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha; + zParam[i] += g[i] - sigma * coefficient[i]; + nParam[i] += g[i] * g[i]; + + if (Math.abs(zParam[i]) <= l1) { + coefficient[i] = 0.0; + } else { + coefficient[i] = + ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i]) + / ((beta + Math.sqrt(nParam[i])) / alpha + l2); + } + } + output.collect(new StreamRecord<>(new DenseVector(coefficient))); + } + } + + private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]> + implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> { + private ListState<DenseVector> modelDataState; + private ListState<Row[]> localBatchDataState; + private double[] gradient; + private double[] weightSum; + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelData", DenseVector.class)); + TypeInformation<Row[]> type = + ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class)); + localBatchDataState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("localBatch", type)); + } + + @Override + public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception { + localBatchDataState.add(pointsRecord.getValue()); + calculateGradient(); + } + + private void calculateGradient() throws Exception { + if (!modelDataState.get().iterator().hasNext() + || !localBatchDataState.get().iterator().hasNext()) { + return; + } + DenseVector modelData = + OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get(); + modelDataState.clear(); + + List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator()); + Row[] points = pointsList.remove(0); + localBatchDataState.update(pointsList); + + for (Row point : points) { + Vector vec = point.getFieldAs(0); + double label = point.getFieldAs(1); + double weight = point.getArity() == 2 ? 1.0 : point.getFieldAs(2); + if (gradient == null) { + gradient = new double[vec.size()]; + weightSum = new double[gradient.length]; + } + double p = BLAS.dot(modelData, vec); + p = 1 / (1 + Math.exp(-p)); + if (vec instanceof DenseVector) { + DenseVector dvec = (DenseVector) vec; + for (int i = 0; i < modelData.size(); ++i) { + gradient[i] += (p - label) * dvec.values[i]; + weightSum[i] += 1.0; + } + } else { + SparseVector svec = (SparseVector) vec; + for (int i = 0; i < svec.indices.length; ++i) { + int idx = svec.indices[i]; + gradient[idx] += (p - label) * svec.values[i]; + weightSum[idx] += weight; + } + } + } + + if (points.length > 0) { + output.collect( + new StreamRecord<>( + new DenseVector[] { + new DenseVector(gradient), + new DenseVector(weightSum), + (getRuntimeContext().getIndexOfThisSubtask() == 0) + ? modelData + : null + })); + } + Arrays.fill(gradient, 0.0); + Arrays.fill(weightSum, 0.0); + } + + @Override + public void processElement2(StreamRecord<DenseVector> modelDataRecord) throws Exception { + modelDataState.add(modelDataRecord.getValue()); + calculateGradient(); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + LogisticRegressionModelData.getModelDataStream(initModelDataTable), + path, + new LogisticRegressionModelData.ModelDataEncoder()); + } + + public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, String path) + throws IOException { + OnlineLogisticRegression onlineLogisticRegression = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new LogisticRegressionModelData.ModelDataDecoder()); + onlineLogisticRegression.setInitialModelData(modelDataTable); + return onlineLogisticRegression; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** + * Sets the initial model data of the online training process with the provided model data + * table. + */ + public OnlineLogisticRegression setInitialModelData(Table initModelDataTable) { + this.initModelDataTable = initModelDataTable; + return this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java new file mode 100644 index 0000000..eab5cf6 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint; + +/** + * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}. + */ +public class OnlineLogisticRegressionModel + implements Model<OnlineLogisticRegressionModel>, + OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> { + public static final String MODEL_DATA_VERSION_GAUGE_KEY = "modelDataVersion"; + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OnlineLogisticRegressionModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + Types.DOUBLE, + TypeInformation.of(DenseVector.class), + Types.LONG), + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), + getPredictionCol(), + getRawPredictionCol(), + getModelVersionCol())); + + DataStream<Row> predictionResult = + tEnv.toDataStream(inputs[0]) + .connect( + LogisticRegressionModelData.getModelDataStream(modelDataTable) + .broadcast()) + .transform( + "PredictLabelOperator", + outputTypeInfo, + new PredictLabelOperator(inputTypeInfo, getFeaturesCol())); + + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility operator used for prediction. */ + private static class PredictLabelOperator extends AbstractStreamOperator<Row> + implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> { + private final RowTypeInfo inputTypeInfo; + + private final String featuresCol; + private ListState<Row> bufferedPointsState; + private DenseVector coefficient; + private long modelDataVersion = 0L; + + public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) { + this.inputTypeInfo = inputTypeInfo; + this.featuresCol = featuresCol; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + bufferedPointsState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("bufferedPoints", inputTypeInfo)); + } + + @Override + public void open() throws Exception { + super.open(); + + getRuntimeContext() + .getMetricGroup() + .gauge( + MODEL_DATA_VERSION_GAUGE_KEY, + (Gauge<String>) () -> Long.toString(modelDataVersion)); + } + + @Override + public void processElement1(StreamRecord<Row> streamRecord) throws Exception { + processElement(streamRecord); + } + + @Override + public void processElement2(StreamRecord<LogisticRegressionModelData> streamRecord) + throws Exception { + LogisticRegressionModelData modelData = streamRecord.getValue(); + coefficient = modelData.coefficient; + modelDataVersion = modelData.modelVersion; + for (Row dataPoint : bufferedPointsState.get()) { + processElement(new StreamRecord<>(dataPoint)); + } + bufferedPointsState.clear(); + } + + public void processElement(StreamRecord<Row> streamRecord) throws Exception { + Row dataPoint = streamRecord.getValue(); + if (coefficient == null) { + bufferedPointsState.add(dataPoint); + return; + } + Vector features = (Vector) dataPoint.getField(featuresCol); + Row predictionResult = predictOneDataPoint(features, coefficient); + output.collect( + new StreamRecord<>( + Row.join( + dataPoint, + Row.of( + predictionResult.getField(0), + predictionResult.getField(1), + modelDataVersion)))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineLogisticRegressionModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public OnlineLogisticRegressionModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java new file mode 100644 index 0000000..573cd48 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.common.param.HasRawPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** + * Params for {@link OnlineLogisticRegressionModel}. + * + * @param <T> The class type of this instance. + */ +public interface OnlineLogisticRegressionModelParams<T> + extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> { + Param<String> MODEL_VERSION_COL = + new StringParam( + "modelVersionCol", + "Model version column name.", + "modelVersion", + ParamValidators.notNull()); + + default String getModelVersionCol() { + return get(MODEL_VERSION_COL); + } + + default T setModelVersionCol(String value) { + set(MODEL_VERSION_COL, value); + return (T) this; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java new file mode 100644 index 0000000..961b7c5 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.ml.common.param.HasBatchStrategy; +import org.apache.flink.ml.common.param.HasElasticNet; +import org.apache.flink.ml.common.param.HasGlobalBatchSize; +import org.apache.flink.ml.common.param.HasLabelCol; +import org.apache.flink.ml.common.param.HasReg; +import org.apache.flink.ml.common.param.HasWeightCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params of {@link OnlineLogisticRegression}. + * + * @param <T> The class type of this instance. + */ +public interface OnlineLogisticRegressionParams<T> + extends HasLabelCol<T>, + HasWeightCol<T>, + HasBatchStrategy<T>, + HasGlobalBatchSize<T>, + HasReg<T>, + HasElasticNet<T>, + OnlineLogisticRegressionModelParams<T> { + + Param<Double> ALPHA = + new DoubleParam("alpha", "The alpha parameter of ftrl.", 0.1, ParamValidators.gt(0.0)); + + Param<Double> BETA = + new DoubleParam("beta", "The beta parameter of ftrl.", 0.1, ParamValidators.gt(0.0)); + + default Double getAlpha() { + return get(ALPHA); + } + + default T setAlpha(Double value) { + return set(ALPHA, value); + } + + default Double getBeta() { + return get(BETA); + } + + default T setBeta(Double value) { + return set(BETA, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java index b876b13..4112527 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java @@ -18,7 +18,6 @@ package org.apache.flink.ml.clustering.kmeans; -import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.ListState; @@ -31,6 +30,7 @@ import org.apache.flink.iteration.IterationBodyResult; import org.apache.flink.iteration.Iterations; import org.apache.flink.iteration.operator.OperatorStateUtils; import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; import org.apache.flink.ml.common.distance.DistanceMeasure; import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseVector; @@ -41,22 +41,18 @@ import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; import org.apache.flink.runtime.state.StateInitializationContext; import org.apache.flink.streaming.api.datastream.DataStream; -import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; -import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.api.internal.TableImpl; import org.apache.flink.types.Row; -import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; import org.apache.commons.collections.IteratorUtils; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -169,10 +165,7 @@ public class OnlineKMeans + "of elements in each batch. Some subtasks might be idling forever."); DataStream<KMeansModelData> newModelData = - points.countWindowAll(batchSize) - .apply(new GlobalBatchCreator()) - .flatMap(new GlobalBatchSplitter(parallelism)) - .rebalance() + DataStreamUtils.generateBatchData(points, parallelism, batchSize) .connect(modelData.broadcast()) .transform( "ModelDataLocalUpdater", @@ -340,52 +333,6 @@ public class OnlineKMeans } } - /** - * An operator that splits a global batch into evenly-sized local batches, and distributes them - * to downstream operator. - */ - private static class GlobalBatchSplitter - implements FlatMapFunction<DenseVector[], DenseVector[]> { - private final int downStreamParallelism; - - private GlobalBatchSplitter(int downStreamParallelism) { - this.downStreamParallelism = downStreamParallelism; - } - - @Override - public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) { - int div = values.length / downStreamParallelism; - int mod = values.length % downStreamParallelism; - - int offset = 0; - int i = 0; - - int size = div + 1; - for (; i < mod; i++) { - collector.collect(Arrays.copyOfRange(values, offset, offset + size)); - offset += size; - } - - size = div; - for (; i < downStreamParallelism; i++) { - collector.collect(Arrays.copyOfRange(values, offset, offset + size)); - offset += size; - } - } - } - - private static class GlobalBatchCreator - implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> { - @Override - public void apply( - GlobalWindow timeWindow, - Iterable<DenseVector> iterable, - Collector<DenseVector[]> collector) { - List<DenseVector> points = IteratorUtils.toList(iterable.iterator()); - collector.collect(points.toArray(new DenseVector[0])); - } - } - /** * Sets the initial model data of the online training process with the provided model data * table. diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index 129fd2e..37815e0 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -250,7 +250,7 @@ public class LogisticRegressionTest { LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); assertEquals( - Collections.singletonList("coefficient"), + Arrays.asList("coefficient", "modelVersion"), model.getModelData()[0].getResolvedSchema().getColumnNames()); Table output = model.transform(binomialDataTable)[0]; verifyPredictionResult( diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java new file mode 100644 index 0000000..446548d --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -0,0 +1,687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.JobID; +import org.apache.flink.api.common.JobSubmissionResult; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.classification.logisticregression.LogisticRegression; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; +import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.runtime.jobgraph.JobGraph; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.TestLogger; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY; + +/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */ +public class OnlineLogisticRegressionTest extends TestLogger { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0}; + + private static final Row[] TRAIN_DENSE_ROWS_1 = + new Row[] { + Row.of(Vectors.dense(0.1, 2.), 0.), + Row.of(Vectors.dense(0.2, 2.), 0.), + Row.of(Vectors.dense(0.3, 2.), 0.), + Row.of(Vectors.dense(0.4, 2.), 0.), + Row.of(Vectors.dense(0.5, 2.), 0.), + Row.of(Vectors.dense(11., 12.), 1.), + Row.of(Vectors.dense(12., 11.), 1.), + Row.of(Vectors.dense(13., 12.), 1.), + Row.of(Vectors.dense(14., 12.), 1.), + Row.of(Vectors.dense(15., 12.), 1.) + }; + + private static final Row[] TRAIN_DENSE_ROWS_2 = + new Row[] { + Row.of(Vectors.dense(0.2, 3.), 0.), + Row.of(Vectors.dense(0.8, 1.), 0.), + Row.of(Vectors.dense(0.7, 1.), 0.), + Row.of(Vectors.dense(0.6, 2.), 0.), + Row.of(Vectors.dense(0.2, 2.), 0.), + Row.of(Vectors.dense(14., 17.), 1.), + Row.of(Vectors.dense(15., 10.), 1.), + Row.of(Vectors.dense(16., 16.), 1.), + Row.of(Vectors.dense(17., 10.), 1.), + Row.of(Vectors.dense(18., 13.), 1.) + }; + + private static final Row[] PREDICT_DENSE_ROWS = + new Row[] { + Row.of(Vectors.dense(0.8, 2.7), 0.0), Row.of(Vectors.dense(15.5, 11.2), 1.0) + }; + + private static final Row[] TRAIN_SPARSE_ROWS_1 = + new Row[] { + Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 1.0), + Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0., 1.4), + Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0., 1.3), + Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0., 1.4), + Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 1.6), + Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1., 1.8), + Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1., 1.9), + Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1., 1.0), + Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1., 1.1) + }; + + private static final Row[] TRAIN_SPARSE_ROWS_2 = + new Row[] { + Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0., 1.0), + Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0., 1.3), + Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0., 1.4), + Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 1.0), + Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1., 1.6), + Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1., 1.8), + Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1., 1.0), + Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1., 1.5), + Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1., 1.0) + }; + + private static final Row[] PREDICT_SPARSE_ROWS = + new Row[] { + Row.of(Vectors.sparse(10, new int[] {1, 3, 5}, ONE_ARRAY), 0.), + Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.) + }; + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private long currentModelDataVersion; + + private InMemorySourceFunction<Row> trainDenseSource; + private InMemorySourceFunction<Row> predictDenseSource; + private InMemorySourceFunction<Row> trainSparseSource; + private InMemorySourceFunction<Row> predictSparseSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<LogisticRegressionModelData> modelDataSink; + + // TODO: creates static mini cluster once for whole test class after dependency upgrades to + // Flink 1.15. + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainDenseTable; + private Table onlineTrainDenseTable; + private Table onlinePredictDenseTable; + private Table onlineTrainSparseTable; + private Table onlinePredictSparseTable; + private Table initDenseModel; + private Table initSparseModel; + + @Before + public void before() throws Exception { + currentModelDataVersion = 0; + + trainDenseSource = new InMemorySourceFunction<>(); + predictDenseSource = new InMemorySourceFunction<>(); + trainSparseSource = new InMemorySourceFunction<>(); + predictSparseSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.create(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + offlineTrainDenseTable = + tEnv.fromDataStream(env.fromElements(TRAIN_DENSE_ROWS_1)).as("features", "label"); + onlineTrainDenseTable = + tEnv.fromDataStream( + env.addSource( + trainDenseSource, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), Types.DOUBLE + }, + new String[] {"features", "label"}))); + + onlinePredictDenseTable = + tEnv.fromDataStream( + env.addSource( + predictDenseSource, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), Types.DOUBLE + }, + new String[] {"features", "label"}))); + + onlineTrainSparseTable = + tEnv.fromDataStream( + env.addSource( + trainSparseSource, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(SparseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + + onlinePredictSparseTable = + tEnv.fromDataStream( + env.addSource( + predictSparseSource, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(SparseVector.class), Types.DOUBLE + }, + new String[] {"features", "label"}))); + + initDenseModel = + tEnv.fromDataStream( + env.fromElements( + Row.of( + new DenseVector( + new double[] { + 0.41233679404769874, -0.18088118293232122 + }), + 0L))); + initSparseModel = + tEnv.fromDataStream( + env.fromElements( + Row.of( + new DenseVector( + new double[] { + 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, + 0.01, 0.01 + }), + 0L))); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineLogisticRegressionModel's transform output and model data. + */ + private void transformAndOutputData( + OnlineLogisticRegressionModel onlineModel, boolean isSparse) { + Table outputTable = + onlineModel + .transform(isSparse ? onlinePredictSparseTable : onlinePredictDenseTable)[ + 0]; + tEnv.toDataStream(outputTable).addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + LogisticRegressionModelData.getModelDataStream(modelDataTable).addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup(JobID jobID) throws InterruptedException { + while (reporter.findMetrics(jobID, MODEL_DATA_VERSION_GAUGE_KEY).size() + < defaultParallelism) { + Thread.sleep(100); + } + waitModelDataUpdate(jobID); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate(JobID jobID) throws InterruptedException { + do { + long tmpModelDataVersion = + reporter.findMetrics(jobID, MODEL_DATA_VERSION_GAUGE_KEY).values().stream() + .map(x -> Long.parseLong(((Gauge<String>) x).getValue())) + .min(Long::compareTo) + .get(); + if (tmpModelDataVersion == currentModelDataVersion) { + Thread.sleep(100); + } else { + currentModelDataVersion = tmpModelDataVersion; + break; + } + } while (true); + } + + /** + * Inserts default predict data to the predict queue, fetches the prediction results, and + * asserts that the grouping result is as expected. + * + * @param expectedRawInfo A list containing sets of expected result RawInfo. + */ + private void predictAndAssert(List<DenseVector> expectedRawInfo, boolean isSparse) + throws Exception { + if (isSparse) { + predictSparseSource.addAll(PREDICT_SPARSE_ROWS); + } else { + predictDenseSource.addAll(PREDICT_DENSE_ROWS); + } + List<Row> rawResult = + outputSink.poll(isSparse ? PREDICT_SPARSE_ROWS.length : PREDICT_DENSE_ROWS.length); + List<DenseVector> resultDetail = new ArrayList<>(rawResult.size()); + for (Row row : rawResult) { + resultDetail.add(row.getFieldAs(3)); + } + resultDetail.sort(TestUtils::compare); + expectedRawInfo.sort(TestUtils::compare); + for (int i = 0; i < resultDetail.size(); ++i) { + double[] realData = resultDetail.get(i).values; + double[] expectedData = expectedRawInfo.get(i).values; + for (int j = 0; j < expectedData.length; ++j) { + Assert.assertEquals(realData[j], expectedData[j], 1.0e-5); + } + } + } + + private JobID submitJob(JobGraph jobGraph) + throws ExecutionException, InterruptedException, TimeoutException { + return miniCluster + .submitJob(jobGraph) + .thenApply(JobSubmissionResult::getJobID) + .get(1, TimeUnit.SECONDS); + } + + @Test + public void testParam() { + OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression(); + Assert.assertEquals("features", onlineLogisticRegression.getFeaturesCol()); + Assert.assertEquals("count", onlineLogisticRegression.getBatchStrategy()); + Assert.assertEquals("label", onlineLogisticRegression.getLabelCol()); + Assert.assertEquals(0.0, onlineLogisticRegression.getReg(), 1.0e-5); + Assert.assertEquals(0.0, onlineLogisticRegression.getElasticNet(), 1.0e-5); + Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5); + Assert.assertEquals(0.1, onlineLogisticRegression.getBeta(), 1.0e-5); + Assert.assertEquals(32, onlineLogisticRegression.getGlobalBatchSize()); + + onlineLogisticRegression + .setFeaturesCol("test_feature") + .setLabelCol("test_label") + .setGlobalBatchSize(5) + .setReg(0.5) + .setElasticNet(0.25) + .setAlpha(0.1) + .setBeta(0.2); + + Assert.assertEquals("test_feature", onlineLogisticRegression.getFeaturesCol()); + Assert.assertEquals("test_label", onlineLogisticRegression.getLabelCol()); + Assert.assertEquals(0.5, onlineLogisticRegression.getReg(), 1.0e-5); + Assert.assertEquals(0.25, onlineLogisticRegression.getElasticNet(), 1.0e-5); + Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5); + Assert.assertEquals(0.2, onlineLogisticRegression.getBeta(), 1.0e-5); + Assert.assertEquals(5, onlineLogisticRegression.getGlobalBatchSize()); + + OnlineLogisticRegressionModel onlineLogisticRegressionModel = + new OnlineLogisticRegressionModel(); + Assert.assertEquals("features", onlineLogisticRegressionModel.getFeaturesCol()); + Assert.assertEquals("modelVersion", onlineLogisticRegressionModel.getModelVersionCol()); + Assert.assertEquals("prediction", onlineLogisticRegressionModel.getPredictionCol()); + Assert.assertEquals("rawPrediction", onlineLogisticRegressionModel.getRawPredictionCol()); + + onlineLogisticRegressionModel + .setFeaturesCol("test_feature") + .setPredictionCol("pred") + .setModelVersionCol("version") + .setRawPredictionCol("raw"); + + Assert.assertEquals("test_feature", onlineLogisticRegressionModel.getFeaturesCol()); + Assert.assertEquals("version", onlineLogisticRegressionModel.getModelVersionCol()); + Assert.assertEquals("pred", onlineLogisticRegressionModel.getPredictionCol()); + Assert.assertEquals("raw", onlineLogisticRegressionModel.getRawPredictionCol()); + } + + @Test + public void testDenseFitAndPredict() throws Exception { + final List<DenseVector> expectedRawInfo1 = + Arrays.asList( + new DenseVector(new double[] {0.04481034155642882, 0.9551896584435712}), + new DenseVector(new double[] {0.5353966697318491, 0.4646033302681509})); + final List<DenseVector> expectedRawInfo2 = + Arrays.asList( + new DenseVector(new double[] {0.013104324065967066, 0.9868956759340329}), + new DenseVector(new double[] {0.5095144380001769, 0.49048556199982307})); + OnlineLogisticRegression onlineLogisticRegression = + new OnlineLogisticRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setPredictionCol("prediction") + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(10) + .setInitialModelData(initDenseModel); + OnlineLogisticRegressionModel onlineModel = + onlineLogisticRegression.fit(onlineTrainDenseTable); + transformAndOutputData(onlineModel, false); + + final JobID jobID = submitJob(env.getStreamGraph().getJobGraph()); + trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); + + waitInitModelDataSetup(jobID); + predictAndAssert(expectedRawInfo1, false); + + trainDenseSource.addAll(TRAIN_DENSE_ROWS_2); + waitModelDataUpdate(jobID); + predictAndAssert(expectedRawInfo2, false); + } + + @Test + public void testSparseFitAndPredict() throws Exception { + final List<DenseVector> expectedRawInfo1 = + Arrays.asList( + new DenseVector(new double[] {0.4452309884735286, 0.5547690115264714}), + new DenseVector(new double[] {0.5105551725414953, 0.4894448274585047})); + final List<DenseVector> expectedRawInfo2 = + Arrays.asList( + new DenseVector(new double[] {0.40310431554310666, 0.5968956844568933}), + new DenseVector(new double[] {0.5249618837373886, 0.4750381162626114})); + OnlineLogisticRegression onlineLogisticRegression = + new OnlineLogisticRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setPredictionCol("prediction") + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(9) + .setInitialModelData(initSparseModel); + OnlineLogisticRegressionModel onlineModel = + onlineLogisticRegression.fit(onlineTrainSparseTable); + transformAndOutputData(onlineModel, true); + final JobID jobID = submitJob(env.getStreamGraph().getJobGraph()); + + trainSparseSource.addAll(TRAIN_SPARSE_ROWS_1); + waitInitModelDataSetup(jobID); + predictAndAssert(expectedRawInfo1, true); + + trainSparseSource.addAll(TRAIN_SPARSE_ROWS_2); + waitModelDataUpdate(jobID); + predictAndAssert(expectedRawInfo2, true); + } + + @Test + public void testFitAndPredictWithWeightCol() throws Exception { + final List<DenseVector> expectedRawInfo1 = + Arrays.asList( + new DenseVector(new double[] {0.452491993753382, 0.547508006246618}), + new DenseVector(new double[] {0.5069192929506545, 0.4930807070493455})); + final List<DenseVector> expectedRawInfo2 = + Arrays.asList( + new DenseVector(new double[] {0.41108882806164193, 0.5889111719383581}), + new DenseVector(new double[] {0.5247727600974581, 0.4752272399025419})); + OnlineLogisticRegression onlineLogisticRegression = + new OnlineLogisticRegression() + .setFeaturesCol("features") + .setLabelCol("label") + .setWeightCol("weight") + .setPredictionCol("prediction") + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(9) + .setInitialModelData(initSparseModel); + OnlineLogisticRegressionModel onlineModel = + onlineLogisticRegression.fit(onlineTrainSparseTable); + transformAndOutputData(onlineModel, true); + final JobID jobID = submitJob(env.getStreamGraph().getJobGraph()); + + trainSparseSource.addAll(TRAIN_SPARSE_ROWS_1); + waitInitModelDataSetup(jobID); + predictAndAssert(expectedRawInfo1, true); + + trainSparseSource.addAll(TRAIN_SPARSE_ROWS_2); + waitModelDataUpdate(jobID); + predictAndAssert(expectedRawInfo2, true); + } + + @Test + public void testGenerateRandomModelData() throws Exception { + Table modelDataTable = LogisticRegressionModelData.generateRandomModelData(tEnv, 2, 2022); + DataStream<Row> modelData = tEnv.toDataStream(modelDataTable); + Row modelRow = (Row) IteratorUtils.toList(modelData.executeAndCollect()).get(0); + Assert.assertEquals(2, ((DenseVector) modelRow.getField(0)).size()); + Assert.assertEquals(0L, modelRow.getField(1)); + } + + @Test + public void testInitWithLogisticRegression() throws Exception { + final List<DenseVector> expectedRawInfo1 = + Arrays.asList( + new DenseVector(new double[] {0.037327343811250024, 0.96267265618875}), + new DenseVector(new double[] {0.5684728224189707, 0.4315271775810293})); + final List<DenseVector> expectedRawInfo2 = + Arrays.asList( + new DenseVector(new double[] {0.007758574555505882, 0.9922414254444941}), + new DenseVector(new double[] {0.5257216567388069, 0.4742783432611931})); + LogisticRegression logisticRegression = + new LogisticRegression() + .setLabelCol("label") + .setFeaturesCol("features") + .setPredictionCol("prediction"); + LogisticRegressionModel model = logisticRegression.fit(offlineTrainDenseTable); + + OnlineLogisticRegression onlineLogisticRegression = + new OnlineLogisticRegression() + .setFeaturesCol("features") + .setPredictionCol("prediction") + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(10) + .setInitialModelData(model.getModelData()[0]); + + OnlineLogisticRegressionModel onlineModel = + onlineLogisticRegression.fit(onlineTrainDenseTable); + transformAndOutputData(onlineModel, false); + final JobID jobID = submitJob(env.getStreamGraph().getJobGraph()); + trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); + waitInitModelDataSetup(jobID); + predictAndAssert(expectedRawInfo1, false); + + trainDenseSource.addAll(TRAIN_DENSE_ROWS_2); + waitModelDataUpdate(jobID); + predictAndAssert(expectedRawInfo2, false); + } + + @Test + public void testBatchSizeLessThanParallelism() { + try { + env.setParallelism(defaultParallelism); + trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); + new OnlineLogisticRegression() + .setInitialModelData(initDenseModel) + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(2) + .setLabelCol("label") + .fit(onlineTrainDenseTable); + Assert.fail("Expected IllegalStateException"); + } catch (Exception e) { + Throwable exception = e; + while (exception.getCause() != null) { + exception = exception.getCause(); + } + Assert.assertEquals(IllegalStateException.class, exception.getClass()); + Assert.assertEquals( + "There are more subtasks in the training process than the number " + + "of elements in each batch. Some subtasks might be idling forever.", + exception.getMessage()); + } + } + + @Test + public void testSaveAndReload() throws Exception { + final List<DenseVector> expectedRawInfo1 = + Arrays.asList( + new DenseVector(new double[] {0.04481034155642882, 0.9551896584435712}), + new DenseVector(new double[] {0.5353966697318491, 0.4646033302681509})); + final List<DenseVector> expectedRawInfo2 = + Arrays.asList( + new DenseVector(new double[] {0.013104324065967066, 0.9868956759340329}), + new DenseVector(new double[] {0.5095144380001769, 0.49048556199982307})); + OnlineLogisticRegression onlineLogisticRegression = + new OnlineLogisticRegression() + .setFeaturesCol("features") + .setPredictionCol("prediction") + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(10) + .setInitialModelData(initDenseModel); + + String savePath = tempFolder.newFolder().getAbsolutePath(); + onlineLogisticRegression.save(savePath); + miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph()); + OnlineLogisticRegression loadedOnlineLogisticRegression = + OnlineLogisticRegression.load(tEnv, savePath); + OnlineLogisticRegressionModel onlineModel = + loadedOnlineLogisticRegression.fit(onlineTrainDenseTable); + String modelSavePath = tempFolder.newFolder().getAbsolutePath(); + onlineModel.save(modelSavePath); + OnlineLogisticRegressionModel loadedOnlineModel = + OnlineLogisticRegressionModel.load(tEnv, modelSavePath); + loadedOnlineModel.setModelData(onlineModel.getModelData()); + + transformAndOutputData(loadedOnlineModel, false); + final JobID jobID = submitJob(env.getStreamGraph().getJobGraph()); + + trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); + waitInitModelDataSetup(jobID); + predictAndAssert(expectedRawInfo1, false); + + trainDenseSource.addAll(TRAIN_DENSE_ROWS_2); + waitModelDataUpdate(jobID); + predictAndAssert(expectedRawInfo2, false); + } + + @Test + public void testGetModelData() throws Exception { + OnlineLogisticRegression onlineLogisticRegression = + new OnlineLogisticRegression() + .setFeaturesCol("features") + .setPredictionCol("prediction") + .setReg(0.2) + .setElasticNet(0.5) + .setGlobalBatchSize(10) + .setInitialModelData(initDenseModel); + OnlineLogisticRegressionModel onlineModel = + onlineLogisticRegression.fit(onlineTrainDenseTable); + transformAndOutputData(onlineModel, false); + + submitJob(env.getStreamGraph().getJobGraph()); + trainDenseSource.addAll(TRAIN_DENSE_ROWS_1); + LogisticRegressionModelData actualModelData = modelDataSink.poll(); + + LogisticRegressionModelData expectedModelData = + new LogisticRegressionModelData( + new DenseVector(new double[] {0.2994527071464283, -0.1412541067743284}), + 1L); + Assert.assertArrayEquals( + expectedModelData.coefficient.values, actualModelData.coefficient.values, 1e-5); + Assert.assertEquals(expectedModelData.modelVersion, actualModelData.modelVersion); + } + + @Test + public void testSetModelData() throws Exception { + LogisticRegressionModelData modelData1 = + new LogisticRegressionModelData(new DenseVector(new double[] {0.085, -0.22}), 1L); + + LogisticRegressionModelData modelData2 = + new LogisticRegressionModelData(new DenseVector(new double[] {0.075, -0.28}), 2L); + + final List<DenseVector> expectedRawInfo1 = + Arrays.asList( + new DenseVector(new double[] {0.6285496932692606, 0.3714503067307394}), + new DenseVector(new double[] {0.7588710471221473, 0.24112895287785274})); + final List<DenseVector> expectedRawInfo2 = + Arrays.asList( + new DenseVector(new double[] {0.6673003248270917, 0.3326996751729083}), + new DenseVector(new double[] {0.8779865510655934, 0.12201344893440658})); + + InMemorySourceFunction<LogisticRegressionModelData> modelDataSource = + new InMemorySourceFunction<>(); + Table modelDataTable = + tEnv.fromDataStream( + env.addSource( + modelDataSource, + TypeInformation.of(LogisticRegressionModelData.class))); + + OnlineLogisticRegressionModel onlineModel = + new OnlineLogisticRegressionModel() + .setModelData(modelDataTable) + .setFeaturesCol("features") + .setPredictionCol("prediction"); + transformAndOutputData(onlineModel, false); + final JobID jobID = submitJob(env.getStreamGraph().getJobGraph()); + + modelDataSource.addAll(modelData1); + waitInitModelDataSetup(jobID); + predictAndAssert(expectedRawInfo1, false); + + modelDataSource.addAll(modelData2); + waitModelDataUpdate(jobID); + predictAndAssert(expectedRawInfo2, false); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java index 1a0cfc4..ebf7ab0 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java @@ -47,6 +47,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.apache.flink.test.util.TestBaseUtils.compareResultCollections; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -79,17 +80,6 @@ public class MinMaxScalerTest { Vectors.dense(0.5, 0.125), Vectors.dense(0.75, 0.225))); - /** Note: this comparator imposes orderings that are inconsistent with equals. */ - private static int compare(DenseVector first, DenseVector second) { - for (int i = 0; i < first.size(); i++) { - int cmp = Double.compare(first.get(i), second.get(i)); - if (cmp != 0) { - return cmp; - } - } - return 0; - } - @Before public void before() { Configuration config = new Configuration(); @@ -113,8 +103,7 @@ public class MinMaxScalerTest { (MapFunction<Row, DenseVector>) row -> (DenseVector) row.getField(outputCol)); List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect()); - result.sort(MinMaxScalerTest::compare); - assertEquals(expected, result); + compareResultCollections(expected, result, TestUtils::compare); } @Test diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java index 171aa66..5250588 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Stage; import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; @@ -33,6 +34,7 @@ import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.types.DataType; import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; import org.apache.commons.lang3.ArrayUtils; @@ -114,4 +116,16 @@ public class TestUtils { .map(DataType::getConversionClass) .toArray(Class<?>[]::new); } + + /** Note: this comparator imposes orderings that are inconsistent with equals. */ + public static int compare(DenseVector first, DenseVector second) { + Preconditions.checkArgument(first.size() == second.size(), "Vector size mismatched."); + for (int i = 0; i < first.size(); i++) { + int cmp = Double.compare(first.get(i), second.get(i)); + if (cmp != 0) { + return cmp; + } + } + return 0; + } } diff --git a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py index 60a49f2..fb3adef 100644 --- a/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py +++ b/flink-ml-python/pyflink/ml/lib/classification/logisticregression.py @@ -17,12 +17,14 @@ ################################################################################ from abc import ABC +from pyflink.ml.core.param import (ParamValidators, Param, StringParam, FloatParam) from pyflink.ml.core.wrapper import JavaWithParams from pyflink.ml.lib.classification.common import (JavaClassificationModel, JavaClassificationEstimator) from pyflink.ml.lib.param import (HasWeightCol, HasMaxIter, HasReg, HasLearningRate, HasGlobalBatchSize, HasTol, HasMultiClass, HasFeaturesCol, - HasPredictionCol, HasRawPredictionCol, HasLabelCol, HasElasticNet) + HasPredictionCol, HasRawPredictionCol, HasLabelCol, + HasBatchStrategy, HasElasticNet) class _LogisticRegressionModelParams( @@ -98,3 +100,120 @@ class LogisticRegression(JavaClassificationEstimator, _LogisticRegressionParams) @classmethod def _java_estimator_class_name(cls) -> str: return "LogisticRegression" + + +class _OnlineLogisticRegressionModelParams( + JavaWithParams, + HasFeaturesCol, + HasPredictionCol, + HasRawPredictionCol, + ABC +): + """ + Params for :class:`OnlineLogisticRegressionModel`. + """ + MODEL_VERSION_COL: Param[str] = StringParam( + "model_version_col", + "Model version column name.", + "model_version", + ParamValidators.not_null()) + + def __init__(self, java_params): + super(_OnlineLogisticRegressionModelParams, self).__init__(java_params) + + def set_model_version_col(self, value: str): + return self.set(self.MODEL_VERSION_COL, value) + + def get_model_version_col(self) -> str: + return self.get(self.MODEL_VERSION_COL) + + +class _OnlineLogisticRegressionParams( + _OnlineLogisticRegressionModelParams, + HasBatchStrategy, + HasLabelCol, + HasWeightCol, + HasReg, + HasElasticNet, + HasGlobalBatchSize +): + """ + Params for :class:`OnlineLogisticRegression`. + """ + + ALPHA: Param[float] = FloatParam( + "alpha", + "The alpha parameter of ftrl.", + 0.1, + ParamValidators.gt(0)) + + BETA: Param[float] = FloatParam( + "beta", + "The beta parameter of ftrl.", + 0.1, + ParamValidators.gt(0)) + + def __init__(self, java_params): + super(_OnlineLogisticRegressionParams, self).__init__(java_params) + + def set_alpha(self, alpha: float): + return self.set(self.ALPHA, alpha) + + def get_alpha(self) -> float: + return self.get(self.ALPHA) + + @property + def alpha(self) -> float: + return self.get_alpha() + + def set_beta(self, beta: float): + return self.set(self.BETA, beta) + + def get_beta(self) -> float: + return self.get(self.BETA) + + @property + def beta(self) -> float: + return self.get_beta() + + +class OnlineLogisticRegressionModel(JavaClassificationModel, + _OnlineLogisticRegressionModelParams): + """ + A Model which classifies data using the model data computed by + :class:`OnlineLogisticRegression`. + """ + + def __init__(self, java_model=None): + super(OnlineLogisticRegressionModel, self).__init__(java_model) + + @classmethod + def _java_model_package_name(cls) -> str: + return "logisticregression" + + @classmethod + def _java_model_class_name(cls) -> str: + return "OnlineLogisticRegressionModel" + + +class OnlineLogisticRegression(JavaClassificationEstimator, _OnlineLogisticRegressionParams): + """ + An Estimator which implements the online logistic regression algorithm. + + See H. Brendan McMahan et al., Ad click prediction: a view from the trenches. + """ + + def __init__(self): + super(OnlineLogisticRegression, self).__init__() + + @classmethod + def _create_model(cls, java_model) -> OnlineLogisticRegressionModel: + return OnlineLogisticRegressionModel(java_model) + + @classmethod + def _java_estimator_package_name(cls) -> str: + return "logisticregression" + + @classmethod + def _java_estimator_class_name(cls) -> str: + return "OnlineLogisticRegression" diff --git a/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py b/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py index 2f2f8bb..e460155 100644 --- a/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py +++ b/flink-ml-python/pyflink/ml/lib/classification/tests/test_logisticregression.py @@ -22,7 +22,7 @@ from pyflink.table import Table from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, DenseVector from pyflink.ml.lib.classification.logisticregression import LogisticRegression, \ - LogisticRegressionModel + LogisticRegressionModel, OnlineLogisticRegression from pyflink.ml.tests.test_utils import PyFlinkMLTestCase @@ -135,7 +135,9 @@ class LogisticRegressionTest(PyFlinkMLTestCase): regression.save(path) regression = LogisticRegression.load(self.t_env, path) # type: LogisticRegression model = regression.fit(self.binomial_data_table) - self.assertEqual(model.get_model_data()[0].get_schema().get_field_names(), ['coefficient']) + self.assertEqual( + model.get_model_data()[0].get_schema().get_field_names(), + ['coefficient', 'modelVersion']) output = model.transform(self.binomial_data_table)[0] field_names = output.get_schema().get_field_names() self.verify_predict_result( @@ -183,3 +185,38 @@ class LogisticRegressionTest(PyFlinkMLTestCase): else: self.assertAlmostEqual(1, prediction, delta=1e-7) self.assertTrue(raw_prediction.get(0) < 0.5) + + +class OnlineLogisticRegressionTest(PyFlinkMLTestCase): + + def setUp(self): + super(OnlineLogisticRegressionTest, self).setUp() + + def test_param(self): + online_logistic_regression = OnlineLogisticRegression() + self.assertEqual("features", online_logistic_regression.features_col) + self.assertEqual("count", online_logistic_regression.batch_strategy) + self.assertEqual("label", online_logistic_regression.label_col) + self.assertEqual(None, online_logistic_regression.weight_col) + self.assertEqual(0.0, online_logistic_regression.reg) + self.assertEqual(0.0, online_logistic_regression.elastic_net) + self.assertEqual(0.1, online_logistic_regression.alpha) + self.assertEqual(0.1, online_logistic_regression.beta) + self.assertEqual(32, online_logistic_regression.global_batch_size) + + online_logistic_regression \ + .set_features_col("test_feature") \ + .set_label_col("test_label") \ + .set_global_batch_size(5) \ + .set_reg(0.5) \ + .set_elastic_net(0.25) \ + .set_alpha(0.1) \ + .set_beta(0.2) + + self.assertEqual("test_feature", online_logistic_regression.features_col) + self.assertEqual("test_label", online_logistic_regression.label_col) + self.assertEqual(0.5, online_logistic_regression.reg) + self.assertEqual(0.25, online_logistic_regression.elastic_net) + self.assertEqual(0.1, online_logistic_regression.alpha) + self.assertEqual(0.2, online_logistic_regression.beta) + self.assertEqual(5, online_logistic_regression.global_batch_size)