This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 4bd9083  [FLINK-27091] Add Transformer and Estimator of LinearSVC
4bd9083 is described below

commit 4bd908312825e3fe4fd7566b15626ad766ebabc5
Author: Zhipeng Zhang <zhangzhipe...@gmail.com>
AuthorDate: Sat May 7 15:53:35 2022 +0800

    [FLINK-27091] Add Transformer and Estimator of LinearSVC
    
    This closes #93.
---
 .../ml/classification/linearsvc/LinearSVC.java     | 123 ++++++++++++
 .../classification/linearsvc/LinearSVCModel.java   | 174 ++++++++++++++++
 .../linearsvc/LinearSVCModelData.java              | 111 +++++++++++
 .../linearsvc/LinearSVCModelParams.java            |  55 +++++
 .../classification/linearsvc/LinearSVCParams.java  |  44 ++++
 .../apache/flink/ml/common/lossfunc/HingeLoss.java |  58 ++++++
 ...isticRegressionTest.java => LinearSVCTest.java} | 222 ++++++++++-----------
 .../ml/classification/LogisticRegressionTest.java  |   8 +-
 .../flink/ml/common/lossfunc/HingeLossTest.java    |  57 ++++++
 .../flink/ml/regression/LinearRegressionTest.java  |   8 +-
 10 files changed, 731 insertions(+), 129 deletions(-)

diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
new file mode 100644
index 0000000..4169d48
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVC.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.HingeLoss;
+import org.apache.flink.ml.common.optimizer.Optimizer;
+import org.apache.flink.ml.common.optimizer.SGD;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the linear support vector classification.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Support-vector_machine#Linear_SVM.
+ */
+public class LinearSVC implements Estimator<LinearSVC, LinearSVCModel>, 
LinearSVCParams<LinearSVC> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public LinearSVC() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings({"rawTypes", "ConstantConditions"})
+    public LinearSVCModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<LabeledPointWithWeight> trainData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                dataPoint -> {
+                                    double weight =
+                                            getWeightCol() == null
+                                                    ? 1.0
+                                                    : (Double) 
dataPoint.getField(getWeightCol());
+                                    double label = (Double) 
dataPoint.getField(getLabelCol());
+                                    Preconditions.checkState(
+                                            Double.compare(0.0, label) == 0
+                                                    || Double.compare(1.0, 
label) == 0,
+                                            "LinearSVC only supports binary 
classification. But detected label: %s.",
+                                            label);
+                                    DenseVector features =
+                                            (DenseVector) 
dataPoint.getField(getFeaturesCol());
+                                    return new 
LabeledPointWithWeight(features, label, weight);
+                                });
+
+        DataStream<DenseVector> initModelData =
+                DataStreamUtils.reduce(
+                                trainData.map(x -> x.getFeatures().size()),
+                                (ReduceFunction<Integer>)
+                                        (t0, t1) -> {
+                                            Preconditions.checkState(
+                                                    t0.equals(t1),
+                                                    "The training data should 
all have same dimensions.");
+                                            return t0;
+                                        })
+                        .map(DenseVector::new);
+
+        Optimizer optimizer =
+                new SGD(
+                        getMaxIter(),
+                        getLearningRate(),
+                        getGlobalBatchSize(),
+                        getTol(),
+                        getReg(),
+                        getElasticNet());
+        DataStream<DenseVector> rawModelData =
+                optimizer.optimize(initModelData, trainData, 
HingeLoss.INSTANCE);
+
+        DataStream<LinearSVCModelData> modelData = 
rawModelData.map(LinearSVCModelData::new);
+        LinearSVCModel model = new 
LinearSVCModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static LinearSVC load(StreamTableEnvironment tEnv, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java
new file mode 100644
index 0000000..253bbcc
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModel.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link 
LinearSVC}. */
+public class LinearSVCModel implements Model<LinearSVCModel>, 
LinearSVCModelParams<LinearSVCModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelDataTable;
+
+    public LinearSVCModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<LinearSVCModelData> modelDataStream =
+                LinearSVCModelData.getModelDataStream(modelDataTable);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                BasicTypeInfo.DOUBLE_TYPE_INFO,
+                                DenseVectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(inputStream),
+                        Collections.singletonMap(broadcastModelKey, 
modelDataStream),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, 
getFeaturesCol(), getThreshold()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    @Override
+    public LinearSVCModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LinearSVCModelData.getModelDataStream(modelDataTable),
+                path,
+                new LinearSVCModelData.ModelDataEncoder());
+    }
+
+    public static LinearSVCModel load(StreamTableEnvironment tEnv, String 
path) throws IOException {
+        LinearSVCModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(tEnv, path, new 
LinearSVCModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+
+        private final String broadcastModelKey;
+
+        private final String featuresCol;
+
+        private final double threshold;
+
+        private DenseVector coefficient;
+
+        public PredictLabelFunction(
+                String broadcastModelKey, String featuresCol, double 
threshold) {
+            this.broadcastModelKey = broadcastModelKey;
+            this.featuresCol = featuresCol;
+            this.threshold = threshold;
+        }
+
+        @Override
+        public Row map(Row dataPoint) {
+            if (coefficient == null) {
+                LinearSVCModelData modelData =
+                        (LinearSVCModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+                coefficient = modelData.coefficient;
+            }
+            DenseVector features = (DenseVector) 
dataPoint.getField(featuresCol);
+            Row predictionResult = predictOneDataPoint(features, coefficient, 
threshold);
+            return Row.join(dataPoint, predictionResult);
+        }
+    }
+
+    /**
+     * The main logic that predicts one input data point.
+     *
+     * @param feature The input feature.
+     * @param coefficient The model parameters.
+     * @param threshold The threshold for prediction.
+     * @return The prediction label and the raw predictions.
+     */
+    private static Row predictOneDataPoint(
+            DenseVector feature, DenseVector coefficient, double threshold) {
+        double dotValue = BLAS.dot(feature, coefficient);
+        return Row.of(dotValue >= threshold ? 1.0 : 0.0, 
Vectors.dense(dotValue, -dotValue));
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java
new file mode 100644
index 0000000..96e8a27
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelData.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LinearSVCModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LinearSVCModelData {
+
+    public DenseVector coefficient;
+
+    public LinearSVCModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+    }
+
+    public LinearSVCModelData() {}
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<LinearSVCModelData> getModelDataStream(Table 
modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData)
+                .map(x -> new LinearSVCModelData((DenseVector) x.getField(0)));
+    }
+
+    /** Data encoder for {@link LinearSVCModel}. */
+    public static class ModelDataEncoder implements 
Encoder<LinearSVCModelData> {
+
+        @Override
+        public void encode(LinearSVCModelData modelData, OutputStream 
outputStream)
+                throws IOException {
+            DenseVectorSerializer.INSTANCE.serialize(
+                    modelData.coefficient, new 
DataOutputViewStreamWrapper(outputStream));
+        }
+    }
+
+    /** Data decoder for {@link LinearSVCModel}. */
+    public static class ModelDataDecoder extends 
SimpleStreamFormat<LinearSVCModelData> {
+
+        @Override
+        public Reader<LinearSVCModelData> createReader(
+                Configuration configuration, FSDataInputStream inputStream) {
+            return new Reader<LinearSVCModelData>() {
+
+                @Override
+                public LinearSVCModelData read() throws IOException {
+                    try {
+                        DenseVector coefficient =
+                                DenseVectorSerializer.INSTANCE.deserialize(
+                                        new 
DataInputViewStreamWrapper(inputStream));
+                        return new LinearSVCModelData(coefficient);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    inputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<LinearSVCModelData> getProducedType() {
+            return TypeInformation.of(LinearSVCModelData.class);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java
new file mode 100644
index 0000000..9e02233
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCModelParams.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link LinearSVCModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface LinearSVCModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> 
{
+    /**
+     * Param for threshold in linear support vector classifier. It applies to 
the rawPrediction and
+     * can be any real number, where Inf makes all predictions 0.0 and -Inf 
makes all predictions
+     * 1.0.
+     */
+    Param<Double> THRESHOLD =
+            new DoubleParam(
+                    "threshold",
+                    "Threshold in binary classification prediction applied to 
rawPrediction.",
+                    0.0,
+                    ParamValidators.notNull());
+
+    default Double getThreshold() {
+        return get(THRESHOLD);
+    }
+
+    default T setThreshold(Double value) {
+        set(THRESHOLD, value);
+        return (T) this;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java
new file mode 100644
index 0000000..7754e89
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linearsvc/LinearSVCParams.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linearsvc;
+
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasLearningRate;
+import org.apache.flink.ml.common.param.HasMaxIter;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasTol;
+import org.apache.flink.ml.common.param.HasWeightCol;
+
+/**
+ * Params for {@link LinearSVC}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface LinearSVCParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasMaxIter<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                HasLearningRate<T>,
+                HasGlobalBatchSize<T>,
+                HasTol<T>,
+                LinearSVCModelParams<T> {}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java
new file mode 100644
index 0000000..eb0f3bf
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/HingeLoss.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.classification.linearsvc.LinearSVC;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * The loss function for hinge loss. See {@link LinearSVC} for example.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hinge_loss.
+ */
+@Internal
+public class HingeLoss implements LossFunc {
+    public static final HingeLoss INSTANCE = new HingeLoss();
+
+    private HingeLoss() {}
+
+    @Override
+    public double computeLoss(LabeledPointWithWeight dataPoint, DenseVector 
coefficient) {
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
+        return dataPoint.getWeight() * Math.max(0, 1 - labelScaled * dot);
+    }
+
+    @Override
+    public void computeGradient(
+            LabeledPointWithWeight dataPoint, DenseVector coefficient, 
DenseVector cumGradient) {
+        double dot = BLAS.dot(dataPoint.getFeatures(), coefficient);
+        double labelScaled = 2 * dataPoint.getLabel() - 1;
+        if (1 - labelScaled * dot > 0) {
+            BLAS.axpy(
+                    -labelScaled * dataPoint.getWeight(),
+                    dataPoint.getFeatures(),
+                    cumGradient,
+                    dataPoint.getFeatures().size());
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
similarity index 52%
copy from 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
copy to 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
index 8d06613..156244e 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LinearSVCTest.java
@@ -23,9 +23,9 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.configuration.Configuration;
-import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
-import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
-import 
org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import org.apache.flink.ml.classification.linearsvc.LinearSVC;
+import org.apache.flink.ml.classification.linearsvc.LinearSVCModel;
+import org.apache.flink.ml.classification.linearsvc.LinearSVCModelData;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
@@ -39,7 +39,6 @@ import org.apache.flink.types.Row;
 
 import org.apache.commons.collections.IteratorUtils;
 import org.apache.commons.lang3.RandomUtils;
-import org.apache.commons.lang3.exception.ExceptionUtils;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -53,10 +52,9 @@ import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
 
-/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */
-public class LogisticRegressionTest {
+/** Tests {@link LinearSVC} and {@link LinearSVCModel}. */
+public class LinearSVCTest {
 
     @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
 
@@ -64,7 +62,7 @@ public class LogisticRegressionTest {
 
     private StreamTableEnvironment tEnv;
 
-    private static final List<Row> binomialTrainData =
+    private static final List<Row> trainData =
             Arrays.asList(
                     Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
                     Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
@@ -77,27 +75,12 @@ public class LogisticRegressionTest {
                     Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
                     Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
 
-    private static final List<Row> multinomialTrainData =
-            Arrays.asList(
-                    Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.),
-                    Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.),
-                    Row.of(Vectors.dense(3, 2, 3, 4), 2., 3.),
-                    Row.of(Vectors.dense(4, 2, 3, 4), 2., 4.),
-                    Row.of(Vectors.dense(5, 2, 3, 4), 2., 5.),
-                    Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.),
-                    Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.),
-                    Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.),
-                    Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.),
-                    Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.));
-
     private static final double[] expectedCoefficient =
-            new double[] {0.525, -0.283, -0.425, -0.567};
+            new double[] {0.470, -0.273, -0.410, -0.546};
 
     private static final double TOLERANCE = 1e-7;
 
-    private Table binomialDataTable;
-
-    private Table multinomialDataTable;
+    private Table trainDataTable;
 
     @Before
     public void before() {
@@ -108,20 +91,11 @@ public class LogisticRegressionTest {
         env.enableCheckpointing(100);
         env.setRestartStrategy(RestartStrategies.noRestart());
         tEnv = StreamTableEnvironment.create(env);
-        Collections.shuffle(binomialTrainData);
-        binomialDataTable =
-                tEnv.fromDataStream(
-                        env.fromCollection(
-                                binomialTrainData,
-                                new RowTypeInfo(
-                                        new TypeInformation[] {
-                                            DenseVectorTypeInfo.INSTANCE, 
Types.DOUBLE, Types.DOUBLE
-                                        },
-                                        new String[] {"features", "label", 
"weight"})));
-        multinomialDataTable =
+        Collections.shuffle(trainData);
+        trainDataTable =
                 tEnv.fromDataStream(
                         env.fromCollection(
-                                multinomialTrainData,
+                                trainData,
                                 new RowTypeInfo(
                                         new TypeInformation[] {
                                             DenseVectorTypeInfo.INSTANCE, 
Types.DOUBLE, Types.DOUBLE
@@ -140,31 +114,31 @@ public class LogisticRegressionTest {
             DenseVector rawPrediction = (DenseVector) 
predictionRow.getField(rawPredictionCol);
             if (feature.get(0) <= 5) {
                 assertEquals(0, prediction, TOLERANCE);
-                assertTrue(rawPrediction.get(0) > 0.5);
+                assertTrue(rawPrediction.get(0) < 0);
             } else {
                 assertEquals(1, prediction, TOLERANCE);
-                assertTrue(rawPrediction.get(0) < 0.5);
+                assertTrue(rawPrediction.get(0) > 0);
             }
         }
     }
 
     @Test
     public void testParam() {
-        LogisticRegression logisticRegression = new LogisticRegression();
-        assertEquals("label", logisticRegression.getLabelCol());
-        assertNull(logisticRegression.getWeightCol());
-        assertEquals(20, logisticRegression.getMaxIter());
-        assertEquals(0, logisticRegression.getReg(), TOLERANCE);
-        assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE);
-        assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE);
-        assertEquals(32, logisticRegression.getGlobalBatchSize());
-        assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE);
-        assertEquals("auto", logisticRegression.getMultiClass());
-        assertEquals("features", logisticRegression.getFeaturesCol());
-        assertEquals("prediction", logisticRegression.getPredictionCol());
-        assertEquals("rawPrediction", 
logisticRegression.getRawPredictionCol());
+        LinearSVC linearSVC = new LinearSVC();
+        assertEquals("features", linearSVC.getFeaturesCol());
+        assertEquals("label", linearSVC.getLabelCol());
+        assertNull(linearSVC.getWeightCol());
+        assertEquals(20, linearSVC.getMaxIter());
+        assertEquals(1e-6, linearSVC.getTol(), TOLERANCE);
+        assertEquals(0.1, linearSVC.getLearningRate(), TOLERANCE);
+        assertEquals(32, linearSVC.getGlobalBatchSize());
+        assertEquals(0, linearSVC.getReg(), TOLERANCE);
+        assertEquals(0, linearSVC.getElasticNet(), TOLERANCE);
+        assertEquals(0.0, linearSVC.getThreshold(), TOLERANCE);
+        assertEquals("prediction", linearSVC.getPredictionCol());
+        assertEquals("rawPrediction", linearSVC.getRawPredictionCol());
 
-        logisticRegression
+        linearSVC
                 .setFeaturesCol("test_features")
                 .setLabelCol("test_label")
                 .setWeightCol("test_weight")
@@ -174,34 +148,34 @@ public class LogisticRegressionTest {
                 .setGlobalBatchSize(1000)
                 .setReg(0.1)
                 .setElasticNet(0.5)
-                .setMultiClass("binomial")
+                .setThreshold(0.5)
                 .setPredictionCol("test_predictionCol")
                 .setRawPredictionCol("test_rawPredictionCol");
-        assertEquals("test_features", logisticRegression.getFeaturesCol());
-        assertEquals("test_label", logisticRegression.getLabelCol());
-        assertEquals("test_weight", logisticRegression.getWeightCol());
-        assertEquals(1000, logisticRegression.getMaxIter());
-        assertEquals(0.001, logisticRegression.getTol(), TOLERANCE);
-        assertEquals(0.5, logisticRegression.getLearningRate(), TOLERANCE);
-        assertEquals(1000, logisticRegression.getGlobalBatchSize());
-        assertEquals(0.1, logisticRegression.getReg(), TOLERANCE);
-        assertEquals(0.5, logisticRegression.getElasticNet(), TOLERANCE);
-        assertEquals("binomial", logisticRegression.getMultiClass());
-        assertEquals("test_predictionCol", 
logisticRegression.getPredictionCol());
-        assertEquals("test_rawPredictionCol", 
logisticRegression.getRawPredictionCol());
+        assertEquals("test_features", linearSVC.getFeaturesCol());
+        assertEquals("test_label", linearSVC.getLabelCol());
+        assertEquals("test_weight", linearSVC.getWeightCol());
+        assertEquals(1000, linearSVC.getMaxIter());
+        assertEquals(0.001, linearSVC.getTol(), TOLERANCE);
+        assertEquals(0.5, linearSVC.getLearningRate(), TOLERANCE);
+        assertEquals(1000, linearSVC.getGlobalBatchSize());
+        assertEquals(0.1, linearSVC.getReg(), TOLERANCE);
+        assertEquals(0.5, linearSVC.getElasticNet(), TOLERANCE);
+        assertEquals(0.5, linearSVC.getThreshold(), TOLERANCE);
+        assertEquals("test_predictionCol", linearSVC.getPredictionCol());
+        assertEquals("test_rawPredictionCol", linearSVC.getRawPredictionCol());
     }
 
     @Test
     public void testOutputSchema() {
-        Table tempTable = binomialDataTable.as("test_features", "test_label", 
"test_weight");
-        LogisticRegression logisticRegression =
-                new LogisticRegression()
+        Table tempTable = trainDataTable.as("test_features", "test_label", 
"test_weight");
+        LinearSVC linearSVC =
+                new LinearSVC()
                         .setFeaturesCol("test_features")
                         .setLabelCol("test_label")
                         .setWeightCol("test_weight")
                         .setPredictionCol("test_predictionCol")
                         .setRawPredictionCol("test_rawPredictionCol");
-        Table output = 
logisticRegression.fit(binomialDataTable).transform(tempTable)[0];
+        Table output = linearSVC.fit(trainDataTable).transform(tempTable)[0];
         assertEquals(
                 Arrays.asList(
                         "test_features",
@@ -214,42 +188,42 @@ public class LogisticRegressionTest {
 
     @Test
     public void testFitAndPredict() throws Exception {
-        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
-        Table output = 
logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0];
+        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+        Table output = 
linearSVC.fit(trainDataTable).transform(trainDataTable)[0];
         verifyPredictionResult(
                 output,
-                logisticRegression.getFeaturesCol(),
-                logisticRegression.getPredictionCol(),
-                logisticRegression.getRawPredictionCol());
+                linearSVC.getFeaturesCol(),
+                linearSVC.getPredictionCol(),
+                linearSVC.getRawPredictionCol());
     }
 
     @Test
     public void testSaveLoadAndPredict() throws Exception {
-        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
-        logisticRegression =
+        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+        linearSVC =
                 StageTestUtils.saveAndReload(
-                        tEnv, logisticRegression, 
tempFolder.newFolder().getAbsolutePath());
-        LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
+                        tEnv, linearSVC, 
tempFolder.newFolder().getAbsolutePath());
+        LinearSVCModel model = linearSVC.fit(trainDataTable);
         model = StageTestUtils.saveAndReload(tEnv, model, 
tempFolder.newFolder().getAbsolutePath());
         assertEquals(
                 Collections.singletonList("coefficient"),
                 model.getModelData()[0].getResolvedSchema().getColumnNames());
-        Table output = model.transform(binomialDataTable)[0];
+        Table output = model.transform(trainDataTable)[0];
         verifyPredictionResult(
                 output,
-                logisticRegression.getFeaturesCol(),
-                logisticRegression.getPredictionCol(),
-                logisticRegression.getRawPredictionCol());
+                linearSVC.getFeaturesCol(),
+                linearSVC.getPredictionCol(),
+                linearSVC.getRawPredictionCol());
     }
 
     @Test
     @SuppressWarnings("unchecked")
     public void testGetModelData() throws Exception {
-        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
-        LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
-        List<LogisticRegressionModelData> modelData =
+        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+        LinearSVCModel model = linearSVC.fit(trainDataTable);
+        List<LinearSVCModelData> modelData =
                 IteratorUtils.toList(
-                        
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                        
LinearSVCModelData.getModelDataStream(model.getModelData()[0])
                                 .executeAndCollect());
         assertEquals(1, modelData.size());
         assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, 0.1);
@@ -257,68 +231,74 @@ public class LogisticRegressionTest {
 
     @Test
     public void testSetModelData() throws Exception {
-        LogisticRegression logisticRegression = new 
LogisticRegression().setWeightCol("weight");
-        LogisticRegressionModel model = 
logisticRegression.fit(binomialDataTable);
+        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+        LinearSVCModel model = linearSVC.fit(trainDataTable);
 
-        LogisticRegressionModel newModel = new LogisticRegressionModel();
+        LinearSVCModel newModel = new LinearSVCModel();
         ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
         newModel.setModelData(model.getModelData());
-        Table output = newModel.transform(binomialDataTable)[0];
+        Table output = newModel.transform(trainDataTable)[0];
         verifyPredictionResult(
                 output,
-                logisticRegression.getFeaturesCol(),
-                logisticRegression.getPredictionCol(),
-                logisticRegression.getRawPredictionCol());
-    }
-
-    @Test
-    public void testMultinomialFit() {
-        try {
-            new LogisticRegression().fit(multinomialDataTable);
-            env.execute();
-            fail();
-        } catch (Throwable e) {
-            assertEquals(
-                    "Multinomial classification is not supported yet. 
Supported options: [auto, binomial].",
-                    ExceptionUtils.getRootCause(e).getMessage());
-        }
+                linearSVC.getFeaturesCol(),
+                linearSVC.getPredictionCol(),
+                linearSVC.getRawPredictionCol());
     }
 
     @Test
     public void testMoreSubtaskThanData() throws Exception {
         env.setParallelism(12);
-        LogisticRegression logisticRegression =
-                new 
LogisticRegression().setWeightCol("weight").setGlobalBatchSize(128);
-        Table output = 
logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0];
+        LinearSVC linearSVC = new 
LinearSVC().setWeightCol("weight").setGlobalBatchSize(128);
+        Table output = 
linearSVC.fit(trainDataTable).transform(trainDataTable)[0];
         verifyPredictionResult(
                 output,
-                logisticRegression.getFeaturesCol(),
-                logisticRegression.getPredictionCol(),
-                logisticRegression.getRawPredictionCol());
+                linearSVC.getFeaturesCol(),
+                linearSVC.getPredictionCol(),
+                linearSVC.getRawPredictionCol());
     }
 
     @Test
     public void testRegularization() throws Exception {
         checkRegularization(0, RandomUtils.nextDouble(0, 1), 
expectedCoefficient);
-        checkRegularization(0.1, 0, new double[] {0.484, -0.258, -0.388, 
-0.517});
-        checkRegularization(0.1, 1, new double[] {0.417, -0.145, -0.312, 
-0.480});
-        checkRegularization(0.1, 0.5, new double[] {0.451, -0.203, -0.351, 
-0.498});
+        checkRegularization(0.1, 0, new double[] {0.437, -0.262, -0.393, 
-0.524});
+        checkRegularization(0.1, 1, new double[] {0.426, -0.197, -0.329, 
-0.463});
+        checkRegularization(0.1, 0.5, new double[] {0.419, -0.238, -0.372, 
-0.505});
+    }
+
+    @Test
+    public void testThreshold() throws Exception {
+        checkThreshold(-Double.MAX_VALUE, 1);
+        checkThreshold(Double.MAX_VALUE, 0);
     }
 
     @SuppressWarnings("unchecked")
     private void checkRegularization(double reg, double elasticNet, double[] 
expectedCoefficient)
             throws Exception {
-        LogisticRegressionModel model =
-                new LogisticRegression()
+        LinearSVCModel model =
+                new LinearSVC()
                         .setWeightCol("weight")
                         .setReg(reg)
                         .setElasticNet(elasticNet)
-                        .fit(binomialDataTable);
-        List<LogisticRegressionModelData> modelData =
+                        .fit(trainDataTable);
+        List<LinearSVCModelData> modelData =
                 IteratorUtils.toList(
-                        
LogisticRegressionModelData.getModelDataStream(model.getModelData()[0])
+                        
LinearSVCModelData.getModelDataStream(model.getModelData()[0])
                                 .executeAndCollect());
         final double errorTol = 1e-3;
         assertArrayEquals(expectedCoefficient, 
modelData.get(0).coefficient.values, errorTol);
     }
+
+    @SuppressWarnings("unchecked")
+    private void checkThreshold(double threshold, double target) throws 
Exception {
+        LinearSVC linearSVC = new LinearSVC().setWeightCol("weight");
+
+        Table predictions =
+                
linearSVC.setThreshold(threshold).fit(trainDataTable).transform(trainDataTable)[0];
+
+        List<Row> predResult =
+                
IteratorUtils.toList(tEnv.toDataStream(predictions).executeAndCollect());
+        for (Row r : predResult) {
+            assertEquals(target, r.getField(linearSVC.getPredictionCol()));
+        }
+    }
 }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
index 8d06613..4bdab6b 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java
@@ -151,16 +151,16 @@ public class LogisticRegressionTest {
     @Test
     public void testParam() {
         LogisticRegression logisticRegression = new LogisticRegression();
+        assertEquals("features", logisticRegression.getFeaturesCol());
         assertEquals("label", logisticRegression.getLabelCol());
         assertNull(logisticRegression.getWeightCol());
         assertEquals(20, logisticRegression.getMaxIter());
-        assertEquals(0, logisticRegression.getReg(), TOLERANCE);
-        assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE);
+        assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE);
         assertEquals(0.1, logisticRegression.getLearningRate(), TOLERANCE);
         assertEquals(32, logisticRegression.getGlobalBatchSize());
-        assertEquals(1e-6, logisticRegression.getTol(), TOLERANCE);
+        assertEquals(0, logisticRegression.getReg(), TOLERANCE);
+        assertEquals(0, logisticRegression.getElasticNet(), TOLERANCE);
         assertEquals("auto", logisticRegression.getMultiClass());
-        assertEquals("features", logisticRegression.getFeaturesCol());
         assertEquals("prediction", logisticRegression.getPredictionCol());
         assertEquals("rawPrediction", 
logisticRegression.getRawPredictionCol());
 
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java
new file mode 100644
index 0000000..1cd165e
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/lossfunc/HingeLossTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link HingeLoss}. */
+public class HingeLossTest {
+    private static final LabeledPointWithWeight dataPoint1 =
+            new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, -1.0), 1.0, 
2.0);
+    private static final LabeledPointWithWeight dataPoint2 =
+            new LabeledPointWithWeight(Vectors.dense(1.0, -1.0, 1.0), 1.0, 
2.0);
+    private static final DenseVector coefficient = Vectors.dense(1.0, 1.0, 
1.0);
+    private static final DenseVector cumGradient = Vectors.dense(0.0, 0.0, 
0.0);
+    private static final double TOLERANCE = 1e-7;
+
+    @Test
+    public void computeLoss() {
+        double loss = HingeLoss.INSTANCE.computeLoss(dataPoint1, coefficient);
+        assertEquals(4.0, loss, TOLERANCE);
+
+        loss = HingeLoss.INSTANCE.computeLoss(dataPoint2, coefficient);
+        assertEquals(0.0, loss, TOLERANCE);
+    }
+
+    @Test
+    public void computeGradient() {
+        HingeLoss.INSTANCE.computeGradient(dataPoint1, coefficient, 
cumGradient);
+        assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, 
TOLERANCE);
+
+        HingeLoss.INSTANCE.computeGradient(dataPoint2, coefficient, 
cumGradient);
+        assertArrayEquals(new double[] {-2.0, 2.0, 2.0}, cumGradient.values, 
TOLERANCE);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
index 58e89c4..3ea99d3 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java
@@ -118,15 +118,15 @@ public class LinearRegressionTest {
     @Test
     public void testParam() {
         LinearRegression linearRegression = new LinearRegression();
+        assertEquals("features", linearRegression.getFeaturesCol());
         assertEquals("label", linearRegression.getLabelCol());
         assertNull(linearRegression.getWeightCol());
         assertEquals(20, linearRegression.getMaxIter());
-        assertEquals(0, linearRegression.getReg(), TOLERANCE);
-        assertEquals(0, linearRegression.getElasticNet(), TOLERANCE);
+        assertEquals(1e-6, linearRegression.getTol(), TOLERANCE);
         assertEquals(0.1, linearRegression.getLearningRate(), TOLERANCE);
         assertEquals(32, linearRegression.getGlobalBatchSize());
-        assertEquals(1e-6, linearRegression.getTol(), TOLERANCE);
-        assertEquals("features", linearRegression.getFeaturesCol());
+        assertEquals(0, linearRegression.getReg(), TOLERANCE);
+        assertEquals(0, linearRegression.getElasticNet(), TOLERANCE);
         assertEquals("prediction", linearRegression.getPredictionCol());
 
         linearRegression

Reply via email to