This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit adc1898920f2bf715efaa9fc7e343af8a61a49b6
Author: jiangxin <jiangxin.ji...@alibaba-inc.com>
AuthorDate: Tue Mar 14 09:46:28 2023 +0800

    [FLINK-31422] Add Servable for Logistic Regression Model
---
 flink-ml-benchmark/pom.xml                         |   7 ++
 .../java/org/apache/flink/ml/util/TestUtils.java   |  19 ++++
 flink-ml-lib/pom.xml                               |  14 +++
 .../logisticregression/LogisticRegression.java     |   2 +-
 .../LogisticRegressionModel.java                   |  51 ++++-----
 ...a.java => LogisticRegressionModelDataUtil.java} |  60 ++++++-----
 .../OnlineLogisticRegression.java                  |   8 +-
 .../OnlineLogisticRegressionModel.java             |  29 ++++--
 .../ml/classification/LogisticRegressionTest.java  |  84 ++++++++++++++-
 .../OnlineLogisticRegressionTest.java              |   6 +-
 flink-ml-servable-lib/pom.xml                      |  66 ++++++++++++
 .../LogisticRegressionModelData.java               |  76 ++++++++++++++
 .../LogisticRegressionModelParams.java             |   2 +-
 .../LogisticRegressionModelServable.java           | 116 +++++++++++++++++++++
 flink-ml-uber/pom.xml                              |   7 ++
 pom.xml                                            |   1 +
 tools/ci/stage.sh                                  |   2 +
 17 files changed, 469 insertions(+), 81 deletions(-)

diff --git a/flink-ml-benchmark/pom.xml b/flink-ml-benchmark/pom.xml
index ef6269ca..f584a3e3 100644
--- a/flink-ml-benchmark/pom.xml
+++ b/flink-ml-benchmark/pom.xml
@@ -44,6 +44,13 @@ under the License.
             <scope>provided</scope>
         </dependency>
 
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-ml-servable-lib</artifactId>
+            <version>${project.version}</version>
+            <scope>provided</scope>
+        </dependency>
+
         <dependency>
             <groupId>org.apache.flink</groupId>
             <artifactId>flink-ml-core</artifactId>
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
index 78a1fa34..ec97b48c 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/util/TestUtils.java
@@ -37,6 +37,7 @@ import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.servable.api.DataFrame;
 import org.apache.flink.ml.servable.api.TransformerServable;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -56,6 +57,7 @@ import java.io.DataOutputStream;
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
 
@@ -305,4 +307,21 @@ public class TestUtils {
         }
         return 0;
     }
+
+    /** Construct DataFrame from a list of Flink {@link Row}s. */
+    public static DataFrame constructDataFrame(
+            List<String> columnNames,
+            List<org.apache.flink.ml.servable.types.DataType> dataTypes,
+            List<Row> rows) {
+        List<org.apache.flink.ml.servable.api.Row> rowList = new ArrayList<>();
+        for (Row row : rows) {
+            List<Object> values = new ArrayList<>();
+            for (int i = 0; i < row.getArity(); i++) {
+                Object value = row.getField(i);
+                values.add(value);
+            }
+            rowList.add(new org.apache.flink.ml.servable.api.Row(values));
+        }
+        return new DataFrame(columnNames, dataTypes, rowList);
+    }
 }
diff --git a/flink-ml-lib/pom.xml b/flink-ml-lib/pom.xml
index 777c6b98..977e8e8c 100644
--- a/flink-ml-lib/pom.xml
+++ b/flink-ml-lib/pom.xml
@@ -37,6 +37,13 @@ under the License.
       <scope>provided</scope>
     </dependency>
 
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-ml-servable-lib</artifactId>
+      <version>${project.version}</version>
+      <scope>provided</scope>
+    </dependency>
+
     <dependency>
       <groupId>org.apache.flink</groupId>
       <artifactId>flink-ml-core</artifactId>
@@ -124,6 +131,13 @@ under the License.
       <scope>test</scope>
       <type>test-jar</type>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-ml-servable-lib</artifactId>
+      <version>${project.version}</version>
+      <scope>test</scope>
+      <type>test-jar</type>
+    </dependency>
   </dependencies>
 
   <build>
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index eeb7338a..87cc650c 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -115,7 +115,7 @@ public class LogisticRegression
                 optimizer.optimize(initModelData, trainData, 
BinaryLogisticLoss.INSTANCE);
 
         DataStream<LogisticRegressionModelData> modelData =
-                rawModelData.map(vector -> new 
LogisticRegressionModelData(vector, 0));
+                rawModelData.map(vector -> new 
LogisticRegressionModelData(vector, 0L));
         LogisticRegressionModel model =
                 new 
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
         ParamUtils.updateExistingParams(model, paramMap);
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
index 675846a6..e777c5fa 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
@@ -21,14 +21,13 @@ package 
org.apache.flink.ml.classification.logisticregression;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.api.Model;
 import org.apache.flink.ml.common.broadcast.BroadcastUtils;
 import org.apache.flink.ml.common.datastream.TableUtils;
-import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
-import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
@@ -68,7 +67,7 @@ public class LogisticRegressionModel
         DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
         final String broadcastModelKey = "broadcastModelKey";
         DataStream<LogisticRegressionModelData> modelDataStream =
-                LogisticRegressionModelData.getModelDataStream(modelDataTable);
+                
LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable);
         RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
         RowTypeInfo outputTypeInfo =
                 new RowTypeInfo(
@@ -87,7 +86,7 @@ public class LogisticRegressionModel
                         inputList -> {
                             DataStream inputData = inputList.get(0);
                             return inputData.map(
-                                    new 
PredictLabelFunction(broadcastModelKey, getFeaturesCol()),
+                                    new 
PredictLabelFunction(broadcastModelKey, paramMap),
                                     outputTypeInfo);
                         });
         return new Table[] {tEnv.fromDataStream(predictionResult)};
@@ -108,9 +107,9 @@ public class LogisticRegressionModel
     public void save(String path) throws IOException {
         ReadWriteUtils.saveMetadata(this, path);
         ReadWriteUtils.saveModelData(
-                LogisticRegressionModelData.getModelDataStream(modelDataTable),
+                
LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable),
                 path,
-                new LogisticRegressionModelData.ModelDataEncoder());
+                new LogisticRegressionModelDataUtil.ModelDataEncoder());
     }
 
     public static LogisticRegressionModel load(StreamTableEnvironment tEnv, 
String path)
@@ -118,10 +117,14 @@ public class LogisticRegressionModel
         LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
         Table modelDataTable =
                 ReadWriteUtils.loadModelData(
-                        tEnv, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+                        tEnv, path, new 
LogisticRegressionModelDataUtil.ModelDataDecoder());
         return model.setModelData(modelDataTable);
     }
 
+    public static LogisticRegressionModelServable loadServable(String path) 
throws IOException {
+        return LogisticRegressionModelServable.load(path);
+    }
+
     @Override
     public Map<Param<?>, Object> getParamMap() {
         return paramMap;
@@ -132,39 +135,29 @@ public class LogisticRegressionModel
 
         private final String broadcastModelKey;
 
-        private final String featuresCol;
+        private final Map<Param<?>, Object> params;
 
-        private DenseVector coefficient;
+        private LogisticRegressionModelServable servable;
 
-        public PredictLabelFunction(String broadcastModelKey, String 
featuresCol) {
+        public PredictLabelFunction(String broadcastModelKey, Map<Param<?>, 
Object> params) {
             this.broadcastModelKey = broadcastModelKey;
-            this.featuresCol = featuresCol;
+            this.params = params;
         }
 
         @Override
         public Row map(Row dataPoint) {
-            if (coefficient == null) {
+            if (servable == null) {
                 LogisticRegressionModelData modelData =
                         (LogisticRegressionModelData)
                                 
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
-                coefficient = modelData.coefficient;
+                servable = new LogisticRegressionModelServable(modelData);
+                ParamUtils.updateExistingParams(servable, params);
             }
-            DenseVector features = ((Vector) 
dataPoint.getField(featuresCol)).toDense();
-            Row predictionResult = predictOneDataPoint(features, coefficient);
-            return Row.join(dataPoint, predictionResult);
-        }
-    }
+            Vector features = (Vector) 
dataPoint.getField(servable.getFeaturesCol());
+
+            Tuple2<Double, DenseVector> predictionResult = 
servable.transform(features);
 
-    /**
-     * The main logic that predicts one input data point.
-     *
-     * @param feature The input feature.
-     * @param coefficient The model parameters.
-     * @return The prediction label and the raw probabilities.
-     */
-    protected static Row predictOneDataPoint(Vector feature, DenseVector 
coefficient) {
-        double dotValue = BLAS.dot(feature, coefficient);
-        double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue));
-        return Row.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob));
+            return Row.join(dataPoint, Row.of(predictionResult.f0, 
predictionResult.f1));
+        }
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java
similarity index 74%
rename from 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
rename to 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java
index da9bf7c4..e6acb7c7 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelDataUtil.java
@@ -24,39 +24,25 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
+import java.io.ByteArrayOutputStream;
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
 import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel} and {@link 
OnlineLogisticRegressionModel}.
- *
- * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
- * to save/load model data.
+ * The utility class which provides methods to convert model data from Table 
to Datastream, and
+ * classes to save/load model data.
  */
-public class LogisticRegressionModelData {
-
-    public DenseVector coefficient;
-    public long modelVersion;
-
-    public LogisticRegressionModelData() {}
-
-    public LogisticRegressionModelData(DenseVector coefficient, long 
modelVersion) {
-        this.coefficient = coefficient;
-        this.modelVersion = modelVersion;
-    }
+public class LogisticRegressionModelDataUtil {
 
     /**
      * Generates a Table containing a {@link LogisticRegressionModelData} 
instance with randomly
@@ -106,17 +92,36 @@ public class LogisticRegressionModelData {
                 .map(x -> new LogisticRegressionModelData(x.getFieldAs(0), 
x.getFieldAs(1)));
     }
 
+    /**
+     * Converts the table model to a data stream of bytes.
+     *
+     * @param modelDataTable The table of model data.
+     * @return The data stream of serialized model data.
+     */
+    public static DataStream<byte[]> getModelDataByteStream(Table 
modelDataTable) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelDataTable).getTableEnvironment();
+
+        return tEnv.toDataStream(modelDataTable)
+                .map(
+                        x -> {
+                            LogisticRegressionModelData modelData =
+                                    new LogisticRegressionModelData(
+                                            x.getFieldAs(0), x.getFieldAs(1));
+
+                            ByteArrayOutputStream outputStream = new 
ByteArrayOutputStream();
+                            modelData.encode(outputStream);
+                            return outputStream.toByteArray();
+                        });
+    }
+
     /** Data encoder for {@link LogisticRegression} and {@link 
OnlineLogisticRegression}. */
     public static class ModelDataEncoder implements 
Encoder<LogisticRegressionModelData> {
-        private final DenseVectorSerializer serializer = new 
DenseVectorSerializer();
 
         @Override
         public void encode(LogisticRegressionModelData modelData, OutputStream 
outputStream)
                 throws IOException {
-            DataOutputViewStreamWrapper dataOutputViewStreamWrapper =
-                    new DataOutputViewStreamWrapper(outputStream);
-            serializer.serialize(modelData.coefficient, 
dataOutputViewStreamWrapper);
-            dataOutputViewStreamWrapper.writeLong(modelData.modelVersion);
+            modelData.encode(outputStream);
         }
     }
 
@@ -127,17 +132,10 @@ public class LogisticRegressionModelData {
         public Reader<LogisticRegressionModelData> createReader(
                 Configuration configuration, FSDataInputStream inputStream) {
             return new Reader<LogisticRegressionModelData>() {
-                private final DenseVectorSerializer serializer = new 
DenseVectorSerializer();
-
                 @Override
                 public LogisticRegressionModelData read() throws IOException {
                     try {
-                        DataInputViewStreamWrapper dataInputViewStreamWrapper =
-                                new DataInputViewStreamWrapper(inputStream);
-                        DenseVector coefficient =
-                                
serializer.deserialize(dataInputViewStreamWrapper);
-                        long modelVersion = 
dataInputViewStreamWrapper.readLong();
-                        return new LogisticRegressionModelData(coefficient, 
modelVersion);
+                        return LogisticRegressionModelData.decode(inputStream);
                     } catch (EOFException e) {
                         return null;
                     }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
index 79566a74..1bc19938 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java
@@ -89,7 +89,7 @@ public class OnlineLogisticRegression
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
         DataStream<LogisticRegressionModelData> modelDataStream =
-                
LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+                
LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable);
 
         RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
         TypeInformation pointTypeInfo;
@@ -413,9 +413,9 @@ public class OnlineLogisticRegression
     public void save(String path) throws IOException {
         ReadWriteUtils.saveMetadata(this, path);
         ReadWriteUtils.saveModelData(
-                
LogisticRegressionModelData.getModelDataStream(initModelDataTable),
+                
LogisticRegressionModelDataUtil.getModelDataStream(initModelDataTable),
                 path,
-                new LogisticRegressionModelData.ModelDataEncoder());
+                new LogisticRegressionModelDataUtil.ModelDataEncoder());
     }
 
     public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, 
String path)
@@ -423,7 +423,7 @@ public class OnlineLogisticRegression
         OnlineLogisticRegression onlineLogisticRegression = 
ReadWriteUtils.loadStageParam(path);
         Table modelDataTable =
                 ReadWriteUtils.loadModelData(
-                        tEnv, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+                        tEnv, path, new 
LogisticRegressionModelDataUtil.ModelDataDecoder());
         onlineLogisticRegression.setInitialModelData(modelDataTable);
         return onlineLogisticRegression;
     }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
index eab5cf63..f0608613 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.ml.api.Model;
@@ -48,8 +49,6 @@ import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
 
-import static 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint;
-
 /**
  * A Model which classifies data using the model data computed by {@link 
OnlineLogisticRegression}.
  */
@@ -88,12 +87,12 @@ public class OnlineLogisticRegressionModel
         DataStream<Row> predictionResult =
                 tEnv.toDataStream(inputs[0])
                         .connect(
-                                
LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                
LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable)
                                         .broadcast())
                         .transform(
                                 "PredictLabelOperator",
                                 outputTypeInfo,
-                                new PredictLabelOperator(inputTypeInfo, 
getFeaturesCol()));
+                                new PredictLabelOperator(inputTypeInfo, 
paramMap));
 
         return new Table[] {tEnv.fromDataStream(predictionResult)};
     }
@@ -103,14 +102,15 @@ public class OnlineLogisticRegressionModel
             implements TwoInputStreamOperator<Row, 
LogisticRegressionModelData, Row> {
         private final RowTypeInfo inputTypeInfo;
 
-        private final String featuresCol;
+        private final Map<Param<?>, Object> params;
         private ListState<Row> bufferedPointsState;
         private DenseVector coefficient;
         private long modelDataVersion = 0L;
+        private LogisticRegressionModelServable servable;
 
-        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String 
featuresCol) {
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, Map<Param<?>, 
Object> params) {
             this.inputTypeInfo = inputTypeInfo;
-            this.featuresCol = featuresCol;
+            this.params = params;
         }
 
         @Override
@@ -156,15 +156,22 @@ public class OnlineLogisticRegressionModel
                 bufferedPointsState.add(dataPoint);
                 return;
             }
-            Vector features = (Vector) dataPoint.getField(featuresCol);
-            Row predictionResult = predictOneDataPoint(features, coefficient);
+            if (servable == null) {
+                servable =
+                        new LogisticRegressionModelServable(
+                                new LogisticRegressionModelData(coefficient, 
0L));
+                ParamUtils.updateExistingParams(servable, params);
+            }
+            Vector features = (Vector) 
dataPoint.getField(servable.getFeaturesCol());
+            Tuple2<Double, DenseVector> predictionResult = 
servable.transform(features);
+
             output.collect(
                     new StreamRecord<>(
                             Row.join(
                                     dataPoint,
                                     Row.of(
-                                            predictionResult.getField(0),
-                                            predictionResult.getField(1),
+                                            predictionResult.f0,
+                                            predictionResult.f1,
                                             modelDataVersion))));
         }
     }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index ad9a5416..f899c281 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -24,11 +24,16 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelServable;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.SparseVector;
 import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.types.BasicType;
+import org.apache.flink.ml.servable.types.DataTypes;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.TestUtils;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -45,10 +50,13 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
+import java.io.ByteArrayInputStream;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
+import static org.apache.flink.ml.util.TestUtils.saveAndLoadServable;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
@@ -99,6 +107,8 @@ public class LogisticRegressionTest extends AbstractTestBase 
{
 
     private Table multinomialDataTable;
 
+    private DataFrame binomialDataDataFrame;
+
     @Before
     public void before() {
         env = TestUtils.getExecutionEnvironment();
@@ -122,6 +132,15 @@ public class LogisticRegressionTest extends 
AbstractTestBase {
                                             DenseVectorTypeInfo.INSTANCE, 
Types.DOUBLE, Types.DOUBLE
                                         },
                                         new String[] {"features", "label", 
"weight"})));
+        binomialDataDataFrame =
+                TestUtils.constructDataFrame(
+                        new ArrayList<>(Arrays.asList("features", "label", 
"weight")),
+                        new ArrayList<>(
+                                Arrays.asList(
+                                        DataTypes.VECTOR(BasicType.DOUBLE),
+                                        DataTypes.DOUBLE,
+                                        DataTypes.DOUBLE)),
+                        binomialTrainData);
     }
 
     @SuppressWarnings("ConstantConditions, unchecked")
@@ -143,6 +162,26 @@ public class LogisticRegressionTest extends 
AbstractTestBase {
         }
     }
 
+    private void verifyPredictionResult(
+            DataFrame output, String featuresCol, String predictionCol, String 
rawPredictionCol) {
+        int featuresColIndex = output.getIndex(featuresCol);
+        int predictionColIndex = output.getIndex(predictionCol);
+        int rawPredictionColIndex = output.getIndex(rawPredictionCol);
+
+        for (org.apache.flink.ml.servable.api.Row predictionRow : 
output.collect()) {
+            DenseVector feature = ((Vector) 
predictionRow.get(featuresColIndex)).toDense();
+            double prediction = (double) predictionRow.get(predictionColIndex);
+            DenseVector rawPrediction = (DenseVector) 
predictionRow.get(rawPredictionColIndex);
+            if (feature.get(0) <= 5) {
+                assertEquals(0, prediction, TOLERANCE);
+                assertTrue(rawPrediction.get(0) > 0.5);
+            } else {
+                assertEquals(1, prediction, TOLERANCE);
+                assertTrue(rawPrediction.get(0) < 0.5);
+            }
+        }
+    }
+
     @Test
     public void testParam() {
         LogisticRegression logisticRegression = new LogisticRegression();
@@ -268,7 +307,7 @@ public class LogisticRegressionTest extends 
AbstractTestBase {
         LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
         List<LogisticRegressionModelData> modelData =
                 IteratorUtils.toList(
-                        
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                        
LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0])
                                 .executeAndCollect());
         assertEquals(1, modelData.size());
         assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, 0.1);
@@ -290,6 +329,47 @@ public class LogisticRegressionTest extends 
AbstractTestBase {
                 logisticRegression.getRawPredictionCol());
     }
 
+    @Test
+    public void testSaveLoadServableAndPredict() throws Exception {
+        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
+        LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
+
+        LogisticRegressionModelServable servable =
+                saveAndLoadServable(
+                        tEnv,
+                        model,
+                        tempFolder.newFolder().getAbsolutePath(),
+                        LogisticRegressionModel::loadServable);
+
+        DataFrame output = servable.transform(binomialDataDataFrame);
+        verifyPredictionResult(
+                output,
+                servable.getFeaturesCol(),
+                servable.getPredictionCol(),
+                servable.getRawPredictionCol());
+    }
+
+    @Test
+    public void testSetModelDataToServable() throws Exception {
+        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
+        LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
+        byte[] serializedModelData =
+                
LogisticRegressionModelDataUtil.getModelDataByteStream(model.getModelData()[0])
+                        .executeAndCollect()
+                        .next();
+
+        LogisticRegressionModelServable servable = new 
LogisticRegressionModelServable();
+        ParamUtils.updateExistingParams(servable, model.getParamMap());
+        servable.setModelData(new ByteArrayInputStream(serializedModelData));
+
+        DataFrame output = servable.transform(binomialDataDataFrame);
+        verifyPredictionResult(
+                output,
+                servable.getFeaturesCol(),
+                servable.getPredictionCol(),
+                servable.getRawPredictionCol());
+    }
+
     @Test
     public void testMultinomialFit() {
         try {
@@ -349,7 +429,7 @@ public class LogisticRegressionTest extends 
AbstractTestBase {
                         .fit(binomialDataTable);
         List<LogisticRegressionModelData> modelData =
                 IteratorUtils.toList(
-                        
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                        
LogisticRegressionModelDataUtil.getModelDataStream(model.getModelData()[0])
                                 .executeAndCollect());
         final double errorTol = 1e-3;
         assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, errorTol);
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
index 956e9472..cac9473c 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java
@@ -30,6 +30,7 @@ import org.apache.flink.metrics.Gauge;
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelDataUtil;
 import 
org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
 import 
org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
 import org.apache.flink.ml.linalg.DenseVector;
@@ -293,7 +294,7 @@ public class OnlineLogisticRegressionTest extends 
TestLogger {
         tEnv.toDataStream(outputTable).addSink(outputSink);
 
         Table modelDataTable = onlineModel.getModelData()[0];
-        
LogisticRegressionModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+        
LogisticRegressionModelDataUtil.getModelDataStream(modelDataTable).addSink(modelDataSink);
     }
 
     /** Blocks the thread until Model has set up init model data. */
@@ -512,7 +513,8 @@ public class OnlineLogisticRegressionTest extends 
TestLogger {
 
     @Test
     public void testGenerateRandomModelData() throws Exception {
-        Table modelDataTable = 
LogisticRegressionModelData.generateRandomModelData(tEnv, 2, 2022);
+        Table modelDataTable =
+                LogisticRegressionModelDataUtil.generateRandomModelData(tEnv, 
2, 2022);
         DataStream<Row> modelData = tEnv.toDataStream(modelDataTable);
         Row modelRow = (Row) 
IteratorUtils.toList(modelData.executeAndCollect()).get(0);
         Assert.assertEquals(2, ((DenseVector) modelRow.getField(0)).size());
diff --git a/flink-ml-servable-lib/pom.xml b/flink-ml-servable-lib/pom.xml
new file mode 100644
index 00000000..f556f74b
--- /dev/null
+++ b/flink-ml-servable-lib/pom.xml
@@ -0,0 +1,66 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+<project xmlns="http://maven.apache.org/POM/4.0.0";
+         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance";
+         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/xsd/maven-4.0.0.xsd";>
+    <modelVersion>4.0.0</modelVersion>
+
+    <parent>
+        <artifactId>flink-ml-parent</artifactId>
+        <groupId>org.apache.flink</groupId>
+        <version>2.2-SNAPSHOT</version>
+    </parent>
+
+    <artifactId>flink-ml-servable-lib</artifactId>
+    <name>Flink ML : Servable : Lib</name>
+
+    <dependencies>
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-ml-servable-core</artifactId>
+            <version>${project.version}</version>
+            <scope>provided</scope>
+        </dependency>
+
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-core</artifactId>
+            <version>${flink.version}</version>
+            <scope>provided</scope>
+        </dependency>
+    </dependencies>
+
+    <build>
+        <plugins>
+            <plugin>
+                <groupId>org.apache.maven.plugins</groupId>
+                <artifactId>maven-jar-plugin</artifactId>
+                <executions>
+                    <execution>
+                        <goals>
+                            <goal>test-jar</goal>
+                        </goals>
+                    </execution>
+                </executions>
+            </plugin>
+        </plugins>
+    </build>
+
+</project>
\ No newline at end of file
diff --git 
a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
 
b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
new file mode 100644
index 00000000..28927e47
--- /dev/null
+++ 
b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/** Model data of {@link LogisticRegressionModelServable}. */
+public class LogisticRegressionModelData {
+
+    public DenseVector coefficient;
+
+    public long modelVersion;
+
+    public LogisticRegressionModelData() {}
+
+    public LogisticRegressionModelData(DenseVector coefficient, long 
modelVersion) {
+        this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
+    }
+
+    /**
+     * Serializes the instance and writes to the output stream.
+     *
+     * @param outputStream The stream to write to.
+     */
+    @VisibleForTesting
+    public void encode(OutputStream outputStream) throws IOException {
+        DataOutputViewStreamWrapper dataOutputViewStreamWrapper =
+                new DataOutputViewStreamWrapper(outputStream);
+
+        DenseVectorSerializer serializer = new DenseVectorSerializer();
+        serializer.serialize(coefficient, dataOutputViewStreamWrapper);
+        dataOutputViewStreamWrapper.writeLong(modelVersion);
+    }
+
+    /**
+     * Reads and deserializes the model data from the input stream.
+     *
+     * @param inputStream The stream to read from.
+     * @return The model data instance.
+     */
+    static LogisticRegressionModelData decode(InputStream inputStream) throws 
IOException {
+        DataInputViewStreamWrapper dataInputViewStreamWrapper =
+                new DataInputViewStreamWrapper(inputStream);
+
+        DenseVectorSerializer serializer = new DenseVectorSerializer();
+        DenseVector coefficient = 
serializer.deserialize(dataInputViewStreamWrapper);
+        long modelVersion = dataInputViewStreamWrapper.readLong();
+
+        return new LogisticRegressionModelData(coefficient, modelVersion);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java
 
b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java
similarity index 94%
rename from 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java
rename to 
flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java
index b15b63e6..800764d5 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java
+++ 
b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelParams.java
@@ -23,7 +23,7 @@ import org.apache.flink.ml.common.param.HasPredictionCol;
 import org.apache.flink.ml.common.param.HasRawPredictionCol;
 
 /**
- * Params for {@link LogisticRegressionModel}.
+ * Params for LogisticRegressionModel and LogisticRegressionModelServable.
  *
  * @param <T> The class type of this instance.
  */
diff --git 
a/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java
 
b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java
new file mode 100644
index 00000000..4cec8513
--- /dev/null
+++ 
b/flink-ml-servable-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelServable.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.servable.api.DataFrame;
+import org.apache.flink.ml.servable.api.ModelServable;
+import org.apache.flink.ml.servable.api.Row;
+import org.apache.flink.ml.servable.types.BasicType;
+import org.apache.flink.ml.servable.types.DataTypes;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ServableReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** A Servable which can be used to classifies data in online inference. */
+public class LogisticRegressionModelServable
+        implements ModelServable<LogisticRegressionModelServable>,
+                LogisticRegressionModelParams<LogisticRegressionModelServable> 
{
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private LogisticRegressionModelData modelData;
+
+    public LogisticRegressionModelServable() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    LogisticRegressionModelServable(LogisticRegressionModelData modelData) {
+        this();
+        this.modelData = modelData;
+    }
+
+    @Override
+    public DataFrame transform(DataFrame input) {
+        List<Double> predictionResults = new ArrayList<>();
+        List<DenseVector> rawPredictionResults = new ArrayList<>();
+
+        int featuresColIndex = input.getIndex(getFeaturesCol());
+        for (Row row : input.collect()) {
+            Vector features = (Vector) row.get(featuresColIndex);
+            Tuple2<Double, DenseVector> dataPoint = transform(features);
+            predictionResults.add(dataPoint.f0);
+            rawPredictionResults.add(dataPoint.f1);
+        }
+
+        input.addColumn(getPredictionCol(), DataTypes.DOUBLE, 
predictionResults);
+        input.addColumn(
+                getRawPredictionCol(), DataTypes.VECTOR(BasicType.DOUBLE), 
rawPredictionResults);
+
+        return input;
+    }
+
+    public LogisticRegressionModelServable setModelData(InputStream... 
modelDataInputs)
+            throws IOException {
+        Preconditions.checkArgument(modelDataInputs.length == 1);
+
+        modelData = LogisticRegressionModelData.decode(modelDataInputs[0]);
+        return this;
+    }
+
+    public static LogisticRegressionModelServable load(String path) throws 
IOException {
+        LogisticRegressionModelServable servable =
+                ServableReadWriteUtils.loadServableParam(
+                        path, LogisticRegressionModelServable.class);
+
+        try (InputStream modelData = 
ServableReadWriteUtils.loadModelData(path)) {
+            servable.setModelData(modelData);
+            return servable;
+        }
+    }
+
+    /**
+     * The main logic that predicts one input data point.
+     *
+     * @param feature The input feature.
+     * @return The prediction label and the raw probabilities.
+     */
+    protected Tuple2<Double, DenseVector> transform(Vector feature) {
+        double dotValue = BLAS.dot(feature, modelData.coefficient);
+        double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue));
+        return Tuple2.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, 
prob));
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git a/flink-ml-uber/pom.xml b/flink-ml-uber/pom.xml
index 2eedeedf..423373fa 100644
--- a/flink-ml-uber/pom.xml
+++ b/flink-ml-uber/pom.xml
@@ -53,6 +53,12 @@ under the License.
       <version>${project.version}</version>
     </dependency>
 
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-ml-servable-lib</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+
     <dependency>
       <groupId>org.apache.flink</groupId>
       <artifactId>flink-ml-lib</artifactId>
@@ -84,6 +90,7 @@ under the License.
                   <include>org.apache.flink:flink-ml-servable-core</include>
                   <include>org.apache.flink:flink-ml-core</include>
                   <include>org.apache.flink:flink-ml-iteration</include>
+                  <include>org.apache.flink:flink-ml-servable-lib</include>
                   <include>org.apache.flink:flink-ml-lib</include>
                   <include>org.apache.flink:flink-ml-benchmark</include>
                   <include>dev.ludovic.netlib:blas</include>
diff --git a/pom.xml b/pom.xml
index 743d246f..2a0af804 100644
--- a/pom.xml
+++ b/pom.xml
@@ -55,6 +55,7 @@ under the License.
     <module>flink-ml-servable-core</module>
     <module>flink-ml-core</module>
     <module>flink-ml-iteration</module>
+    <module>flink-ml-servable-lib</module>
     <module>flink-ml-lib</module>
     <module>flink-ml-tests</module>
     <module>flink-ml-uber</module>
diff --git a/tools/ci/stage.sh b/tools/ci/stage.sh
index c44c023d..4cf36227 100755
--- a/tools/ci/stage.sh
+++ b/tools/ci/stage.sh
@@ -23,11 +23,13 @@ STAGE_TESTS="tests"
 STAGE_MISC="misc"
 
 MODULES_CORE="\
+flink-ml-servable-core,\
 flink-ml-core,\
 flink-ml-iteration,\
 "
 
 MODULES_LIB="\
+flink-ml-servable-lib,\
 flink-ml-lib,\
 "
 


Reply via email to