This is an automated email from the ASF dual-hosted git repository. lindong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit adc1898920f2bf715efaa9fc7e343af8a61a49b6 Author: jiangxin <jiangxin.ji...@alibaba-inc.com> AuthorDate: Tue Mar 14 09:46:28 2023 +0800 [FLINK-31422] Add Servable for Logistic Regression Model --- flink-ml-benchmark/pom.xml | 7 ++ .../java/org/apache/flink/ml/util/TestUtils.java | 19 ++++ flink-ml-lib/pom.xml | 14 +++ .../logisticregression/LogisticRegression.java | 2 +- .../LogisticRegressionModel.java | 51 ++++----- ...a.java => LogisticRegressionModelDataUtil.java} | 60 ++++++----- .../OnlineLogisticRegression.java | 8 +- .../OnlineLogisticRegressionModel.java | 29 ++++-- .../ml/classification/LogisticRegressionTest.java | 84 ++++++++++++++- .../OnlineLogisticRegressionTest.java | 6 +- flink-ml-servable-lib/pom.xml | 66 ++++++++++++ .../LogisticRegressionModelData.java | 76 ++++++++++++++ .../LogisticRegressionModelParams.java | 2 +- .../LogisticRegressionModelServable.java | 116 +++++++++++++++++++++ flink-ml-uber/pom.xml | 7 ++ pom.xml | 1 + tools/ci/stage.sh | 2 + 17 files changed, 469 insertions(+), 81 deletions(-) diff --git a/flink-ml-benchmark/pom.xml b/flink-ml-benchmark/pom.xml index ef6269ca..f584a3e3 100644 --- a/flink-ml-benchmark/pom.xml +++ b/flink-ml-benchmark/pom.xml @@ -44,6 +44,13 @@ under the License. <scope>provided</scope> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-ml-servable-lib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-ml-core</artifactId> diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java index 78a1fa34..ec97b48c 100644 --- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -37,6 +37,7 @@ import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.servable.api.DataFrame; import org.apache.flink.ml.servable.api.TransformerServable; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -56,6 +57,7 @@ import java.io.DataOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -305,4 +307,21 @@ public class TestUtils { } return 0; } + + /** Construct DataFrame from a list of Flink {@link Row}s. */ + public static DataFrame constructDataFrame( + List<String> columnNames, + List<org.apache.flink.ml.servable.types.DataType> dataTypes, + List<Row> rows) { + List<org.apache.flink.ml.servable.api.Row> rowList = new ArrayList<>(); + for (Row row : rows) { + List<Object> values = new ArrayList<>(); + for (int i = 0; i < row.getArity(); i++) { + Object value = row.getField(i); + values.add(value); + } + rowList.add(new org.apache.flink.ml.servable.api.Row(values)); + } + return new DataFrame(columnNames, dataTypes, rowList); + } } diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml index 777c6b98..977e8e8c 100644 --- a/flink-ml-lib/pom.xml +++ b/flink-ml-lib/pom.xml @@ -37,6 +37,13 @@ under the License. <scope>provided</scope> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-ml-servable-lib</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-ml-core</artifactId> @@ -124,6 +131,13 @@ under the License. <scope>test</scope> <type>test-jar</type> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-ml-servable-lib</artifactId> + <version>${project.version}</version> + <scope>test</scope> + <type>test-jar</type> + </dependency> </dependencies> <build> diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java index eeb7338a..87cc650c 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java @@ -115,7 +115,7 @@ public class LogisticRegression optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE); DataStream<LogisticRegressionModelData> modelData = - rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0)); + rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0L)); LogisticRegressionModel model = new LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData)); ParamUtils.updateExistingParams(model, paramMap); diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java index 675846a6..e777c5fa 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java @@ -21,14 +21,13 @@ package org.apache.flink.ml.classification.logisticregression; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Model; import org.apache.flink.ml.common.broadcast.BroadcastUtils; import org.apache.flink.ml.common.datastream.TableUtils; -import org.apache.flink.ml.linalg.BLAS; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.param.Param; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.ReadWriteUtils; @@ -68,7 +67,7 @@ public class LogisticRegressionModel DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]); final String broadcastModelKey = "broadcastModelKey"; DataStream<LogisticRegressionModelData> modelDataStream = - LogisticRegressionModelData.getModelDataStream(modelDataTable); + LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); RowTypeInfo outputTypeInfo = new RowTypeInfo( @@ -87,7 +86,7 @@ public class LogisticRegressionModel inputList -> { DataStream inputData = inputList.get(0); return inputData.map( - new PredictLabelFunction(broadcastModelKey, getFeaturesCol()), + new PredictLabelFunction(broadcastModelKey, paramMap), outputTypeInfo); }); return new Table[] {tEnv.fromDataStream(predictionResult)}; @@ -108,9 +107,9 @@ public class LogisticRegressionModel public void save(String path) throws IOException { ReadWriteUtils.saveMetadata(this, path); ReadWriteUtils.saveModelData( - LogisticRegressionModelData.getModelDataStream(modelDataTable), + LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable), path, - new LogisticRegressionModelData.ModelDataEncoder()); + new LogisticRegressionModelDataUtil.ModelDataEncoder()); } public static LogisticRegressionModel load(StreamTableEnvironment tEnv, String path) @@ -118,10 +117,14 @@ public class LogisticRegressionModel LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); Table modelDataTable = ReadWriteUtils.loadModelData( - tEnv, path, new LogisticRegressionModelData.ModelDataDecoder()); + tEnv, path, new LogisticRegressionModelDataUtil.ModelDataDecoder()); return model.setModelData(modelDataTable); } + public static LogisticRegressionModelServable loadServable(String path) throws IOException { + return LogisticRegressionModelServable.load(path); + } + @Override public Map<Param<?>, Object> getParamMap() { return paramMap; @@ -132,39 +135,29 @@ public class LogisticRegressionModel private final String broadcastModelKey; - private final String featuresCol; + private final Map<Param<?>, Object> params; - private DenseVector coefficient; + private LogisticRegressionModelServable servable; - public PredictLabelFunction(String broadcastModelKey, String featuresCol) { + public PredictLabelFunction(String broadcastModelKey, Map<Param<?>, Object> params) { this.broadcastModelKey = broadcastModelKey; - this.featuresCol = featuresCol; + this.params = params; } @Override public Row map(Row dataPoint) { - if (coefficient == null) { + if (servable == null) { LogisticRegressionModelData modelData = (LogisticRegressionModelData) getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); - coefficient = modelData.coefficient; + servable = new LogisticRegressionModelServable(modelData); + ParamUtils.updateExistingParams(servable, params); } - DenseVector features = ((Vector) dataPoint.getField(featuresCol)).toDense(); - Row predictionResult = predictOneDataPoint(features, coefficient); - return Row.join(dataPoint, predictionResult); - } - } + Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); + + Tuple2<Double, DenseVector> predictionResult = servable.transform(features); - /** - * The main logic that predicts one input data point. - * - * @param feature The input feature. - * @param coefficient The model parameters. - * @return The prediction label and the raw probabilities. - */ - protected static Row predictOneDataPoint(Vector feature, DenseVector coefficient) { - double dotValue = BLAS.dot(feature, coefficient); - double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); - return Row.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); + return Row.join(dataPoint, Row.of(predictionResult.f0, predictionResult.f1)); + } } } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java similarity index 74% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java rename to flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java index da9bf7c4..e6acb7c7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java @@ -24,39 +24,25 @@ import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.configuration.Configuration; import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.DenseVector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.api.internal.TableImpl; +import java.io.ByteArrayOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.OutputStream; import java.util.Random; /** - * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}. - * - * <p>This class also provides methods to convert model data from Table to Datastream, and classes - * to save/load model data. + * The utility class which provides methods to convert model data from Table to Datastream, and + * classes to save/load model data. */ -public class LogisticRegressionModelData { - - public DenseVector coefficient; - public long modelVersion; - - public LogisticRegressionModelData() {} - - public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) { - this.coefficient = coefficient; - this.modelVersion = modelVersion; - } +public class LogisticRegressionModelDataUtil { /** * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly @@ -106,17 +92,36 @@ public class LogisticRegressionModelData { .map(x -> new LogisticRegressionModelData(x.getFieldAs(0), x.getFieldAs(1))); } + /** + * Converts the table model to a data stream of bytes. + * + * @param modelDataTable The table of model data. + * @return The data stream of serialized model data. + */ + public static DataStream<byte[]> getModelDataByteStream(Table modelDataTable) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + + return tEnv.toDataStream(modelDataTable) + .map( + x -> { + LogisticRegressionModelData modelData = + new LogisticRegressionModelData( + x.getFieldAs(0), x.getFieldAs(1)); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + modelData.encode(outputStream); + return outputStream.toByteArray(); + }); + } + /** Data encoder for {@link LogisticRegression} and {@link OnlineLogisticRegression}. */ public static class ModelDataEncoder implements Encoder<LogisticRegressionModelData> { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); @Override public void encode(LogisticRegressionModelData modelData, OutputStream outputStream) throws IOException { - DataOutputViewStreamWrapper dataOutputViewStreamWrapper = - new DataOutputViewStreamWrapper(outputStream); - serializer.serialize(modelData.coefficient, dataOutputViewStreamWrapper); - dataOutputViewStreamWrapper.writeLong(modelData.modelVersion); + modelData.encode(outputStream); } } @@ -127,17 +132,10 @@ public class LogisticRegressionModelData { public Reader<LogisticRegressionModelData> createReader( Configuration configuration, FSDataInputStream inputStream) { return new Reader<LogisticRegressionModelData>() { - private final DenseVectorSerializer serializer = new DenseVectorSerializer(); - @Override public LogisticRegressionModelData read() throws IOException { try { - DataInputViewStreamWrapper dataInputViewStreamWrapper = - new DataInputViewStreamWrapper(inputStream); - DenseVector coefficient = - serializer.deserialize(dataInputViewStreamWrapper); - long modelVersion = dataInputViewStreamWrapper.readLong(); - return new LogisticRegressionModelData(coefficient, modelVersion); + return LogisticRegressionModelData.decode(inputStream); } catch (EOFException e) { return null; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java index 79566a74..1bc19938 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java @@ -89,7 +89,7 @@ public class OnlineLogisticRegression StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); DataStream<LogisticRegressionModelData> modelDataStream = - LogisticRegressionModelData.getModelDataStream(initModelDataTable); + LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable); RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); TypeInformation pointTypeInfo; @@ -413,9 +413,9 @@ public class OnlineLogisticRegression public void save(String path) throws IOException { ReadWriteUtils.saveMetadata(this, path); ReadWriteUtils.saveModelData( - LogisticRegressionModelData.getModelDataStream(initModelDataTable), + LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable), path, - new LogisticRegressionModelData.ModelDataEncoder()); + new LogisticRegressionModelDataUtil.ModelDataEncoder()); } public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, String path) @@ -423,7 +423,7 @@ public class OnlineLogisticRegression OnlineLogisticRegression onlineLogisticRegression = ReadWriteUtils.loadStageParam(path); Table modelDataTable = ReadWriteUtils.loadModelData( - tEnv, path, new LogisticRegressionModelData.ModelDataDecoder()); + tEnv, path, new LogisticRegressionModelDataUtil.ModelDataDecoder()); onlineLogisticRegression.setInitialModelData(modelDataTable); return onlineLogisticRegression; } diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java index eab5cf63..f0608613 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.metrics.Gauge; import org.apache.flink.ml.api.Model; @@ -48,8 +49,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import static org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint; - /** * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}. */ @@ -88,12 +87,12 @@ public class OnlineLogisticRegressionModel DataStream<Row> predictionResult = tEnv.toDataStream(inputs[0]) .connect( - LogisticRegressionModelData.getModelDataStream(modelDataTable) + LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable) .broadcast()) .transform( "PredictLabelOperator", outputTypeInfo, - new PredictLabelOperator(inputTypeInfo, getFeaturesCol())); + new PredictLabelOperator(inputTypeInfo, paramMap)); return new Table[] {tEnv.fromDataStream(predictionResult)}; } @@ -103,14 +102,15 @@ public class OnlineLogisticRegressionModel implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> { private final RowTypeInfo inputTypeInfo; - private final String featuresCol; + private final Map<Param<?>, Object> params; private ListState<Row> bufferedPointsState; private DenseVector coefficient; private long modelDataVersion = 0L; + private LogisticRegressionModelServable servable; - public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) { + public PredictLabelOperator(RowTypeInfo inputTypeInfo, Map<Param<?>, Object> params) { this.inputTypeInfo = inputTypeInfo; - this.featuresCol = featuresCol; + this.params = params; } @Override @@ -156,15 +156,22 @@ public class OnlineLogisticRegressionModel bufferedPointsState.add(dataPoint); return; } - Vector features = (Vector) dataPoint.getField(featuresCol); - Row predictionResult = predictOneDataPoint(features, coefficient); + if (servable == null) { + servable = + new LogisticRegressionModelServable( + new LogisticRegressionModelData(coefficient, 0L)); + ParamUtils.updateExistingParams(servable, params); + } + Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol()); + Tuple2<Double, DenseVector> predictionResult = servable.transform(features); + output.collect( new StreamRecord<>( Row.join( dataPoint, Row.of( - predictionResult.getField(0), - predictionResult.getField(1), + predictionResult.f0, + predictionResult.f1, modelDataVersion)))); } } diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java index ad9a5416..f899c281 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java @@ -24,11 +24,16 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.SparseVector; import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.Vectors; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.servable.api.DataFrame; +import org.apache.flink.ml.servable.types.BasicType; +import org.apache.flink.ml.servable.types.DataTypes; import org.apache.flink.ml.util.ParamUtils; import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -45,10 +50,13 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import java.io.ByteArrayInputStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.apache.flink.ml.util.TestUtils.saveAndLoadServable; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -99,6 +107,8 @@ public class LogisticRegressionTest extends AbstractTestBase { private Table multinomialDataTable; + private DataFrame binomialDataDataFrame; + @Before public void before() { env = TestUtils.getExecutionEnvironment(); @@ -122,6 +132,15 @@ public class LogisticRegressionTest extends AbstractTestBase { DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE }, new String[] {"features", "label", "weight"}))); + binomialDataDataFrame = + TestUtils.constructDataFrame( + new ArrayList<>(Arrays.asList("features", "label", "weight")), + new ArrayList<>( + Arrays.asList( + DataTypes.VECTOR(BasicType.DOUBLE), + DataTypes.DOUBLE, + DataTypes.DOUBLE)), + binomialTrainData); } @SuppressWarnings("ConstantConditions, unchecked") @@ -143,6 +162,26 @@ public class LogisticRegressionTest extends AbstractTestBase { } } + private void verifyPredictionResult( + DataFrame output, String featuresCol, String predictionCol, String rawPredictionCol) { + int featuresColIndex = output.getIndex(featuresCol); + int predictionColIndex = output.getIndex(predictionCol); + int rawPredictionColIndex = output.getIndex(rawPredictionCol); + + for (org.apache.flink.ml.servable.api.Row predictionRow : output.collect()) { + DenseVector feature = ((Vector) predictionRow.get(featuresColIndex)).toDense(); + double prediction = (double) predictionRow.get(predictionColIndex); + DenseVector rawPrediction = (DenseVector) predictionRow.get(rawPredictionColIndex); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } + @Test public void testParam() { LogisticRegression logisticRegression = new LogisticRegression(); @@ -268,7 +307,7 @@ public class LogisticRegressionTest extends AbstractTestBase { LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); List<LogisticRegressionModelData> modelData = IteratorUtils.toList( - LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); assertEquals(1, modelData.size()); assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, 0.1); @@ -290,6 +329,47 @@ public class LogisticRegressionTest extends AbstractTestBase { logisticRegression.getRawPredictionCol()); } + @Test + public void testSaveLoadServableAndPredict() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + + LogisticRegressionModelServable servable = + saveAndLoadServable( + tEnv, + model, + tempFolder.newFolder().getAbsolutePath(), + LogisticRegressionModel::loadServable); + + DataFrame output = servable.transform(binomialDataDataFrame); + verifyPredictionResult( + output, + servable.getFeaturesCol(), + servable.getPredictionCol(), + servable.getRawPredictionCol()); + } + + @Test + public void testSetModelDataToServable() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + byte[] serializedModelData = + LogisticRegressionModelDataUtil.getModelDataByteStream(model.getModelData()[0]) + .executeAndCollect() + .next(); + + LogisticRegressionModelServable servable = new LogisticRegressionModelServable(); + ParamUtils.updateExistingParams(servable, model.getParamMap()); + servable.setModelData(new ByteArrayInputStream(serializedModelData)); + + DataFrame output = servable.transform(binomialDataDataFrame); + verifyPredictionResult( + output, + servable.getFeaturesCol(), + servable.getPredictionCol(), + servable.getRawPredictionCol()); + } + @Test public void testMultinomialFit() { try { @@ -349,7 +429,7 @@ public class LogisticRegressionTest extends AbstractTestBase { .fit(binomialDataTable); List<LogisticRegressionModelData> modelData = IteratorUtils.toList( - LogisticRegressionModelData.getModelDataStream(model.getModelData()[0]) + LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0]) .executeAndCollect()); final double errorTol = 1e-3; assertArrayEquals(expectedCoefficient, modelData.get(0).coefficient.values, errorTol); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java index 956e9472..cac9473c 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java @@ -30,6 +30,7 @@ import org.apache.flink.metrics.Gauge; import org.apache.flink.ml.classification.logisticregression.LogisticRegression; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel; import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData; +import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression; import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel; import org.apache.flink.ml.linalg.DenseVector; @@ -293,7 +294,7 @@ public class OnlineLogisticRegressionTest extends TestLogger { tEnv.toDataStream(outputTable).addSink(outputSink); Table modelDataTable = onlineModel.getModelData()[0]; - LogisticRegressionModelData.getModelDataStream(modelDataTable).addSink(modelDataSink); + LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable).addSink(modelDataSink); } /** Blocks the thread until Model has set up init model data. */ @@ -512,7 +513,8 @@ public class OnlineLogisticRegressionTest extends TestLogger { @Test public void testGenerateRandomModelData() throws Exception { - Table modelDataTable = LogisticRegressionModelData.generateRandomModelData(tEnv, 2, 2022); + Table modelDataTable = + LogisticRegressionModelDataUtil.generateRandomModelData(tEnv, 2, 2022); DataStream<Row> modelData = tEnv.toDataStream(modelDataTable); Row modelRow = (Row) IteratorUtils.toList(modelData.executeAndCollect()).get(0); Assert.assertEquals(2, ((DenseVector) modelRow.getField(0)).size()); diff --git a/flink-ml-servable-lib/pom.xml b/flink-ml-servable-lib/pom.xml new file mode 100644 index 00000000..f556f74b --- /dev/null +++ b/flink-ml-servable-lib/pom.xml @@ -0,0 +1,66 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <artifactId>flink-ml-parent</artifactId> + <groupId>org.apache.flink</groupId> + <version>2.2-SNAPSHOT</version> + </parent> + + <artifactId>flink-ml-servable-lib</artifactId> + <name>Flink ML : Servable : Lib</name> + + <dependencies> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-ml-servable-core</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-core</artifactId> + <version>${flink.version}</version> + <scope>provided</scope> + </dependency> + </dependencies> + + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + +</project> \ No newline at end of file diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java new file mode 100644 index 00000000..28927e47 --- /dev/null +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** Model data of {@link LogisticRegressionModelServable}. */ +public class LogisticRegressionModelData { + + public DenseVector coefficient; + + public long modelVersion; + + public LogisticRegressionModelData() {} + + public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) { + this.coefficient = coefficient; + this.modelVersion = modelVersion; + } + + /** + * Serializes the instance and writes to the output stream. + * + * @param outputStream The stream to write to. + */ + @VisibleForTesting + public void encode(OutputStream outputStream) throws IOException { + DataOutputViewStreamWrapper dataOutputViewStreamWrapper = + new DataOutputViewStreamWrapper(outputStream); + + DenseVectorSerializer serializer = new DenseVectorSerializer(); + serializer.serialize(coefficient, dataOutputViewStreamWrapper); + dataOutputViewStreamWrapper.writeLong(modelVersion); + } + + /** + * Reads and deserializes the model data from the input stream. + * + * @param inputStream The stream to read from. + * @return The model data instance. + */ + static LogisticRegressionModelData decode(InputStream inputStream) throws IOException { + DataInputViewStreamWrapper dataInputViewStreamWrapper = + new DataInputViewStreamWrapper(inputStream); + + DenseVectorSerializer serializer = new DenseVectorSerializer(); + DenseVector coefficient = serializer.deserialize(dataInputViewStreamWrapper); + long modelVersion = dataInputViewStreamWrapper.readLong(); + + return new LogisticRegressionModelData(coefficient, modelVersion); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java similarity index 94% rename from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java rename to flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java index b15b63e6..800764d5 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java @@ -23,7 +23,7 @@ import org.apache.flink.ml.common.param.HasPredictionCol; import org.apache.flink.ml.common.param.HasRawPredictionCol; /** - * Params for {@link LogisticRegressionModel}. + * Params for LogisticRegressionModel and LogisticRegressionModelServable. * * @param <T> The class type of this instance. */ diff --git a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java new file mode 100644 index 00000000..4cec8513 --- /dev/null +++ b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.servable.api.DataFrame; +import org.apache.flink.ml.servable.api.ModelServable; +import org.apache.flink.ml.servable.api.Row; +import org.apache.flink.ml.servable.types.BasicType; +import org.apache.flink.ml.servable.types.DataTypes; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ServableReadWriteUtils; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** A Servable which can be used to classifies data in online inference. */ +public class LogisticRegressionModelServable + implements ModelServable<LogisticRegressionModelServable>, + LogisticRegressionModelParams<LogisticRegressionModelServable> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + private LogisticRegressionModelData modelData; + + public LogisticRegressionModelServable() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + LogisticRegressionModelServable(LogisticRegressionModelData modelData) { + this(); + this.modelData = modelData; + } + + @Override + public DataFrame transform(DataFrame input) { + List<Double> predictionResults = new ArrayList<>(); + List<DenseVector> rawPredictionResults = new ArrayList<>(); + + int featuresColIndex = input.getIndex(getFeaturesCol()); + for (Row row : input.collect()) { + Vector features = (Vector) row.get(featuresColIndex); + Tuple2<Double, DenseVector> dataPoint = transform(features); + predictionResults.add(dataPoint.f0); + rawPredictionResults.add(dataPoint.f1); + } + + input.addColumn(getPredictionCol(), DataTypes.DOUBLE, predictionResults); + input.addColumn( + getRawPredictionCol(), DataTypes.VECTOR(BasicType.DOUBLE), rawPredictionResults); + + return input; + } + + public LogisticRegressionModelServable setModelData(InputStream... modelDataInputs) + throws IOException { + Preconditions.checkArgument(modelDataInputs.length == 1); + + modelData = LogisticRegressionModelData.decode(modelDataInputs[0]); + return this; + } + + public static LogisticRegressionModelServable load(String path) throws IOException { + LogisticRegressionModelServable servable = + ServableReadWriteUtils.loadServableParam( + path, LogisticRegressionModelServable.class); + + try (InputStream modelData = ServableReadWriteUtils.loadModelData(path)) { + servable.setModelData(modelData); + return servable; + } + } + + /** + * The main logic that predicts one input data point. + * + * @param feature The input feature. + * @return The prediction label and the raw probabilities. + */ + protected Tuple2<Double, DenseVector> transform(Vector feature) { + double dotValue = BLAS.dot(feature, modelData.coefficient); + double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue)); + return Tuple2.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob)); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-uber/pom.xml b/flink-ml-uber/pom.xml index 2eedeedf..423373fa 100644 --- a/flink-ml-uber/pom.xml +++ b/flink-ml-uber/pom.xml @@ -53,6 +53,12 @@ under the License. <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-ml-servable-lib</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-ml-lib</artifactId> @@ -84,6 +90,7 @@ under the License. <include>org.apache.flink:flink-ml-servable-core</include> <include>org.apache.flink:flink-ml-core</include> <include>org.apache.flink:flink-ml-iteration</include> + <include>org.apache.flink:flink-ml-servable-lib</include> <include>org.apache.flink:flink-ml-lib</include> <include>org.apache.flink:flink-ml-benchmark</include> <include>dev.ludovic.netlib:blas</include> diff --git a/pom.xml b/pom.xml index 743d246f..2a0af804 100644 --- a/pom.xml +++ b/pom.xml @@ -55,6 +55,7 @@ under the License. <module>flink-ml-servable-core</module> <module>flink-ml-core</module> <module>flink-ml-iteration</module> + <module>flink-ml-servable-lib</module> <module>flink-ml-lib</module> <module>flink-ml-tests</module> <module>flink-ml-uber</module> diff --git a/tools/ci/stage.sh b/tools/ci/stage.sh index c44c023d..4cf36227 100755 --- a/tools/ci/stage.sh +++ b/tools/ci/stage.sh @@ -23,11 +23,13 @@ STAGE_TESTS="tests" STAGE_MISC="misc" MODULES_CORE="\ +flink-ml-servable-core,\ flink-ml-core,\ flink-ml-iteration,\ " MODULES_LIB="\ +flink-ml-servable-lib,\ flink-ml-lib,\ "