This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new f34dbb7  [FLINK-27093] Add Transformer and Estimator for 
LinearRegression
f34dbb7 is described below

commit f34dbb708a0d151caa5afab593d04badeb549224
Author: Zhipeng Zhang <zhangzhipe...@gmail.com>
AuthorDate: Fri May 6 16:02:55 2022 +0800

    [FLINK-27093] Add Transformer and Estimator for LinearRegression
    
    This closes #90.
---
 .../flink/ml/common/datastream/AllReduceImpl.java  |   2 +
 .../ml/common/datastream/DataStreamUtils.java      |  83 ++++-
 .../main/java/org/apache/flink/ml/linalg/BLAS.java |  19 +-
 .../ml/common/datastream/DataStreamUtilsTest.java  |  11 +
 .../java/org/apache/flink/ml/linalg/BLASTest.java  |  16 +
 .../logisticregression/LogisticGradient.java       |  97 -----
 .../logisticregression/LogisticRegression.java     | 401 +++------------------
 .../LogisticRegressionModel.java                   |  80 ++--
 .../LogisticRegressionParams.java                  |   2 +
 .../ml/common/lossfunc/BinaryLogisticLoss.java     |  50 +++
 .../flink/ml/common/lossfunc/LeastSquareLoss.java  |  50 +++
 .../apache/flink/ml/common/lossfunc/LossFunc.java  |  51 +++
 .../flink/ml/common/optimizer/Optimizer.java       |  46 +++
 .../ml/common/optimizer/RegularizationUtils.java   |  92 +++++
 .../org/apache/flink/ml/common/optimizer/SGD.java  | 390 ++++++++++++++++++++
 .../flink/ml/common/param/HasElasticNet.java       |  47 +++
 .../linearregression/LinearRegression.java         | 122 +++++++
 .../linearregression/LinearRegressionModel.java}   | 113 +++---
 .../LinearRegressionModelData.java                 | 111 ++++++
 .../LinearRegressionModelParams.java               |  29 ++
 .../linearregression/LinearRegressionParams.java}  |  12 +-
 .../ml/classification/LogisticRegressionTest.java  | 101 ++++--
 .../ml/common/lossfunc/BinaryLogisticLossTest.java |  53 +++
 .../ml/common/lossfunc/LeastSquareLossTest.java    |  51 +++
 .../common/optimizer/RegularizationUtilsTest.java  |  47 +++
 .../flink/ml/regression/LinearRegressionTest.java  | 255 +++++++++++++
 26 files changed, 1721 insertions(+), 610 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
index 760b5db..167572a 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.ml.common.datastream;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
@@ -49,6 +50,7 @@ import java.util.Map;
  * <li>All workers do reduce on all data it received and then broadcast 
partial results to others.
  * <li>All workers merge partial results into final result.
  */
+@Internal
 class AllReduceImpl {
 
     @VisibleForTesting static final int CHUNK_SIZE = 1024 * 4;
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 7eea6b0..10073b8 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -18,12 +18,16 @@
 
 package org.apache.flink.ml.common.datastream;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
@@ -32,6 +36,7 @@ import 
org.apache.flink.streaming.api.operators.TimestampedCollector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 /** Provides utility functions for {@link DataStream}. */
+@Internal
 public class DataStreamUtils {
     /**
      * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
@@ -55,8 +60,8 @@ public class DataStreamUtils {
      *
      * @param input The input data stream.
      * @param func The user defined mapPartition function.
-     * @param <IN> The class type of the input element.
-     * @param <OUT> The class type of output element.
+     * @param <IN> The class type of the input.
+     * @param <OUT> The class type of output.
      * @return The result data stream.
      */
     public static <IN, OUT> DataStream<OUT> mapPartition(
@@ -67,6 +72,28 @@ public class DataStreamUtils {
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output 
stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, 
ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("reduce", input.getType(), new 
ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());
+        if (partialReducedStream.getParallelism() == 1) {
+            return partialReducedStream;
+        } else {
+            return partialReducedStream
+                    .transform("reduce", input.getType(), new 
ReduceOperator<>(func))
+                    .setParallelism(1);
+        }
+    }
+
     /**
      * A stream operator to apply {@link MapPartitionFunction} on each 
partition of the input
      * bounded data stream.
@@ -103,4 +130,56 @@ public class DataStreamUtils {
             valuesState.add(input.getValue());
         }
     }
+
+    /** A stream operator to apply {@link ReduceFunction} on the input bounded 
data stream. */
+    private static class ReduceOperator<T> extends 
AbstractUdfStreamOperator<T, ReduceFunction<T>>
+            implements OneInputStreamOperator<T, T>, BoundedOneInput {
+        /** The temp result of the reduce function. */
+        private T result;
+
+        private ListState<T> state;
+
+        public ReduceOperator(ReduceFunction<T> userFunction) {
+            super(userFunction);
+        }
+
+        @Override
+        public void endInput() {
+            if (result != null) {
+                output.collect(new StreamRecord<>(result));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<T> streamRecord) throws 
Exception {
+            if (result == null) {
+                result = streamRecord.getValue();
+            } else {
+                result = userFunction.reduce(streamRecord.getValue(), result);
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            state =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<T>(
+                                            "state",
+                                            getOperatorConfig()
+                                                    .getTypeSerializerIn(
+                                                            0, 
getClass().getClassLoader())));
+            result = OperatorStateUtils.getUniqueElement(state, 
"state").orElse(null);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            state.clear();
+            if (result != null) {
+                state.add(result);
+            }
+        }
+    }
 }
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
index 383bf66..24cc3ef 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -34,10 +34,16 @@ public class BLAS {
     /** y += a * x . */
     public static void axpy(double a, Vector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
+        axpy(a, x, y, x.size());
+    }
+
+    /** y += a * x for the first k dimensions, with the other dimensions 
unchanged. */
+    public static void axpy(double a, Vector x, DenseVector y, int k) {
+        Preconditions.checkArgument(x.size() >= k && y.size() >= k);
         if (x instanceof SparseVector) {
-            axpy(a, (SparseVector) x, y);
+            axpy(a, (SparseVector) x, y, k);
         } else {
-            axpy(a, (DenseVector) x, y);
+            axpy(a, (DenseVector) x, y, k);
         }
     }
 
@@ -112,13 +118,16 @@ public class BLAS {
                 1);
     }
 
-    private static void axpy(double a, DenseVector x, DenseVector y) {
-        JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+    private static void axpy(double a, DenseVector x, DenseVector y, int k) {
+        JAVA_BLAS.daxpy(k, a, x.values, 1, y.values, 1);
     }
 
-    private static void axpy(double a, SparseVector x, DenseVector y) {
+    private static void axpy(double a, SparseVector x, DenseVector y, int k) {
         for (int i = 0; i < x.indices.length; i++) {
             int index = x.indices[i];
+            if (index >= k) {
+                return;
+            }
             y.values[index] += a * x.values[i];
         }
     }
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index 7933859..7dc88c8 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.api.common.typeinfo.Types;
@@ -63,6 +64,16 @@ public class DataStreamUtilsTest {
                 new int[] {5, 5, 5, 5}, 
counts.stream().mapToInt(Integer::intValue).toArray());
     }
 
+    @Test
+    public void testReduce() throws Exception {
+        DataStream<Long> dataStream =
+                env.fromParallelCollection(new NumberSequenceIterator(0L, 
19L), Types.LONG);
+        DataStream<Long> result =
+                DataStreamUtils.reduce(dataStream, (ReduceFunction<Long>) 
Long::sum);
+        List<Long> sum = IteratorUtils.toList(result.executeAndCollect());
+        assertArrayEquals(new long[] {190L}, 
sum.stream().mapToLong(Long::longValue).toArray());
+    }
+
     /** A simple implementation for a {@link MapPartitionFunction}. */
     private static class TestMapPartitionFunc extends 
RichMapPartitionFunction<Long, Integer> {
 
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java 
b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
index 7bcd853..7055c62 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
@@ -51,6 +51,22 @@ public class BLASTest {
         assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE);
     }
 
+    @Test
+    public void testAxpyK() {
+        // Tests axpy(dense, dense, k).
+        DenseVector anotherDenseVec = Vectors.dense(1, 2, 3);
+        BLAS.axpy(1, inputDenseVec, anotherDenseVec, 3);
+        double[] expectedResult = new double[] {2, 0, 6};
+        assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE);
+
+        // Tests axpy(sparse, dense, k).
+        SparseVector sparseVec = Vectors.sparse(5, new int[] {0, 2, 4}, new 
double[] {1, 3, 5});
+        anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5, 6, 7);
+        BLAS.axpy(2, sparseVec, anotherDenseVec, 5);
+        expectedResult = new double[] {3, 2, 9, 4, 15, 6, 7};
+        assertArrayEquals(expectedResult, anotherDenseVec.values, TOLERANCE);
+    }
+
     @Test
     public void testDot() {
         DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5);
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
deleted file mode 100644
index 13f753b..0000000
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticGradient.java
+++ /dev/null
@@ -1,97 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.classification.logisticregression;
-
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
-import org.apache.flink.ml.linalg.BLAS;
-import org.apache.flink.ml.linalg.DenseVector;
-
-import java.io.Serializable;
-import java.util.List;
-
-/**
- * Utility class to compute gradient and loss for logistic loss function.
- *
- * <p>See http://mlwiki.org/index.php/Logistic_Regression.
- */
-public class LogisticGradient implements Serializable {
-
-    /** L2 regularization term. */
-    private final double l2;
-
-    public LogisticGradient(double l2) {
-        this.l2 = l2;
-    }
-
-    /**
-     * Computes weight sum and loss sum on a set of samples.
-     *
-     * @param dataPoints A sample set of train data.
-     * @param coefficient The model parameters.
-     * @return Weight sum and loss sum of the input data.
-     */
-    public Tuple2<Double, Double> computeLoss(
-            List<LabeledPointWithWeight> dataPoints, DenseVector coefficient) {
-        double weightSum = 0.0;
-        double lossSum = 0.0;
-        for (LabeledPointWithWeight dataPoint : dataPoints) {
-            lossSum += dataPoint.getWeight() * computeLoss(dataPoint, 
coefficient);
-            weightSum += dataPoint.getWeight();
-        }
-        if (Double.compare(0, l2) != 0) {
-            lossSum += l2 * Math.pow(BLAS.norm2(coefficient), 2);
-        }
-        return Tuple2.of(weightSum, lossSum);
-    }
-
-    /**
-     * Computes gradient on a set of samples.
-     *
-     * @param dataPoints A sample set of train data.
-     * @param coefficient The model parameters.
-     * @param cumGradient The accumulated gradients.
-     */
-    public void computeGradient(
-            List<LabeledPointWithWeight> dataPoints,
-            DenseVector coefficient,
-            DenseVector cumGradient) {
-        for (LabeledPointWithWeight dataPoint : dataPoints) {
-            computeGradient(dataPoint, coefficient, cumGradient);
-        }
-        if (Double.compare(0, l2) != 0) {
-            BLAS.axpy(l2 * 2, coefficient, cumGradient);
-        }
-    }
-
-    private double computeLoss(LabeledPointWithWeight dataPoint, DenseVector 
coefficient) {
-        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
-        double labelScaled = 2 * dataPoint.getLabel() - 1;
-        return Math.log(1 + Math.exp(-dot * labelScaled));
-    }
-
-    private void computeGradient(
-            LabeledPointWithWeight dataPoint, DenseVector coefficient, 
DenseVector cumGradient) {
-        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
-        double labelScaled = 2 * dataPoint.getLabel() - 1;
-        double multiplier =
-                dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * 
labelScaled) + 1));
-        BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient);
-    }
-}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 58cf0ce..f64fc10 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -18,54 +18,26 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.state.ListState;
-import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.iteration.DataStreamList;
-import org.apache.flink.iteration.IterationBody;
-import org.apache.flink.iteration.IterationBodyResult;
-import org.apache.flink.iteration.IterationConfig;
-import org.apache.flink.iteration.IterationListener;
-import org.apache.flink.iteration.Iterations;
-import org.apache.flink.iteration.ReplayableDataStreamList;
-import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.ml.api.Estimator;
 import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
-import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
-import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.common.lossfunc.BinaryLogisticLoss;
+import org.apache.flink.ml.common.optimizer.Optimizer;
+import org.apache.flink.ml.common.optimizer.SGD;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.runtime.state.StateInitializationContext;
-import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
-import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.BoundedOneInput;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
-import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
-import org.apache.flink.util.Collector;
-import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.Preconditions;
 
-import org.apache.commons.collections.IteratorUtils;
-
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
-import java.util.Random;
 
 /**
  * An Estimator which implements the logistic regression algorithm.
@@ -82,21 +54,6 @@ public class LogisticRegression
         ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
-
-    @Override
-    public void save(String path) throws IOException {
-        ReadWriteUtils.saveMetadata(this, path);
-    }
-
-    public static LogisticRegression load(StreamTableEnvironment tEnv, String 
path)
-            throws IOException {
-        return ReadWriteUtils.loadStageParam(path);
-    }
-
     @Override
     @SuppressWarnings({"rawTypes", "ConstantConditions"})
     public LogisticRegressionModel fit(Table... inputs) {
@@ -107,15 +64,16 @@ public class LogisticRegression
                 "Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
         DataStream<LabeledPointWithWeight> trainData =
                 tEnv.toDataStream(inputs[0])
                         .map(
                                 dataPoint -> {
-                                    Double weight =
+                                    double weight =
                                             getWeightCol() == null
                                                     ? 1.0
                                                     : (Double) 
dataPoint.getField(getWeightCol());
-                                    Double label = (Double) 
dataPoint.getField(getLabelCol());
+                                    double label = (Double) 
dataPoint.getField(getLabelCol());
                                     boolean isBinomial =
                                             Double.compare(0., label) == 0
                                                     || Double.compare(1., 
label) == 0;
@@ -127,327 +85,50 @@ public class LogisticRegression
                                             (DenseVector) 
dataPoint.getField(getFeaturesCol());
                                     return new 
LabeledPointWithWeight(features, label, weight);
                                 });
-        DataStream<double[]> initModelData =
-                trainData
-                        .transform("getModelDim", BasicTypeInfo.INT_TYPE_INFO, 
new GetModelDim())
-                        .setParallelism(1)
-                        .broadcast()
-                        .map(double[]::new);
 
-        DataStream<LogisticRegressionModelData> modelData = train(trainData, 
initModelData);
+        DataStream<DenseVector> initModelData =
+                DataStreamUtils.reduce(
+                                trainData.map(x -> x.getFeatures().size()),
+                                (ReduceFunction<Integer>)
+                                        (t0, t1) -> {
+                                            Preconditions.checkState(
+                                                    t0.equals(t1),
+                                                    "The training data should 
all have same dimensions.");
+                                            return t0;
+                                        })
+                        .map(DenseVector::new);
+
+        Optimizer optimizer =
+                new SGD(
+                        getMaxIter(),
+                        getLearningRate(),
+                        getGlobalBatchSize(),
+                        getTol(),
+                        getReg(),
+                        getElasticNet());
+        DataStream<DenseVector> rawModelData =
+                optimizer.optimize(initModelData, trainData, 
BinaryLogisticLoss.INSTANCE);
+
+        DataStream<LogisticRegressionModelData> modelData =
+                rawModelData.map(LogisticRegressionModelData::new);
         LogisticRegressionModel model =
                 new 
LogisticRegressionModel().setModelData(tEnv.fromDataStream(modelData));
         ReadWriteUtils.updateExistingParams(model, paramMap);
         return model;
     }
 
-    /** Gets the dimension of the model data. */
-    private static class GetModelDim extends AbstractStreamOperator<Integer>
-            implements OneInputStreamOperator<LabeledPointWithWeight, 
Integer>, BoundedOneInput {
-
-        private int dim = 0;
-
-        private ListState<Integer> dimState;
-
-        @Override
-        public void endInput() {
-            output.collect(new StreamRecord<>(dim));
-        }
-
-        @Override
-        public void processElement(StreamRecord<LabeledPointWithWeight> 
streamRecord) {
-            if (dim == 0) {
-                dim = streamRecord.getValue().getFeatures().size();
-            } else {
-                if (dim != streamRecord.getValue().getFeatures().size()) {
-                    throw new RuntimeException(
-                            "The training data should all have same 
dimensions.");
-                }
-            }
-        }
-
-        @Override
-        public void initializeState(StateInitializationContext context) throws 
Exception {
-            super.initializeState(context);
-            dimState =
-                    context.getOperatorStateStore()
-                            .getListState(
-                                    new ListStateDescriptor<>(
-                                            "dimState", 
BasicTypeInfo.INT_TYPE_INFO));
-            dim = OperatorStateUtils.getUniqueElement(dimState, 
"dimState").orElse(0);
-        }
-
-        @Override
-        public void snapshotState(StateSnapshotContext context) throws 
Exception {
-            dimState.clear();
-            dimState.add(dim);
-        }
-    }
-
-    /**
-     * Does machine learning training on the input data with the initialized 
model data.
-     *
-     * @param trainData The training data.
-     * @param initModelData The initialized model.
-     * @return The trained model data.
-     */
-    private DataStream<LogisticRegressionModelData> train(
-            DataStream<LabeledPointWithWeight> trainData, DataStream<double[]> 
initModelData) {
-        LogisticGradient logisticGradient = new LogisticGradient(getReg());
-        DataStreamList resultList =
-                Iterations.iterateBoundedStreamsUntilTermination(
-                        DataStreamList.of(initModelData),
-                        ReplayableDataStreamList.notReplay(trainData),
-                        IterationConfig.newBuilder().build(),
-                        new TrainIterationBody(
-                                logisticGradient,
-                                getGlobalBatchSize(),
-                                getLearningRate(),
-                                getMaxIter(),
-                                getTol()));
-        return resultList.get(0);
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
     }
 
-    /** The iteration implementation for training process. */
-    private static class TrainIterationBody implements IterationBody {
-
-        private final LogisticGradient logisticGradient;
-
-        private final int globalBatchSize;
-
-        private final double learningRate;
-
-        private final int maxIter;
-
-        private final double tol;
-
-        public TrainIterationBody(
-                LogisticGradient logisticGradient,
-                int globalBatchSize,
-                double learningRate,
-                int maxIter,
-                double tol) {
-            this.logisticGradient = logisticGradient;
-            this.globalBatchSize = globalBatchSize;
-            this.learningRate = learningRate;
-            this.maxIter = maxIter;
-            this.tol = tol;
-        }
-
-        @Override
-        public IterationBodyResult process(
-                DataStreamList variableStreams, DataStreamList dataStreams) {
-            // The variable stream at the first iteration is the initialized 
model data.
-            // In the following iterations, it contains: the computed 
gradient, weightSum and
-            // lossSum.
-            DataStream<double[]> variableStream = variableStreams.get(0);
-            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
-            final OutputTag<LogisticRegressionModelData> modelDataOutputTag =
-                    new OutputTag<LogisticRegressionModelData>("MODEL_OUTPUT") 
{};
-            SingleOutputStreamOperator<double[]> gradientAndWeightAndLoss =
-                    trainData
-                            .connect(variableStream)
-                            .transform(
-                                    "CacheDataAndDoTrain",
-                                    
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
-                                    new CacheDataAndDoTrain(
-                                            logisticGradient,
-                                            globalBatchSize,
-                                            learningRate,
-                                            modelDataOutputTag));
-            DataStreamList feedbackVariableStream =
-                    IterationBody.forEachRound(
-                            DataStreamList.of(gradientAndWeightAndLoss),
-                            input -> {
-                                DataStream<double[]> feedback =
-                                        
DataStreamUtils.allReduceSum(input.get(0));
-                                return DataStreamList.of(feedback);
-                            });
-            DataStream<Integer> terminationCriteria =
-                    feedbackVariableStream
-                            .get(0)
-                            .map(
-                                    reducedGradientAndWeightAndLoss -> {
-                                        double[] value = (double[]) 
reducedGradientAndWeightAndLoss;
-                                        return value[value.length - 1] / 
value[value.length - 2];
-                                    })
-                            .flatMap(new TerminateOnMaxIterOrTol(maxIter, 
tol));
-            return new IterationBodyResult(
-                    DataStreamList.of(feedbackVariableStream.get(0)),
-                    
DataStreamList.of(gradientAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
-                    terminationCriteria);
-        }
+    public static LogisticRegression load(StreamTableEnvironment tEnv, String 
path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
     }
 
-    /**
-     * A stream operator that caches the training data in the first iteration 
and updates the model
-     * using gradients iteratively. The first input is the training data, and 
the second input is
-     * the initialized model data or feedback of gradient, weight and loss.
-     */
-    private static class CacheDataAndDoTrain extends 
AbstractStreamOperator<double[]>
-            implements TwoInputStreamOperator<LabeledPointWithWeight, 
double[], double[]>,
-                    IterationListener<double[]> {
-
-        private final int globalBatchSize;
-
-        private int localBatchSize;
-
-        private final double learningRate;
-
-        private final LogisticGradient logisticGradient;
-
-        private DenseVector gradient;
-
-        private DenseVector coefficient;
-
-        private int coefficientDim;
-
-        private ListState<DenseVector> coefficientState;
-
-        private List<LabeledPointWithWeight> trainData;
-
-        private ListState<LabeledPointWithWeight> trainDataState;
-
-        private final Random random = new Random(2021);
-
-        private List<LabeledPointWithWeight> miniBatchData;
-
-        /** The buffer for feedback record: {gradient, weightSum, loss}. */
-        private double[] feedbackBuffer;
-
-        private ListState<double[]> feedbackBufferState;
-
-        private final OutputTag<LogisticRegressionModelData> 
modelDataOutputTag;
-
-        public CacheDataAndDoTrain(
-                LogisticGradient logisticGradient,
-                int globalBatchSize,
-                double learningRate,
-                OutputTag<LogisticRegressionModelData> modelDataOutputTag) {
-            this.logisticGradient = logisticGradient;
-            this.globalBatchSize = globalBatchSize;
-            this.learningRate = learningRate;
-            this.modelDataOutputTag = modelDataOutputTag;
-        }
-
-        @Override
-        public void open() {
-            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
-            int taskId = getRuntimeContext().getIndexOfThisSubtask();
-            localBatchSize = globalBatchSize / numTasks;
-            if (globalBatchSize % numTasks > taskId) {
-                localBatchSize++;
-            }
-            this.miniBatchData = new ArrayList<>(localBatchSize);
-        }
-
-        private List<LabeledPointWithWeight> getMiniBatchData(
-                List<LabeledPointWithWeight> fullBatchData, int batchSize) {
-            miniBatchData.clear();
-            for (int i = 0; i < batchSize; i++) {
-                
miniBatchData.add(fullBatchData.get(random.nextInt(fullBatchData.size())));
-            }
-            return miniBatchData;
-        }
-
-        private void updateModel() {
-            System.arraycopy(feedbackBuffer, 0, gradient.values, 0, 
gradient.size());
-            double weightSum = feedbackBuffer[coefficientDim];
-            BLAS.axpy(-learningRate / weightSum, gradient, coefficient);
-        }
-
-        @Override
-        public void onEpochWatermarkIncremented(
-                int epochWatermark, Context context, Collector<double[]> 
collector)
-                throws Exception {
-            if (epochWatermark == 0) {
-                coefficient = new DenseVector(feedbackBuffer);
-                coefficientDim = coefficient.size();
-                feedbackBuffer = new double[coefficientDim + 2];
-                gradient = new DenseVector(coefficientDim);
-            } else {
-                updateModel();
-            }
-            Arrays.fill(gradient.values, 0);
-            if (trainData == null) {
-                trainData = 
IteratorUtils.toList(trainDataState.get().iterator());
-            }
-            if (trainData.size() > 0) {
-                miniBatchData = getMiniBatchData(trainData, localBatchSize);
-                Tuple2<Double, Double> weightSumAndLossSum =
-                        logisticGradient.computeLoss(miniBatchData, 
coefficient);
-                logisticGradient.computeGradient(miniBatchData, coefficient, 
gradient);
-                System.arraycopy(gradient.values, 0, feedbackBuffer, 0, 
gradient.size());
-                feedbackBuffer[coefficientDim] = weightSumAndLossSum.f0;
-                feedbackBuffer[coefficientDim + 1] = weightSumAndLossSum.f1;
-                collector.collect(feedbackBuffer);
-            }
-        }
-
-        @Override
-        public void onIterationTerminated(Context context, Collector<double[]> 
collector) {
-            trainDataState.clear();
-            coefficientState.clear();
-            feedbackBufferState.clear();
-            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
-                updateModel();
-                context.output(modelDataOutputTag, new 
LogisticRegressionModelData(coefficient));
-            }
-        }
-
-        @Override
-        public void processElement1(StreamRecord<LabeledPointWithWeight> 
streamRecord)
-                throws Exception {
-            trainDataState.add(streamRecord.getValue());
-        }
-
-        @Override
-        public void processElement2(StreamRecord<double[]> streamRecord) {
-            feedbackBuffer = streamRecord.getValue();
-        }
-
-        @Override
-        public void initializeState(StateInitializationContext context) throws 
Exception {
-            super.initializeState(context);
-            trainDataState =
-                    context.getOperatorStateStore()
-                            .getListState(
-                                    new ListStateDescriptor<>(
-                                            "trainDataState",
-                                            
TypeInformation.of(LabeledPointWithWeight.class)));
-            coefficientState =
-                    context.getOperatorStateStore()
-                            .getListState(
-                                    new ListStateDescriptor<>(
-                                            "coefficientState",
-                                            
TypeInformation.of(DenseVector.class)));
-            OperatorStateUtils.getUniqueElement(coefficientState, 
"coefficientState")
-                    .ifPresent(x -> coefficient = x);
-            feedbackBufferState =
-                    context.getOperatorStateStore()
-                            .getListState(
-                                    new ListStateDescriptor<>(
-                                            "feedbackBufferState",
-                                            PrimitiveArrayTypeInfo
-                                                    
.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO));
-            OperatorStateUtils.getUniqueElement(feedbackBufferState, 
"feedbackBufferState")
-                    .ifPresent(x -> feedbackBuffer = x);
-            if (coefficient != null) {
-                coefficientDim = coefficient.size();
-                gradient = new DenseVector(new double[coefficientDim]);
-            }
-        }
-
-        @Override
-        public void snapshotState(StateSnapshotContext context) throws 
Exception {
-            coefficientState.clear();
-            if (coefficient != null) {
-                coefficientState.add(coefficient);
-            }
-            feedbackBufferState.clear();
-            if (feedbackBuffer != null) {
-                feedbackBufferState.add(feedbackBuffer);
-            }
-        }
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
index 3c30a8d..247a4d5 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
@@ -21,7 +21,6 @@ package org.apache.flink.ml.classification.logisticregression;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.api.Model;
 import org.apache.flink.ml.common.broadcast.BroadcastUtils;
@@ -59,40 +58,6 @@ public class LogisticRegressionModel
         ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
-
-    @Override
-    public void save(String path) throws IOException {
-        ReadWriteUtils.saveMetadata(this, path);
-        ReadWriteUtils.saveModelData(
-                LogisticRegressionModelData.getModelDataStream(modelDataTable),
-                path,
-                new LogisticRegressionModelData.ModelDataEncoder());
-    }
-
-    public static LogisticRegressionModel load(StreamTableEnvironment tEnv, 
String path)
-            throws IOException {
-        LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
-        Table modelDataTable =
-                ReadWriteUtils.loadModelData(
-                        tEnv, path, new 
LogisticRegressionModelData.ModelDataDecoder());
-        return model.setModelData(modelDataTable);
-    }
-
-    @Override
-    public LogisticRegressionModel setModelData(Table... inputs) {
-        modelDataTable = inputs[0];
-        return this;
-    }
-
-    @Override
-    public Table[] getModelData() {
-        return new Table[] {modelDataTable};
-    }
-
     @Override
     @SuppressWarnings("unchecked")
     public Table[] transform(Table... inputs) {
@@ -127,6 +92,40 @@ public class LogisticRegressionModel
         return new Table[] {tEnv.fromDataStream(predictionResult)};
     }
 
+    @Override
+    public LogisticRegressionModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(modelDataTable),
+                path,
+                new LogisticRegressionModelData.ModelDataEncoder());
+    }
+
+    public static LogisticRegressionModel load(StreamTableEnvironment tEnv, 
String path)
+            throws IOException {
+        LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
     /** A utility function used for prediction. */
     private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
 
@@ -150,22 +149,21 @@ public class LogisticRegressionModel
                 coefficient = modelData.coefficient;
             }
             DenseVector features = (DenseVector) 
dataPoint.getField(featuresCol);
-            Tuple2<Double, DenseVector> predictionResult = 
predictRaw(features, coefficient);
-            return Row.join(dataPoint, Row.of(predictionResult.f0, 
predictionResult.f1));
+            Row predictionResult = predictOneDataPoint(features, coefficient);
+            return Row.join(dataPoint, predictionResult);
         }
     }
 
     /**
-     * The main logic that predicts one input record.
+     * The main logic that predicts one input data point.
      *
      * @param feature The input feature.
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Tuple2<Double, DenseVector> predictRaw(
-            DenseVector feature, DenseVector coefficient) {
+    private static Row predictOneDataPoint(DenseVector feature, DenseVector 
coefficient) {
         double dotValue = BLAS.dot(feature, coefficient);
         double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue));
-        return new Tuple2<>(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, 
prob));
+        return Row.of(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, prob));
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
index c9b5919..f016978 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.ml.common.param.HasElasticNet;
 import org.apache.flink.ml.common.param.HasGlobalBatchSize;
 import org.apache.flink.ml.common.param.HasLabelCol;
 import org.apache.flink.ml.common.param.HasLearningRate;
@@ -37,6 +38,7 @@ public interface LogisticRegressionParams<T>
                 HasWeightCol<T>,
                 HasMaxIter<T>,
                 HasReg<T>,
+                HasElasticNet<T>,
                 HasLearningRate<T>,
                 HasGlobalBatchSize<T>,
                 HasTol<T>,
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java
new file mode 100644
index 0000000..cd24c06
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/** The loss function for binary logistic loss. See {@link LogisticRegression} 
for example. */
+@Internal
+public class BinaryLogisticLoss implements LossFunc {
+    public static final BinaryLogisticLoss INSTANCE = new BinaryLogisticLoss();
+
+    private BinaryLogisticLoss() {}
+
+    @Override
+    public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector 
coefficient) {
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
+        return dataPoint.getWeight() * Math.log(1 + Math.exp(-dot * 
labelScaled));
+    }
+
+    @Override
+    public void computeGradient(
+            LabeledPointWithWeight dataPoint, DenseVector coefficient, 
DenseVector cumGradient) {
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
+        double multiplier =
+                dataPoint.getWeight() * (-labelScaled / (Math.exp(dot * 
labelScaled) + 1));
+        BLAS.axpy(multiplier, dataPoint.getFeatures(), cumGradient, 
dataPoint.getFeatures().size());
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java
new file mode 100644
index 0000000..ea64649
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+
+/** The loss function for least square loss. See {@link LinearRegression} for 
example. */
+@Internal
+public class LeastSquareLoss implements LossFunc {
+    public static final LeastSquareLoss INSTANCE = new LeastSquareLoss();
+
+    private LeastSquareLoss() {}
+
+    @Override
+    public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector 
coefficient) {
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        return dataPoint.getWeight() * 0.5 * Math.pow(dot - 
dataPoint.getLabel(), 2);
+    }
+
+    @Override
+    public void computeGradient(
+            LabeledPointWithWeight dataPoint, DenseVector coefficient, 
DenseVector cumGradient) {
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        BLAS.axpy(
+                (dot - dataPoint.getLabel()) * dataPoint.getWeight(),
+                dataPoint.getFeatures(),
+                cumGradient,
+                dataPoint.getFeatures().size());
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java
new file mode 100644
index 0000000..a90967a
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given 
coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {
+
+    /**
+     * Computes the loss on the given data point.
+     *
+     * @param dataPoint A training data point.
+     * @param coefficient The model parameters.
+     * @return The loss of the input data.
+     */
+    double computeLoss(LabeledPointWithWeight dataPoint, DenseVector 
coefficient);
+
+    /**
+     * Computes the gradient on the given data point and adds the computed 
gradient to cumGradient.
+     *
+     * @param dataPoint A training data point.
+     * @param coefficient The model parameters.
+     * @param cumGradient The accumulated gradient.
+     */
+    void computeGradient(
+            LabeledPointWithWeight dataPoint, DenseVector coefficient, 
DenseVector cumGradient);
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java
new file mode 100644
index 0000000..647741d
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning 
model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples 
of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimizes the given loss function using the initial model data and the 
bounded training data.
+     *
+     * @param initModelData The initial model data.
+     * @param trainData The training data.
+     * @param lossFunc The loss function to optimize.
+     * @return The fitted model data.
+     */
+    DataStream<DenseVector> optimize(
+            DataStream<DenseVector> initModelData,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc);
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java
new file mode 100644
index 0000000..3d36d9a
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The 
regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * 
(norm2(coefficient))^2
+ *
+ * <p>See 
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could 
be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. 
Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else if (Double.compare(elasticNet, 0) == 0) {
+            // Only L2 regularization.
+            double loss = reg / 2 * BLAS.norm2(coefficient);
+            BLAS.scal(1 - learningRate * reg, coefficient);
+            return loss;
+        } else if (Double.compare(elasticNet, 1) == 0) {
+            // Only L1 regularization.
+            double loss = 0;
+            double[] coefficientArray = coefficient.values;
+            for (int i = 0; i < coefficientArray.length; i++) {
+                if (Double.compare(coefficientArray[i], 0) == 0) {
+                    continue;
+                }
+                loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                coefficientArray[i] -=
+                        learningRate * elasticNet * reg * 
Math.signum(coefficientArray[i]);
+            }
+            return loss;
+        } else {
+            // Both L1 and L2 are not zero.
+            double loss = 0;
+            double[] coefficientArray = coefficient.values;
+            for (int i = 0; i < coefficientArray.length; i++) {
+                loss +=
+                        elasticNet * reg * Math.signum(coefficientArray[i])
+                                + (1 - elasticNet)
+                                        * (reg / 2)
+                                        * coefficientArray[i]
+                                        * coefficientArray[i];
+                coefficientArray[i] -=
+                        (learningRate
+                                * (elasticNet * reg * 
Math.signum(coefficientArray[i])
+                                        + (1 - elasticNet) * reg * 
coefficientArray[i]));
+            }
+            return loss;
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java
new file mode 100644
index 0000000..2f78004
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for 
optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine 
learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, 
tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> initModelData,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(
+                                initModelData.broadcast().map(modelVec -> 
modelVec.values)),
+                        
ReplayableDataStreamList.notReplay(trainData.rebalance().map(x -> x)),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized 
model data.
+            // In the following iterations, it contains: [the model update, 
totalWeight, and
+            // totalLoss].
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, 
modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        
DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) 
reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / 
value[value.length - 2];
+                                    })
+                            .flatMap(new 
TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            
modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration 
and updates the model
+     * iteratively. The first input is the training data, and the second input 
is the initial model
+     * data or feedback of model update, totalWeight, and totalLoss.
+     */
+    private static class CacheDataAndDoTrain extends 
AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, 
double[], double[]>,
+                    IterationListener<double[]> {
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. 
*/
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+
+        /** The dimension of the coefficient. */
+        private int coefficientDim;
+
+        /**
+         * The double array to sync among all workers. For example, when 
training {@link
+         * LinearRegression}, this double array consists of [modelUpdate, 
totalWeight, totalLoss].
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+
+        /** The batch size on this partition. */
+        private int localBatchSize;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> 
modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        private double getTotalWeight() {
+            return feedbackArray[coefficientDim];
+        }
+
+        private void setTotalWeight(double totalWeight) {
+            feedbackArray[coefficientDim] = totalWeight;
+        }
+
+        private double getTotalLoss() {
+            return feedbackArray[coefficientDim + 1];
+        }
+
+        private void setTotalLoss(double totalLoss) {
+            feedbackArray[coefficientDim + 1] = totalLoss;
+        }
+
+        private void updateModel() {
+            if (getTotalWeight() > 0) {
+                BLAS.axpy(
+                        -params.learningRate / getTotalWeight(),
+                        new DenseVector(feedbackArray),
+                        coefficient,
+                        coefficientDim);
+                double regLoss =
+                        RegularizationUtils.regularize(
+                                coefficient, params.reg, params.elasticNet, 
params.learningRate);
+                setTotalLoss(getTotalLoss() + regLoss);
+            }
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> 
collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coefficientDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                updateModel();
+            }
+
+            if (trainData == null) {
+                trainData = 
IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each 
partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, 
trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : 
nextBatchOffset;
+
+                // Does the training.
+                Arrays.fill(feedbackArray, 0);
+                double totalLoss = 0;
+                double totalWeight = 0;
+                DenseVector cumGradientsWrapper = new 
DenseVector(feedbackArray);
+                for (LabeledPointWithWeight dataPoint : miniBatchData) {
+                    totalLoss += lossFunc.computeLoss(dataPoint, coefficient);
+                    lossFunc.computeGradient(dataPoint, coefficient, 
cumGradientsWrapper);
+                    totalWeight += dataPoint.getWeight();
+                }
+                setTotalLoss(totalLoss);
+                setTotalWeight(totalWeight);
+
+                collector.collect(feedbackArray);
+            }
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector<double[]> 
collector) {
+            trainDataState.clear();
+            if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
+                updateModel();
+                context.output(modelDataOutputTag, coefficient);
+            }
+        }
+
+        @Override
+        public void processElement1(StreamRecord<LabeledPointWithWeight> 
streamRecord)
+                throws Exception {
+            trainDataState.add(streamRecord.getValue());
+        }
+
+        @Override
+        public void processElement2(StreamRecord<double[]> streamRecord) {
+            feedbackArray = streamRecord.getValue();
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            coefficientState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "coefficientState", 
DenseVectorTypeInfo.INSTANCE));
+            OperatorStateUtils.getUniqueElement(coefficientState, 
"coefficientState")
+                    .ifPresent(x -> coefficient = x);
+            if (coefficient != null) {
+                coefficientDim = coefficient.size();
+            }
+
+            feedbackArrayState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "feedbackArrayState",
+                                            PrimitiveArrayTypeInfo
+                                                    
.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO));
+            OperatorStateUtils.getUniqueElement(feedbackArrayState, 
"feedbackArrayState")
+                    .ifPresent(x -> feedbackArray = x);
+
+            trainDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "trainDataState",
+                                            
TypeInformation.of(LabeledPointWithWeight.class)));
+
+            nextBatchOffsetState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "nextBatchOffsetState", 
BasicTypeInfo.INT_TYPE_INFO));
+            nextBatchOffset =
+                    OperatorStateUtils.getUniqueElement(
+                                    nextBatchOffsetState, 
"nextBatchOffsetState")
+                            .orElse(0);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            coefficientState.clear();
+            if (coefficient != null) {
+                coefficientState.add(coefficient);
+            }
+
+            feedbackArrayState.clear();
+            if (feedbackArray != null) {
+                feedbackArrayState.add(feedbackArray);
+            }
+
+            nextBatchOffsetState.clear();
+            nextBatchOffsetState.add(nextBatchOffset);
+        }
+    }
+
+    /** Parameters for {@link SGD}. */
+    private static class SGDParams implements Serializable {
+        public final int maxIter;
+        public final double learningRate;
+        public final int globalBatchSize;
+        public final double tol;
+        public final double reg;
+        public final double elasticNet;
+
+        private SGDParams(
+                int maxIter,
+                double learningRate,
+                int globalBatchSize,
+                double tol,
+                double reg,
+                double elasticNet) {
+            this.maxIter = maxIter;
+            this.learningRate = learningRate;
+            this.globalBatchSize = globalBatchSize;
+            this.tol = tol;
+            this.reg = reg;
+            this.elasticNet = elasticNet;
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNet.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNet.java
new file mode 100644
index 0000000..467308a
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNet.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared elasticNet param, which specifies the mixing of L1 
and L2 penalty:
+ *
+ * <ul>
+ *   <li>If the value is zero, it is L2 penalty.
+ *   <li>If the value is one, it is L1 penalty.
+ *   <li>For value in (0,1), it is a combination of L1 and L2 penalty.
+ * </ul>
+ */
+public interface HasElasticNet<T> extends WithParams<T> {
+    Param<Double> ELASTIC_NET =
+            new DoubleParam(
+                    "elasticNet", "ElasticNet parameter.", 0.0, 
ParamValidators.inRange(0.0, 1.0));
+
+    default double getElasticNet() {
+        return get(ELASTIC_NET);
+    }
+
+    default T setElasticNet(Double value) {
+        return set(ELASTIC_NET, value);
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
new file mode 100644
index 0000000..7edffc4
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegression.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LeastSquareLoss;
+import org.apache.flink.ml.common.optimizer.Optimizer;
+import org.apache.flink.ml.common.optimizer.SGD;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the linear regression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Linear_regression.
+ */
+public class LinearRegression
+        implements Estimator<LinearRegression, LinearRegressionModel>,
+                LinearRegressionParams<LinearRegression> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LinearRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings({"rawTypes", "ConstantConditions"})
+    public LinearRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<LabeledPointWithWeight> trainData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                dataPoint -> {
+                                    double weight =
+                                            getWeightCol() == null
+                                                    ? 1.0
+                                                    : (Double) 
dataPoint.getField(getWeightCol());
+                                    double label = (Double) 
dataPoint.getField(getLabelCol());
+                                    DenseVector features =
+                                            (DenseVector) 
dataPoint.getField(getFeaturesCol());
+                                    return new 
LabeledPointWithWeight(features, label, weight);
+                                });
+
+        DataStream<DenseVector> initModelData =
+                DataStreamUtils.reduce(
+                                trainData.map(x -> x.getFeatures().size()),
+                                (ReduceFunction<Integer>)
+                                        (t0, t1) -> {
+                                            Preconditions.checkState(
+                                                    t0.equals(t1),
+                                                    "The training data should 
all have same dimensions.");
+                                            return t0;
+                                        })
+                        .map(DenseVector::new);
+
+        Optimizer optimizer =
+                new SGD(
+                        getMaxIter(),
+                        getLearningRate(),
+                        getGlobalBatchSize(),
+                        getTol(),
+                        getReg(),
+                        getElasticNet());
+        DataStream<DenseVector> rawModelData =
+                optimizer.optimize(initModelData, trainData, 
LeastSquareLoss.INSTANCE);
+
+        DataStream<LinearRegressionModelData> modelData =
+                rawModelData.map(LinearRegressionModelData::new);
+        LinearRegressionModel model =
+                new 
LinearRegressionModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static LinearRegression load(StreamTableEnvironment tEnv, String 
path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java
similarity index 68%
copy from 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java
index 3c30a8d..7fb9019 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java
@@ -16,19 +16,16 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.classification.logisticregression;
+package org.apache.flink.ml.regression.linearregression;
 
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.api.Model;
 import org.apache.flink.ml.common.broadcast.BroadcastUtils;
 import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
@@ -46,53 +43,19 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
-/** A Model which classifies data using the model data computed by {@link 
LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
+/** A Model which predicts data using the model data computed by {@link 
LinearRegression}. */
+public class LinearRegressionModel
+        implements Model<LinearRegressionModel>,
+                LinearRegressionModelParams<LinearRegressionModel> {
 
     private final Map<Param<?>, Object> paramMap = new HashMap<>();
 
     private Table modelDataTable;
 
-    public LogisticRegressionModel() {
+    public LinearRegressionModel() {
         ParamUtils.initializeMapWithDefaultValues(paramMap, this);
     }
 
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
-
-    @Override
-    public void save(String path) throws IOException {
-        ReadWriteUtils.saveMetadata(this, path);
-        ReadWriteUtils.saveModelData(
-                LogisticRegressionModelData.getModelDataStream(modelDataTable),
-                path,
-                new LogisticRegressionModelData.ModelDataEncoder());
-    }
-
-    public static LogisticRegressionModel load(StreamTableEnvironment tEnv, 
String path)
-            throws IOException {
-        LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
-        Table modelDataTable =
-                ReadWriteUtils.loadModelData(
-                        tEnv, path, new 
LogisticRegressionModelData.ModelDataDecoder());
-        return model.setModelData(modelDataTable);
-    }
-
-    @Override
-    public LogisticRegressionModel setModelData(Table... inputs) {
-        modelDataTable = inputs[0];
-        return this;
-    }
-
-    @Override
-    public Table[] getModelData() {
-        return new Table[] {modelDataTable};
-    }
-
     @Override
     @SuppressWarnings("unchecked")
     public Table[] transform(Table... inputs) {
@@ -101,19 +64,14 @@ public class LogisticRegressionModel
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
         DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
         final String broadcastModelKey = "broadcastModelKey";
-        DataStream<LogisticRegressionModelData> modelDataStream =
-                LogisticRegressionModelData.getModelDataStream(modelDataTable);
+        DataStream<LinearRegressionModelData> modelDataStream =
+                LinearRegressionModelData.getModelDataStream(modelDataTable);
         RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
         RowTypeInfo outputTypeInfo =
                 new RowTypeInfo(
                         ArrayUtils.addAll(
-                                inputTypeInfo.getFieldTypes(),
-                                BasicTypeInfo.DOUBLE_TYPE_INFO,
-                                TypeInformation.of(DenseVector.class)),
-                        ArrayUtils.addAll(
-                                inputTypeInfo.getFieldNames(),
-                                getPredictionCol(),
-                                getRawPredictionCol()));
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
         DataStream<Row> predictionResult =
                 BroadcastUtils.withBroadcastStream(
                         Collections.singletonList(inputStream),
@@ -127,6 +85,40 @@ public class LogisticRegressionModel
         return new Table[] {tEnv.fromDataStream(predictionResult)};
     }
 
+    @Override
+    public LinearRegressionModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LinearRegressionModelData.getModelDataStream(modelDataTable),
+                path,
+                new LinearRegressionModelData.ModelDataEncoder());
+    }
+
+    public static LinearRegressionModel load(StreamTableEnvironment tEnv, 
String path)
+            throws IOException {
+        LinearRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new 
LinearRegressionModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
     /** A utility function used for prediction. */
     private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
 
@@ -144,28 +136,25 @@ public class LogisticRegressionModel
         @Override
         public Row map(Row dataPoint) {
             if (coefficient == null) {
-                LogisticRegressionModelData modelData =
-                        (LogisticRegressionModelData)
+                LinearRegressionModelData modelData =
+                        (LinearRegressionModelData)
                                 
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
                 coefficient = modelData.coefficient;
             }
             DenseVector features = (DenseVector) 
dataPoint.getField(featuresCol);
-            Tuple2<Double, DenseVector> predictionResult = 
predictRaw(features, coefficient);
-            return Row.join(dataPoint, Row.of(predictionResult.f0, 
predictionResult.f1));
+            Row predictionResult = predictOneDataPoint(features, coefficient);
+            return Row.join(dataPoint, predictionResult);
         }
     }
 
     /**
-     * The main logic that predicts one input record.
+     * The main logic that predicts one input data point.
      *
      * @param feature The input feature.
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Tuple2<Double, DenseVector> predictRaw(
-            DenseVector feature, DenseVector coefficient) {
-        double dotValue = BLAS.dot(feature, coefficient);
-        double prob = 1 - 1.0 / (1.0 + Math.exp(dotValue));
-        return new Tuple2<>(dotValue >= 0 ? 1. : 0., Vectors.dense(1 - prob, 
prob));
+    private static Row predictOneDataPoint(DenseVector feature, DenseVector 
coefficient) {
+        return Row.of(BLAS.dot(feature, coefficient));
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java
new file mode 100644
index 0000000..278d5b3
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelData.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LinearRegressionModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LinearRegressionModelData {
+
+    public DenseVector coefficient;
+
+    public LinearRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+    }
+
+    public LinearRegressionModelData() {}
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<LinearRegressionModelData> 
getModelDataStream(Table modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData)
+                .map(x -> new LinearRegressionModelData((DenseVector) 
x.getField(0)));
+    }
+
+    /** Data encoder for {@link LinearRegressionModel}. */
+    public static class ModelDataEncoder implements 
Encoder<LinearRegressionModelData> {
+
+        @Override
+        public void encode(LinearRegressionModelData modelData, OutputStream 
outputStream)
+                throws IOException {
+            DenseVectorSerializer.INSTANCE.serialize(
+                    modelData.coefficient, new 
DataOutputViewStreamWrapper(outputStream));
+        }
+    }
+
+    /** Data decoder for {@link LinearRegressionModel}. */
+    public static class ModelDataDecoder extends 
SimpleStreamFormat<LinearRegressionModelData> {
+
+        @Override
+        public Reader<LinearRegressionModelData> createReader(
+                Configuration configuration, FSDataInputStream inputStream) {
+            return new Reader<LinearRegressionModelData>() {
+
+                @Override
+                public LinearRegressionModelData read() throws IOException {
+                    try {
+                        DenseVector coefficient =
+                                DenseVectorSerializer.INSTANCE.deserialize(
+                                        new 
DataInputViewStreamWrapper(inputStream));
+                        return new LinearRegressionModelData(coefficient);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    inputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<LinearRegressionModelData> getProducedType() {
+            return TypeInformation.of(LinearRegressionModelData.class);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelParams.java
new file mode 100644
index 0000000..687fed6
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+
+/**
+ * Params for {@link LinearRegressionModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface LinearRegressionModelParams<T> extends HasFeaturesCol<T>, 
HasPredictionCol<T> {}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionParams.java
similarity index 83%
copy from 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionParams.java
index c9b5919..a5ca6cc 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionParams.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionParams.java
@@ -16,29 +16,29 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.classification.logisticregression;
+package org.apache.flink.ml.regression.linearregression;
 
+import org.apache.flink.ml.common.param.HasElasticNet;
 import org.apache.flink.ml.common.param.HasGlobalBatchSize;
 import org.apache.flink.ml.common.param.HasLabelCol;
 import org.apache.flink.ml.common.param.HasLearningRate;
 import org.apache.flink.ml.common.param.HasMaxIter;
-import org.apache.flink.ml.common.param.HasMultiClass;
 import org.apache.flink.ml.common.param.HasReg;
 import org.apache.flink.ml.common.param.HasTol;
 import org.apache.flink.ml.common.param.HasWeightCol;
 
 /**
- * Params for {@link LogisticRegression}.
+ * Params for {@link LinearRegression}.
  *
  * @param <T> The class type of this instance.
  */
-public interface LogisticRegressionParams<T>
+public interface LinearRegressionParams<T>
         extends HasLabelCol<T>,
                 HasWeightCol<T>,
                 HasMaxIter<T>,
                 HasReg<T>,
+                HasElasticNet<T>,
                 HasLearningRate<T>,
                 HasGlobalBatchSize<T>,
                 HasTol<T>,
-                HasMultiClass<T>,
-                LogisticRegressionModelParams<T> {}
+                LinearRegressionModelParams<T> {}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 88882c6..8d06613 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionM
 import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.ml.util.StageTestUtils;
 import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
@@ -37,6 +38,7 @@ import 
org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.types.Row;
 
 import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.exception.ExceptionUtils;
 import org.junit.Before;
 import org.junit.Rule;
@@ -49,7 +51,6 @@ import java.util.List;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -90,7 +91,7 @@ public class LogisticRegressionTest {
                     Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
 
     private static final double[] expectedCoefficient =
-            new double[] {0.528, -0.286, -0.429, -0.572};
+            new double[] {0.525, -0.283, -0.425, -0.567};
 
     private static final double TOLERANCE = 1e-7;
 
@@ -114,9 +115,7 @@ public class LogisticRegressionTest {
                                 binomialTrainData,
                                 new RowTypeInfo(
                                         new TypeInformation[] {
-                                            
TypeInformation.of(DenseVector.class),
-                                            Types.DOUBLE,
-                                            Types.DOUBLE
+                                            DenseVectorTypeInfo.INSTANCE, 
Types.DOUBLE, Types.DOUBLE
                                         },
                                         new String[] {"features", "label", 
"weight"})));
         multinomialDataTable =
@@ -125,14 +124,12 @@ public class LogisticRegressionTest {
                                 multinomialTrainData,
                                 new RowTypeInfo(
                                         new TypeInformation[] {
-                                            
TypeInformation.of(DenseVector.class),
-                                            Types.DOUBLE,
-                                            Types.DOUBLE
+                                            DenseVectorTypeInfo.INSTANCE, 
Types.DOUBLE, Types.DOUBLE
                                         },
                                         new String[] {"features", "label", 
"weight"})));
     }
 
-    @SuppressWarnings("ConstantConditions")
+    @SuppressWarnings("ConstantConditions, unchecked")
     private void verifyPredictionResult(
             Table output, String featuresCol, String predictionCol, String 
rawPredictionCol)
             throws Exception {
@@ -154,17 +151,18 @@ public class LogisticRegressionTest {
     @Test
     public void testParam() {
         LogisticRegression logisticRegression = new LogisticRegression();
-        assertEquals(logisticRegression.getLabelCol(), "label");
+        assertEquals("label", logisticRegression.getLabelCol());
         assertNull(logisticRegression.getWeightCol());
-        assertEquals(logisticRegression.getMaxIter(), 20);
-        assertEquals(logisticRegression.getReg(), 0, TOLERANCE);
-        assertEquals(logisticRegression.getLearningRate(), 0.1, TOLERANCE);
-        assertEquals(logisticRegression.getGlobalBatchSize(), 32);
-        assertEquals(logisticRegression.getTol(), 1e-6, TOLERANCE);
-        assertEquals(logisticRegression.getMultiClass(), "auto");
-        assertEquals(logisticRegression.getFeaturesCol(), "features");
-        assertEquals(logisticRegression.getPredictionCol(), "prediction");
-        assertEquals(logisticRegression.getRawPredictionCol(), 
"rawPrediction");
+        assertEquals(20, logisticRegression.getMaxIter());
+        assertEquals(0, logisticRegression.getReg(), TOLERANCE);
+        assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE);
+        assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE);
+        assertEquals(32, logisticRegression.getGlobalBatchSize());
+        assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE);
+        assertEquals("auto", logisticRegression.getMultiClass());
+        assertEquals("features", logisticRegression.getFeaturesCol());
+        assertEquals("prediction", logisticRegression.getPredictionCol());
+        assertEquals("rawPrediction", 
logisticRegression.getRawPredictionCol());
 
         logisticRegression
                 .setFeaturesCol("test_features")
@@ -175,20 +173,22 @@ public class LogisticRegressionTest {
                 .setLearningRate(0.5)
                 .setGlobalBatchSize(1000)
                 .setReg(0.1)
+                .setElasticNet(0.5)
                 .setMultiClass("binomial")
                 .setPredictionCol("test_predictionCol")
                 .setRawPredictionCol("test_rawPredictionCol");
-        assertEquals(logisticRegression.getFeaturesCol(), "test_features");
-        assertEquals(logisticRegression.getLabelCol(), "test_label");
-        assertEquals(logisticRegression.getWeightCol(), "test_weight");
-        assertEquals(logisticRegression.getMaxIter(), 1000);
-        assertEquals(logisticRegression.getTol(), 0.001, TOLERANCE);
-        assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE);
-        assertEquals(logisticRegression.getGlobalBatchSize(), 1000);
-        assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE);
-        assertEquals(logisticRegression.getMultiClass(), "binomial");
-        assertEquals(logisticRegression.getPredictionCol(), 
"test_predictionCol");
-        assertEquals(logisticRegression.getRawPredictionCol(), 
"test_rawPredictionCol");
+        assertEquals("test_features", logisticRegression.getFeaturesCol());
+        assertEquals("test_label", logisticRegression.getLabelCol());
+        assertEquals("test_weight", logisticRegression.getWeightCol());
+        assertEquals(1000, logisticRegression.getMaxIter());
+        assertEquals(0.001, logisticRegression.getTol(), TOLERANCE);
+        assertEquals(0.5, logisticRegression.getLearningRate(), TOLERANCE);
+        assertEquals(1000, logisticRegression.getGlobalBatchSize());
+        assertEquals(0.1, logisticRegression.getReg(), TOLERANCE);
+        assertEquals(0.5, logisticRegression.getElasticNet(), TOLERANCE);
+        assertEquals("binomial", logisticRegression.getMultiClass());
+        assertEquals("test_predictionCol", 
logisticRegression.getPredictionCol());
+        assertEquals("test_rawPredictionCol", 
logisticRegression.getRawPredictionCol());
     }
 
     @Test
@@ -243,15 +243,16 @@ public class LogisticRegressionTest {
     }
 
     @Test
+    @SuppressWarnings("unchecked")
     public void testGetModelData() throws Exception {
         LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
         LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
-        LogisticRegressionModelData modelData =
-                
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
-                        .executeAndCollect()
-                        .next();
-        assertNotNull(modelData);
-        assertArrayEquals(expectedCoefficient, modelData.coefficient.values, 
0.1);
+        List<LogisticRegressionModelData> modelData =
+                IteratorUtils.toList(
+                        
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                                .executeAndCollect());
+        assertEquals(1, modelData.size());
+        assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, 0.1);
     }
 
     @Test
@@ -286,7 +287,8 @@ public class LogisticRegressionTest {
     @Test
     public void testMoreSubtaskThanData() throws Exception {
         env.setParallelism(12);
-        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
+        LogisticRegression logisticRegression =
+                new 
LogisticRegression().setWeightCol("weight").setGlobalBatchSize(128);
         Table output = 
logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0];
         verifyPredictionResult(
                 output,
@@ -294,4 +296,29 @@ public class LogisticRegressionTest {
                 logisticRegression.getPredictionCol(),
                 logisticRegression.getRawPredictionCol());
     }
+
+    @Test
+    public void testRegularization() throws Exception {
+        checkRegularization(0, RandomUtils.nextDouble(0, 1), 
expectedCoefficient);
+        checkRegularization(0.1, 0, new double[] {0.484, -0.258, -0.388, 
-0.517});
+        checkRegularization(0.1, 1, new double[] {0.417, -0.145, -0.312, 
-0.480});
+        checkRegularization(0.1, 0.5, new double[] {0.451, -0.203, -0.351, 
-0.498});
+    }
+
+    @SuppressWarnings("unchecked")
+    private void checkRegularization(double reg, double elasticNet, double[] 
expectedCoefficient)
+            throws Exception {
+        LogisticRegressionModel model =
+                new LogisticRegression()
+                        .setWeightCol("weight")
+                        .setReg(reg)
+                        .setElasticNet(elasticNet)
+                        .fit(binomialDataTable);
+        List<LogisticRegressionModelData> modelData =
+                IteratorUtils.toList(
+                        
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                                .executeAndCollect());
+        final double errorTol = 1e-3;
+        assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, errorTol);
+    }
 }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java
new file mode 100644
index 0000000..22ce2de
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLossTest.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link BinaryLogisticLoss}. */
+public class BinaryLogisticLossTest {
+    private static final LabeledPointWithWeight dataPoint =
+            new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0);
+    private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 
1.0);
+    private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 
0.0);
+    private static final double TOLERANCE = 1e-7;
+
+    @Test
+    public void computeLoss() {
+        double loss = BinaryLogisticLoss.INSTANCE.computeLoss(dataPoint, 
coefficient);
+        assertEquals(0.0049513, loss, TOLERANCE);
+    }
+
+    @Test
+    public void computeGradient() {
+        BinaryLogisticLoss.INSTANCE.computeGradient(dataPoint, coefficient, 
cumGradient);
+        assertArrayEquals(
+                new double[] {-0.0049452, -0.0098904, -0.0148357}, 
cumGradient.values, TOLERANCE);
+        BinaryLogisticLoss.INSTANCE.computeGradient(dataPoint, coefficient, 
cumGradient);
+        assertArrayEquals(
+                new double[] {-0.0098904, -0.0197809, -0.0296714}, 
cumGradient.values, TOLERANCE);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java
new file mode 100644
index 0000000..ee2d030
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/LeastSquareLossTest.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link LeastSquareLoss}. */
+public class LeastSquareLossTest {
+    private static final LabeledPointWithWeight dataPoint =
+            new LabeledPointWithWeight(Vectors.dense(1.0, 2.0, 3.0), 1.0, 2.0);
+    private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 
1.0);
+    private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 
0.0);
+    private static final double TOLERANCE = 1e-7;
+
+    @Test
+    public void computeLoss() {
+        double loss = LeastSquareLoss.INSTANCE.computeLoss(dataPoint, 
coefficient);
+        assertEquals(25.0, loss, TOLERANCE);
+    }
+
+    @Test
+    public void computeGradient() {
+        LeastSquareLoss.INSTANCE.computeGradient(dataPoint, coefficient, 
cumGradient);
+        assertArrayEquals(new double[] {10.0, 20.0, 30.0}, cumGradient.values, 
TOLERANCE);
+        LeastSquareLoss.INSTANCE.computeGradient(dataPoint, coefficient, 
cumGradient);
+        assertArrayEquals(new double[] {20.0, 40.0, 60.0}, cumGradient.values, 
TOLERANCE);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java
new file mode 100644
index 0000000..83d6883
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.ml.linalg.DenseVector;
+
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/** Tests {@link RegularizationUtils}. */
+public class RegularizationUtilsTest {
+    private static final double learningRate = 0.1;
+    private static final double TOLERANCE = 1e-7;
+    private static final DenseVector coefficient = new DenseVector(new 
double[] {1.0, -2.0, 0});
+
+    @Test
+    public void testRegularization() {
+        checkRegularization(0, RandomUtils.nextDouble(0, 1), new double[] {1, 
-2.0, 0});
+        checkRegularization(0.1, 0, new double[] {0.99, -1.98, 0});
+        checkRegularization(0.1, 1, new double[] {0.99, -1.99, 0});
+        checkRegularization(0.1, 0.1, new double[] {0.99, -1.981, 0});
+    }
+
+    private void checkRegularization(double reg, double elasticNet, double[] 
expectedCoefficient) {
+        DenseVector clonedCoefficient = coefficient.clone();
+        RegularizationUtils.regularize(clonedCoefficient, reg, elasticNet, 
learningRate);
+        assertArrayEquals(expectedCoefficient, clonedCoefficient.values, 
TOLERANCE);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
new file mode 100644
index 0000000..58e89c4
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import 
org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.141, 
1.829};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private static final double PREDICTION_TOLERANCE = 0.1;
+
+    private static final double COEFFICIENT_TOLERANCE = 0.1;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            DenseVectorTypeInfo.INSTANCE, 
Types.DOUBLE, Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", 
"weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(Table output, String labelCol, String 
predictionCol)
+            throws Exception {
+        List<Row> predResult = 
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            assertTrue(Math.abs(prediction - label) / label < 
PREDICTION_TOLERANCE);
+        }
+    }
+
+    @Test
+    public void testParam() {
+        LinearRegression linearRegression = new LinearRegression();
+        assertEquals("label", linearRegression.getLabelCol());
+        assertNull(linearRegression.getWeightCol());
+        assertEquals(20, linearRegression.getMaxIter());
+        assertEquals(0, linearRegression.getReg(), TOLERANCE);
+        assertEquals(0, linearRegression.getElasticNet(), TOLERANCE);
+        assertEquals(0.1, linearRegression.getLearningRate(), TOLERANCE);
+        assertEquals(32, linearRegression.getGlobalBatchSize());
+        assertEquals(1e-6, linearRegression.getTol(), TOLERANCE);
+        assertEquals("features", linearRegression.getFeaturesCol());
+        assertEquals("prediction", linearRegression.getPredictionCol());
+
+        linearRegression
+                .setFeaturesCol("test_features")
+                .setLabelCol("test_label")
+                .setWeightCol("test_weight")
+                .setMaxIter(1000)
+                .setTol(0.001)
+                .setLearningRate(0.5)
+                .setGlobalBatchSize(1000)
+                .setReg(0.1)
+                .setElasticNet(0.5)
+                .setPredictionCol("test_predictionCol");
+        assertEquals("test_features", linearRegression.getFeaturesCol());
+        assertEquals("test_label", linearRegression.getLabelCol());
+        assertEquals("test_weight", linearRegression.getWeightCol());
+        assertEquals(1000, linearRegression.getMaxIter());
+        assertEquals(0.001, linearRegression.getTol(), TOLERANCE);
+        assertEquals(0.5, linearRegression.getLearningRate(), TOLERANCE);
+        assertEquals(1000, linearRegression.getGlobalBatchSize());
+        assertEquals(0.1, linearRegression.getReg(), TOLERANCE);
+        assertEquals(0.5, linearRegression.getElasticNet(), TOLERANCE);
+        assertEquals("test_predictionCol", 
linearRegression.getPredictionCol());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable = trainDataTable.as("test_features", "test_label", 
"test_weight");
+        LinearRegression linearRegression =
+                new LinearRegression()
+                        .setFeaturesCol("test_features")
+                        .setLabelCol("test_label")
+                        .setWeightCol("test_weight")
+                        .setPredictionCol("test_predictionCol");
+        Table output = 
linearRegression.fit(trainDataTable).transform(tempTable)[0];
+        assertEquals(
+                Arrays.asList("test_features", "test_label", "test_weight", 
"test_predictionCol"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        LinearRegression linearRegression = new 
LinearRegression().setWeightCol("weight");
+        Table output = 
linearRegression.fit(trainDataTable).transform(trainDataTable)[0];
+        verifyPredictionResult(
+                output, linearRegression.getLabelCol(), 
linearRegression.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        LinearRegression linearRegression = new 
LinearRegression().setWeightCol("weight");
+        linearRegression =
+                StageTestUtils.saveAndReload(
+                        tEnv, linearRegression, 
tempFolder.newFolder().getAbsolutePath());
+        LinearRegressionModel model = linearRegression.fit(trainDataTable);
+        model = StageTestUtils.saveAndReload(tEnv, model, 
tempFolder.newFolder().getAbsolutePath());
+        assertEquals(
+                Collections.singletonList("coefficient"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+        Table output = model.transform(trainDataTable)[0];
+        verifyPredictionResult(
+                output, linearRegression.getLabelCol(), 
linearRegression.getPredictionCol());
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        LinearRegression linearRegression = new 
LinearRegression().setWeightCol("weight");
+        LinearRegressionModel model = linearRegression.fit(trainDataTable);
+        List<LinearRegressionModelData> modelData =
+                IteratorUtils.toList(
+                        
LinearRegressionModelData.getModelDataStream(model.getModelData()[0])
+                                .executeAndCollect());
+        assertNotNull(modelData);
+        assertEquals(1, modelData.size());
+        assertArrayEquals(
+                expectedCoefficient, modelData.get(0).coefficient.values, 
COEFFICIENT_TOLERANCE);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        LinearRegression linearRegression = new 
LinearRegression().setWeightCol("weight");
+        LinearRegressionModel model = linearRegression.fit(trainDataTable);
+
+        LinearRegressionModel newModel = new LinearRegressionModel();
+        ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+        newModel.setModelData(model.getModelData());
+        Table output = newModel.transform(trainDataTable)[0];
+        verifyPredictionResult(
+                output, linearRegression.getLabelCol(), 
linearRegression.getPredictionCol());
+    }
+
+    @Test
+    public void testMoreSubtaskThanData() throws Exception {
+        env.setParallelism(12);
+        LinearRegression linearRegression =
+                new 
LinearRegression().setWeightCol("weight").setGlobalBatchSize(128);
+        Table output = 
linearRegression.fit(trainDataTable).transform(trainDataTable)[0];
+        verifyPredictionResult(
+                output, linearRegression.getLabelCol(), 
linearRegression.getPredictionCol());
+    }
+
+    @Test
+    public void testRegularization() throws Exception {
+        checkRegularization(0, RandomUtils.nextDouble(0, 1), 
expectedCoefficient);
+        checkRegularization(0.1, 0, new double[] {1.165, 1.780});
+        checkRegularization(0.1, 1, new double[] {1.143, 1.812});
+        checkRegularization(0.1, 0.5, new double[] {1.154, 1.796});
+    }
+
+    @SuppressWarnings("unchecked")
+    private void checkRegularization(double reg, double elasticNet, double[] 
expectedCoefficient)
+            throws Exception {
+        LinearRegressionModel model =
+                new LinearRegression()
+                        .setWeightCol("weight")
+                        .setReg(reg)
+                        .setElasticNet(elasticNet)
+                        .fit(trainDataTable);
+        List<LinearRegressionModelData> modelData =
+                IteratorUtils.toList(
+                        
LinearRegressionModelData.getModelDataStream(model.getModelData()[0])
+                                .executeAndCollect());
+        final double errorTol = 1e-3;
+        assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, errorTol);
+    }
+}

Reply via email to