[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771431959



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
##
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.feature;
+
+import org.apache.flink.ml.linalg.DenseVector;
+
+/** Utility class to represent a data point that contains features, label and 
weight. */
+public class LabeledPointWithWeight {
+
+private DenseVector features;
+
+private double label;
+
+private double weight;
+
+public LabeledPointWithWeight(DenseVector features, double label, double 
weight) {
+this.features = features;
+this.label = label;
+this.weight = weight;
+}
+
+public LabeledPointWithWeight() {}
+
+public DenseVector getFeatures() {

Review comment:
   Thanks for the suggestion. It is updated by removing all getter/setters 
and making these fields public.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771442195



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LogisticRegressionModel}.
+ *
+ * This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {

Review comment:
   An example could be: `TerminateOnMaxIter` accepts a Generic as Input and 
outputs a Integer. 
   
   As I see, the input generic is going to use kryo. (Correct me if I am wrong)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771442195



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LogisticRegressionModel}.
+ *
+ * This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {

Review comment:
   `TerminateOnMaxIter` accepts a Generic as Input and outputs a Integer. 
As I see, the input generic is going to use kryo.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771431959



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java
##
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.feature;
+
+import org.apache.flink.ml.linalg.DenseVector;
+
+/** Utility class to represent a data point that contains features, label and 
weight. */
+public class LabeledPointWithWeight {
+
+private DenseVector features;
+
+private double label;
+
+private double weight;
+
+public LabeledPointWithWeight(DenseVector features, double label, double 
weight) {
+this.features = features;
+this.label = label;
+this.weight = weight;
+}
+
+public LabeledPointWithWeight() {}
+
+public DenseVector getFeatures() {

Review comment:
   These getter/setter methods are expected to be called when doing 
(de)serialization.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771240813



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LogisticRegressionModel}.
+ *
+ * This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {

Review comment:
   The model data of Kmeans, LR, and NB have been changed to POJOs in the 
latest commit.
   
   >BTW, in order to make sure that we don't accidentally use kryo serializer, 
how about we set env.getConfig().disableGenericTypes() in every algorithm's 
test?
   
   For small objects, I think we should still allow kryo?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771240813



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LogisticRegressionModel}.
+ *
+ * This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {

Review comment:
   The model data of Kmeans, LR, and NB have been changed to POJOs in the 
latest commit.
   
   >BTW, in order to make sure that we don't accidentally use kryo serializer, 
how about we set env.getConfig().disableGenericTypes() in every algorithm's 
test?
   
   Do you mean that we add this in the unit test?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression

2021-12-17 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r771192152



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java
##
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link 
LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private final Map, Object> paramMap = new HashMap<>();
+
+private Table modelDataTable;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelDataTable),
+path,
+new LogisticRegressionModelData.ModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+DataStream modelData =
+ReadWriteUtils.loadModelData(
+env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+return model.setModelData(tEnv.fromDataStream(modelData));
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelDataTable = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelDataTable};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelDataStream =
+LogisticRegressionModelData.getModelDataStream(modelDataTable);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeInfo.DOUBLE_TYPE_INFO,
+ 

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-13 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r767719555



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LogisticRegressionModel}.
+ *
+ * This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {

Review comment:
   Good observation here~
   
   In my understanding, the first answer is yes and second answer is no.
   
   As `LogisticRegressionModelData` is not serializable here or a POJO, Flink 
is supposed to treat `LogisticRegressionModelData` as a generic and use default 
serializers (e..g, KryoSerializer) to do the (de)serialization.
   
   If I am correct, we may need to create a type serializer for it or we can 
simply make it a POJO?
   
   Let's see how @gaoyunhaii thinks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-13 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r767719555



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link LogisticRegressionModel}.
+ *
+ * This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {

Review comment:
   Good observation here~
   
   In my understanding, the first answer is yes and second answer is no.
   
   As `LogisticRegressionModelData` is not serializable here, Flink is supposed 
to treat `LogisticRegressionModelData` as a generic and use default serializers 
(e..g, KryoSerializer) to do the (de)serialization.
   
   If I am correct, we may need to create a type serializer for it or we can 
simply make it a POJO?
   
   Let's see how @gaoyunhaii thinks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-12 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r767459284



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param outputType The type information of the output element.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(

Review comment:
   After some offline discussion, we agree to remove this method for now 
and add it back when needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-12 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r767446298



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
##
@@ -329,33 +342,42 @@ public static void updateExistingParams(Stage stage, 
Map, Object> pa
  * @param model The model data stream.
  * @param path The parent directory of the model data file.
  * @param modelEncoder The encoder to encode the model data.
+ * @param modelIndex The index of the table to save.
  * @param  The class type of the model data.
  */
 public static  void saveModelData(
-DataStream model, String path, Encoder modelEncoder) {
+DataStream model, String path, Encoder modelEncoder, int 
modelIndex) {
 FileSink sink =
 FileSink.forRowFormat(
-new 
org.apache.flink.core.fs.Path(getDataPath(path)), modelEncoder)
+new 
org.apache.flink.core.fs.Path(getDataPath(path, modelIndex)),
+modelEncoder)
 .withRollingPolicy(OnCheckpointRollingPolicy.build())
 .withBucketAssigner(new BasePathBucketAssigner<>())
 .build();
 model.sinkTo(sink);
 }
 
 /**
- * Loads the model data from the given path using the model decoder.
+ * Loads the model table with index `modelIndex` from the given path using 
the model decoder.
  *
  * @param env A StreamExecutionEnvironment instance.
  * @param path The parent directory of the model data file.
  * @param modelDecoder The decoder used to decode the model data.
+ * @param modelIndex The index of the table to load.
  * @param  The class type of the model data.
  * @return The loaded model data.
  */
-public static  DataStream loadModelData(
-StreamExecutionEnvironment env, String path, SimpleStreamFormat 
modelDecoder) {
+public static  Table loadModelData(

Review comment:
   I thought there would be cases for the model data that contains multiple 
tables. After some offline discussion, we aggree to remove `modelIndex` for now 
and support multiple tables when we come to the real use case.
   
   We also aggree to let `loadModelData ` to return a DataStream for two 
reasons: (1) save and load could be symmetric (2) `saveModelData` and 
`loadModelData` are utility functions for developers, who are often using 
`DataStream`s.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-12 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r76701



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param outputType The type information of the output element.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(

Review comment:
   We let the users explicitly provide the `outputType` here because 
`TypeExtractor` may not be able infer the output type in some cases. (e.g., 
when the output type is a `Row`, we cannot infer each field of the row).
   
   Note that this design pattern is already used in DataStream. Please checkout 
the code example for `DataStream#map`.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-08 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r765474898



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+public static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
outputStream)
+throws IOException {
+DenseVectorSerializer serializer = new DenseVectorSerializer();
+serializer.serialize(
+modelData.coefficient, new 
DataOutputViewStreamWrapper(outputStream));
+}
+}
+
+/** Data decoder for {@link LogisticRegressionModel}. */
+public static class ModelDataDecoder extends 
SimpleStreamFormat {

Review comment:
   `KmeansModelData` is already updated. For naivebayes, how about we do 
the update in NaiveBayes PR given that it is not merged yet?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-08 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r765480782



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+public static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
outputStream)
+throws IOException {
+DenseVectorSerializer serializer = new DenseVectorSerializer();
+serializer.serialize(
+modelData.coefficient, new 
DataOutputViewStreamWrapper(outputStream));
+}
+}
+
+/** Data decoder for {@link LogisticRegressionModel}. */
+public static class ModelDataDecoder extends 
SimpleStreamFormat {
+
+@Override
+public Reader createReader(
+Configuration configuration, FSDataInputStream inputStream) {
+return new Reader() {
+
+@Override
+public LogisticRegressionModelData read() throws IOException {
+DenseVectorSerializer serializer = new 
DenseVectorSerializer();

Review comment:
   Thanks for pointing this out. 
   I found that `Serializer` in Flink often made the class member `INSTANCE` 
public, how about we do the same for `DenseVectorSerializer` such that we can 
avoid creating new objects here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-08 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r765474898



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+public static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
outputStream)
+throws IOException {
+DenseVectorSerializer serializer = new DenseVectorSerializer();
+serializer.serialize(
+modelData.coefficient, new 
DataOutputViewStreamWrapper(outputStream));
+}
+}
+
+/** Data decoder for {@link LogisticRegressionModel}. */
+public static class ModelDataDecoder extends 
SimpleStreamFormat {

Review comment:
   `KmeansModelData` is already updated. For naivebayes, how about we do 
the update in NaiveBayes PR given that it is not merged yet?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-07 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r764491505



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+new LogisticRegressionModelData.ModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeI

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-07 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762762819



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+new LogisticRegressionModelData.ModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeI

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-07 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r764491505



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+new LogisticRegressionModelData.ModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeI

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-06 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762780423



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;
+
+public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+this.maxIter = maxIter;
+this.tol = tol;
+}
+
+public TerminateOnMaxIterOrTol(int maxIter) {

Review comment:
   Sure, let's remove this constructor here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-06 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762777943



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;
+
+public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+this.maxIter = maxIter;
+this.tol = tol;
+}
+
+public TerminateOnMaxIterOrTol(int maxIter) {
+this.maxIter = maxIter;
+this.tol = Double.NEGATIVE_INFINITY;
+}
+
+public TerminateOnMaxIterOrTol(double tol) {
+this.maxIter = Integer.MAX_VALUE;
+this.tol = tol;
 }
 
 @Override
-public void flatMap(EpochRecord integer, Collector collector) 
throws Exception {}
+public void flatMap(Double value, Collector out) {
+this.loss = value;

Review comment:
   I think throwing exception to enforce only one input in each epoch is a 
viable solution.
   
   Users are not supposed to use `TerminateOnMaxIterOrTol` or 
`TerminateOnMaxIter` for asynchronous iterations --- Each worker is expected to 
decide its termination relying on its own loss..




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-06 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762777943



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;
+
+public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+this.maxIter = maxIter;
+this.tol = tol;
+}
+
+public TerminateOnMaxIterOrTol(int maxIter) {
+this.maxIter = maxIter;
+this.tol = Double.NEGATIVE_INFINITY;
+}
+
+public TerminateOnMaxIterOrTol(double tol) {
+this.maxIter = Integer.MAX_VALUE;
+this.tol = tol;
 }
 
 @Override
-public void flatMap(EpochRecord integer, Collector collector) 
throws Exception {}
+public void flatMap(Double value, Collector out) {
+this.loss = value;

Review comment:
   I think throwing exception to enforce only one input in each epoch is a 
viable solution.
   
   Users are not supposed to use `TerminateOnMaxIterOrTol` or 
`TerminateOnMaxIter` for asynchronous iterations --- Users are expected to 
decide termination relying on their own loss..




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762762819



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+new LogisticRegressionModelData.ModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeI

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762733954



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;

Review comment:
   There should always be one loss value in each epoch (including the first 
one).
   
   Thanks for pointing this out, I think I made a mistake here --- the initial 
value of `loss` should be `Double.MAX_VALUE`.

##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;

Review comment:
   There should always be one loss value in each epoch (including the first 
one).
   
   Thanks for pointing this out, I made a mistake here --- the initial value of 
`loss` should be `Double.MAX_VALUE`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762726433



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+new LogisticRegressionModelData.ModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, new 
LogisticRegressionModelData.ModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeI

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762724172



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;

Review comment:
   Ah, thanks for pointing this out. It should be `Double.MAX_VALUE`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762723908



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;
+
+public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+this.maxIter = maxIter;
+this.tol = tol;
+}
+
+public TerminateOnMaxIterOrTol(int maxIter) {

Review comment:
   I kept it here because it is more intuitive when users want to terminate 
the iteration by `loss` first, then switched to `maxIter`.
   
   If we remove this constructor, I feel like it is not a complete one. What do 
you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762722161



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -16,29 +16,53 @@
  * limitations under the License.
  */
 
-package org.apache.flink.test.iteration.operators;
+package org.apache.flink.ml.common.iteration;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.util.Collector;
 
-/** An termination criteria function that asks to stop after the specialized 
round. */
-public class RoundBasedTerminationCriteria
-implements FlatMapFunction, 
IterationListener {
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss exceeds a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
+ * body, the iteration will be executed for at most the given `maxIter` 
iterations. And the
+ * iteration will terminate once any input value is smaller than or equal to 
the given `tol`.
+ */
+public class TerminateOnMaxIterOrTol
+implements IterationListener, FlatMapFunction {
+
+private final int maxIter;
 
-private final int maxRound;
+private final double tol;
 
-public RoundBasedTerminationCriteria(int maxRound) {
-this.maxRound = maxRound;
+private double loss = Double.NEGATIVE_INFINITY;
+
+public TerminateOnMaxIterOrTol(int maxIter, double tol) {
+this.maxIter = maxIter;
+this.tol = tol;
+}
+
+public TerminateOnMaxIterOrTol(int maxIter) {
+this.maxIter = maxIter;
+this.tol = Double.NEGATIVE_INFINITY;
+}
+
+public TerminateOnMaxIterOrTol(double tol) {
+this.maxIter = Integer.MAX_VALUE;
+this.tol = tol;
 }
 
 @Override
-public void flatMap(EpochRecord integer, Collector collector) 
throws Exception {}
+public void flatMap(Double value, Collector out) {
+this.loss = value;

Review comment:
   Hmm, there should be only one loss value in each epoch. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762671142



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/**
+ * Gets the data encoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data encoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataEncoder getModelDataEncoder() {
+return new ModelDataEncoder();
+}
+
+/**
+ * Gets the data decoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data decoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataDecoder getModelDataDecoder() {
+return new ModelDataDecoder();
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+private static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
stream) {

Review comment:
   Sounds good. I have updated the encoder and decoder for both LR and 
Kmeans.
   
   I did not update the other parts since some of the Kmeans code seems to have 
been updated in [https://github.com/apache/flink-ml/pull/32](url). 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-05 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762403551



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/**
+ * Gets the data encoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data encoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataEncoder getModelDataEncoder() {
+return new ModelDataEncoder();
+}
+
+/**
+ * Gets the data decoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data decoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataDecoder getModelDataDecoder() {
+return new ModelDataDecoder();
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+private static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
stream) {

Review comment:
   I tried to re-used `DenseVectorSerializer ` but encountered the 
following exception:
   ```
   Caused by: java.io.EOFException
at java.io.DataInputStream.readInt(DataInputStream.java:392)
at 
org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer.deserialize(DenseVectorSerializer.java:85)
at 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData$ModelDataDecoder$1.read(LogisticRegressionModelData.java:100)
at 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData$ModelDataDecoder$1.read(LogisticRegressionModelData.java:83)
   ```
   I guess it is because I did not check whether the input stream has reached 
the end. However, I did not find an API that allows me to check the end of of 
`InputStream` without reading one byte. Do you have a solution here?
   
   The code snippet is as follows:
   
   ```
   new Reader() {
   private final DenseVectorSerializer serializer = new 
DenseVectorSerializer();
   
   @Override
   public LogisticRegressionModelData read() throws IOException{
   DenseVector coefficient = serializer.deserialize(new 
DataInputViewStreamWrapper(stream));
   return new LogisticRegressionModelData(coefficient);
   }
  };
   ```
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to 

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762403691



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+LogisticRegressionModelData.getModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeInf

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762403551



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/**
+ * Gets the data encoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data encoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataEncoder getModelDataEncoder() {
+return new ModelDataEncoder();
+}
+
+/**
+ * Gets the data decoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data decoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataDecoder getModelDataDecoder() {
+return new ModelDataDecoder();
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+private static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
stream) {

Review comment:
   I tried to re-used `DenseVectorSerializer ` but encountered the 
following exception:
   ```
   Caused by: java.io.EOFException
at java.io.DataInputStream.readInt(DataInputStream.java:392)
at 
org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer.deserialize(DenseVectorSerializer.java:85)
at 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData$ModelDataDecoder$1.read(LogisticRegressionModelData.java:100)
at 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData$ModelDataDecoder$1.read(LogisticRegressionModelData.java:83)
   ```
   I guess it is because I did not check whether the input stream has reached 
the end. However, I did not find an API that allows me to check the end of of 
`InputStream` without reading one byte. Do you have a solution here?
   
   The code snippet is as follows:
   
   ```
   new Reader() {
   private final DenseVectorSerializer serializer = new 
DenseVectorSerializer();
   
   @Override
   public LogisticRegressionModelData read() throws IOException{
   DenseVector coefficient = serializer.deserialize(new 
DataInputViewStreamWrapper(stream));
   return new LogisticRegressionModelData(coefficient);
   }
  };
   ```
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to 

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762403691



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table modelData;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(modelData),
+path,
+LogisticRegressionModelData.getModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+modelData = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {modelData};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream inputStream = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModelKey";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(this.modelData);
+RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+RowTypeInfo outputTypeInfo =
+new RowTypeInfo(
+ArrayUtils.addAll(
+inputTypeInfo.getFieldTypes(),
+BasicTypeInf

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762403551



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/**
+ * Gets the data encoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data encoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataEncoder getModelDataEncoder() {
+return new ModelDataEncoder();
+}
+
+/**
+ * Gets the data decoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data decoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataDecoder getModelDataDecoder() {
+return new ModelDataDecoder();
+}
+
+/** Data encoder for {@link LogisticRegressionModel}. */
+private static class ModelDataEncoder implements 
Encoder {
+
+@Override
+public void encode(LogisticRegressionModelData modelData, OutputStream 
stream) {

Review comment:
   I tried to re-used `DenseVectorSerializer ` but encountered the 
following exception:
   
   Caused by: java.io.EOFException
at java.io.DataInputStream.readInt(DataInputStream.java:392)
at 
org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer.deserialize(DenseVectorSerializer.java:85)
at 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData$ModelDataDecoder$1.read(LogisticRegressionModelData.java:100)
at 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData$ModelDataDecoder$1.read(LogisticRegressionModelData.java:83)
   
   I guess it is because I did not check whether the input stream has reached 
the end. However, I did not find an API that allows me to check the end of of 
`InputStream` without reading one byte. Do you have a solution here?
   
   The code snippet is as follows:
   
   ```
   new Reader() {
   private final DenseVectorSerializer serializer = new 
DenseVectorSerializer();
   
   @Override
   public LogisticRegressionModelData read() throws IOException{
   DenseVector coefficient = serializer.deserialize(new 
DataInputViewStreamWrapper(stream));
   return new LogisticRegressionModelData(coefficient);
   }
  };
   ```
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to 

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762399805



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {

Review comment:
   Just removed `implements Serializable`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762399576



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData implements Serializable {
+
+public final DenseVector coefficient;
+
+public LogisticRegressionModelData(DenseVector coefficient) {
+this.coefficient = coefficient;
+}
+
+/**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+public static DataStream 
getModelDataStream(Table modelData) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
modelData).getTableEnvironment();
+return tEnv.toDataStream(modelData).map(x -> 
(LogisticRegressionModelData) x.getField(0));
+}
+
+/**
+ * Gets the data encoder for {@link LogisticRegressionModelData}.
+ *
+ * @return The data encoder for {@link LogisticRegressionModelData}.
+ */
+public static ModelDataEncoder getModelDataEncoder() {

Review comment:
   I am not sure about the design pattern. Let's keep it simple now. (Just 
removed `getModelDataEncoder` and `getModelDataDecoder`)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-04 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762399279



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * An Estimator which implements the logistic regression algorithm.
+ *
+ * See https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.loadStageParam(path);
+}
+
+@Override
+@SuppressWarnings("rawTypes")
+public LogisticRegressionModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+String classificationType = getMultiClass();
+Preconditions.checkArgument(
+"auto".equals(classificationType) || 
"binomial".equals(classificationType),
+"Multinomial classification is not supported yet. Supported 
options: [auto, binomial].");
+StreamTableEnvironment tEnv =
+(

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-03 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r762398015



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r761677853



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,488 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag MODEL_OUTPUT =

Review comment:
   Hmm, thanks for the comment here!! I misunderstood the cost of creating 
output tag. I have made it a local variable.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r761636447



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r761636447



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r761636447



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r761633170



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760945867



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table model;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(model),
+path,
+LogisticRegressionModelData.getModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+model = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {model};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream data = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModel";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(model);
+DataStream predictResult =
+BroadcastUtils.withBroadcastStream(
+Collections.singletonList(data),
+Collections.singletonMap(broadcastModelKey, modelData),
+inputList -> {
+DataStream inputData = inputList.get(0);
+return inputData.transform(
+"doPredic

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760959166



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,488 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag MODEL_OUTPUT =

Review comment:
   I aggree with you on the general solution.
   
   But puting `OutputTag` as a global static variable is for performance issues 
because creating type information is expensive. (From @gaoyunhaii 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760928810



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMultiClass.java
##
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared multi-class param.
+ *
+ * Supported options:
+ * auto: select the version based on the number of classes: If numClasses 
is one or two, set to

Review comment:
   For the long run, we could add `ClassificationModel ` and `Classifier `. 
But we should probably do it later since it would not introduce a lot of 
refactoring here.
   
   Deciding `auto` has nothing to do with `ClassificationModel::numClasses` --- 
It is computed by the histogram of labels. Please feel free to checkout code at 
`LogisticRegression#L134` and Spark's implementation. [1]
   
   [1] 
https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala#L519
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760918165



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760918165



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760959166



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,488 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag MODEL_OUTPUT =

Review comment:
   I aggree with you on the general solution.
   
   But puting `OutputTag` as a global static variable is for performance 
issues. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760945867



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table model;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(model),
+path,
+LogisticRegressionModelData.getModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+model = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {model};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream data = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModel";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(model);
+DataStream predictResult =
+BroadcastUtils.withBroadcastStream(
+Collections.singletonList(data),
+Collections.singletonMap(broadcastModelKey, modelData),
+inputList -> {
+DataStream inputData = inputList.get(0);
+return inputData.transform(
+"doPredic

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760945867



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table model;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+ReadWriteUtils.saveModelData(
+LogisticRegressionModelData.getModelDataStream(model),
+path,
+LogisticRegressionModelData.getModelDataEncoder());
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+Table modelData =
+ReadWriteUtils.loadModelData(
+env, path, 
LogisticRegressionModelData.getModelDataDecoder());
+return model.setModelData(modelData);
+}
+
+@Override
+public LogisticRegressionModel setModelData(Table... inputs) {
+model = inputs[0];
+return this;
+}
+
+@Override
+public Table[] getModelData() {
+return new Table[] {model};
+}
+
+@Override
+@SuppressWarnings("unchecked")
+public Table[] transform(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream data = tEnv.toDataStream(inputs[0]);
+final String broadcastModelKey = "broadcastModel";
+DataStream modelData =
+LogisticRegressionModelData.getModelDataStream(model);
+DataStream predictResult =
+BroadcastUtils.withBroadcastStream(
+Collections.singletonList(data),
+Collections.singletonMap(broadcastModelKey, modelData),
+inputList -> {
+DataStream inputData = inputList.get(0);
+return inputData.transform(
+"doPredic

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760918165



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760931345



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIter.java
##
@@ -27,16 +27,17 @@
  * threshold.
  *
  * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration
- * body, the iteration will be executed for at most the given `numRounds` 
rounds.
+ * body, the iteration will be executed for at most the given `maxIter` rounds.
  *
  * @param  The class type of the input element.
  */
-public class TerminateOnMaxIterationNum
+public class TerminateOnMaxIter

Review comment:
   I would stick with TerminateOnMaxIter because 
`RoundBasedTerminationCriteria ` is not in the source code. 
   
   Moreover, it is not consistent with the `TerminateOnMaxIter ` in the 
following aspects:
   (1) It does not accept a generic as input
   (2) It does not run the same number of iterations as `TerminateOnMaxIter`
   (3) It does not specify a tolerance based termination criteria




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760928810



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMultiClass.java
##
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared multi-class param.
+ *
+ * Supported options:
+ * auto: select the version based on the number of classes: If numClasses 
is one or two, set to

Review comment:
   For the long run, we could add `ClassificationModel ` and `Classifier `. 
But we should probably do it later since it would not introduce a lot of 
refatoring here.
   
   Deciding `auto` has nothing to do with `ClassificationModel::numClasses` --- 
It is computed by the histogram of labels. Please feel free to checkout code at 
`LogisticRegression#L134` and Spark's implementation. [1]
   
   [1] 
https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala#L519
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760918165



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760913965



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input element.
+ * @return The sorted data stream.
+ */
+p

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760908183



##
File path: 
flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
##
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.common.typeutils.base.LongComparator;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.NumberSequenceIterator;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/** Tests the {@link DataStreamUtils}. */
+public class DataStreamUtilsTest {
+private StreamExecutionEnvironment env;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+}
+
+@Test
+@SuppressWarnings("unchecked")
+public void testDistinct() throws Exception {
+DataStream dataStream =
+env.fromParallelCollection(new NumberSequenceIterator(1L, 
10L), Types.LONG)

Review comment:
   I have removed distinct() and sortPartition() for now.
   
   BTW, it is not unique as I divided by two in the next line of code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760907345



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {

Review comment:
   I have removed `distinct()` and `sortPartition()` for now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760903491



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java
##
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.iteration;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.util.Collector;
+
+/**
+ * A FlatMapFunction that emits values iff the iteration's epochWatermark does 
not exceed a certain
+ * threshold and the loss does not exceed a certain tolerance.
+ *
+ * When the output of this FlatMapFunction is used as the termination 
criteria of an iteration

Review comment:
   Thanks for pointting this out. I have added `or equal to ` here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-02 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r760892058



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMultiClass.java
##
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared multi-class param.
+ *
+ * Supported options:
+ * auto: select the version based on the number of classes: If numClasses 
is one or two, set to
+ * "binomial". Otherwise, set to "multinomial".
+ * binomial: Binary logistic regression.
+ * multinomial: Multinomial logistic regression.
+ */
+public interface HasMultiClass extends WithParams {
+Param MULTI_CLASS =
+new StringParam(
+"multiClass",
+"Type of classification.",

Review comment:
   Thanks. I have changed the java doc here to `Classification type` and 
`HasDistanceMeasure`'s doc to `Distance measure` to be consistent with the 
description with the existing params.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-12-01 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759953240



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-30 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759828179



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-30 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759828179



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-30 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759818158



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-30 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759817519



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759012504



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759008827



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759007852



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r759003838



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r758973297



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r758973059



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r758968020



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r758901177



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
##
@@ -0,0 +1,244 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.common.utils.ComparatorAdapter;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+/**
+ * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+ * one double array in each partition. The result data stream has the same 
parallelism as the
+ * input, where each partition contains one double array that sums all of 
the double arrays in
+ * the input data stream.
+ *
+ * Note that we throw exception when one of the following two cases 
happen:
+ * There exists one partition that contains more than one double array.
+ * The length of the double array is not consistent among all 
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+public static DataStream allReduceSum(DataStream 
input) {
+return AllReduceImpl.allReduceSum(input);
+}
+
+/**
+ * Collects distinct values in a bounded data stream. The parallelism of 
the output stream is 1.
+ *
+ * @param  The class type of the input data stream.
+ * @param input The bounded input data stream.
+ * @return The result data stream that contains all the distinct values.
+ */
+public static  DataStream distinct(DataStream input) {
+return input.transform(
+"distinctInEachPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(input.getParallelism())
+.transform(
+"distinctInFinalPartition",
+input.getType(),
+new DistinctPartitionOperator<>())
+.setParallelism(1);
+}
+
+/**
+ * Applies a {@link MapPartitionFunction} on a bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param func The user defined mapPartition function.
+ * @param  The class type of the input element.
+ * @param  The class type of output element.
+ * @return The result data stream.
+ */
+public static  DataStream mapPartition(
+DataStream input, MapPartitionFunction func) {
+TypeInformation resultType =
+TypeExtractor.getMapPartitionReturnTypes(func, 
input.getType(), null, true);
+return input.transform("mapPartition", resultType, new 
MapPartitionOperator<>(func))
+.setParallelism(input.getParallelism());
+}
+
+/**
+ * Sorts the elements in each partition of the input bounded data stream.
+ *
+ * @param input The input data stream.
+ * @param comparator The comparator used to sort the elements.
+ * @param  The class type of input elem

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r758899907



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,653 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * This class implements methods to train a logistic regression model. For 
details, see
+ * https://en.wikipedia.org/wiki/Logistic_regression.
+ */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+return ReadWriteUtils.load

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-29 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r758873522



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/utils/ComparatorAdapter.java
##
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.utils;
+
+import org.apache.flink.api.common.typeutils.TypeComparator;
+
+import java.util.Comparator;
+
+/** Utility class to convert {@link TypeComparator} to a {@link Comparator}. */
+public class ComparatorAdapter {
+
+/**
+ * Adapts a {@link TypeComparator} to a {@link Comparator}.
+ *
+ * @param typeComparator The input typeComparator.
+ * @param  The class type of the input element.
+ * @return The converted comparator.
+ */
+public static  Comparator getComparator(TypeComparator 
typeComparator) {

Review comment:
   Thanks for the comments. I have removed this class.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757248612



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/SortPartitionImpl.java
##
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.ml.common.utils.ComparatorAdapter;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.List;
+
+/** Applies sortPartition to a bounded data stream. */
+class SortPartitionImpl {

Review comment:
   Sure, let's follow the existing practice. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757248185



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##
@@ -146,7 +145,7 @@ public IterationBodyResult process(
 DataStream points = dataStreams.get(0);
 
 DataStream terminationCriteria =
-centroids.flatMap(new 
TerminateOnMaxIterationNum<>(maxIterationNum));
+centroids.map(x -> 0.).flatMap(new 
TerminationCriteria(maxIterationNum));

Review comment:
   Having separate classess for different termination conditions is fine 
for me and also may not confuse users. Then we are going to have the following 
three classes:
   - `TerminateOnMaxIterationNum`
   - `TerminateOnToleranceThreshold`
   - `TerminateOnMaxIterationNumOrToleranceThreshold
   
   Do you think it is a viable solution?

##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##
@@ -146,7 +145,7 @@ public IterationBodyResult process(
 DataStream points = dataStreams.get(0);
 
 DataStream terminationCriteria =
-centroids.flatMap(new 
TerminateOnMaxIterationNum<>(maxIterationNum));
+centroids.map(x -> 0.).flatMap(new 
TerminationCriteria(maxIterationNum));

Review comment:
   Having separate classess for different termination conditions is fine 
for me and also may not confuse users. Then we are going to have the following 
three classes:
   - `TerminateOnMaxIterationNum`
   - `TerminateOnToleranceThreshold`
   - `TerminateOnMaxIterationNumOrToleranceThreshold`
   
   Do you think it is a viable solution?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757229351



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java
##
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.common.linalg.BLAS;
+
+import java.io.Serializable;
+
+/** Utility class to compute gradient and loss for logistic loss. */

Review comment:
   I am not sure whether we should do that detailed explanation. As the 
information is just on the website.
   
   I just added one url there. What do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757227966



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData {
+
+public final double[] coefficient;
+

Review comment:
   Let's follow the common pratice in Flink.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757227966



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData {
+
+public final double[] coefficient;
+

Review comment:
   Let's see the common pratice in Flink




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757227812



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap = new HashMap<>();
+
+private Table model;
+
+public LogisticRegressionModel() {
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
model).getTableEnvironment();
+String dataPath = ReadWriteUtils.getDataPath(path);
+FileSink sink =
+FileSink.forRowFormat(new Path(dataPath), new 
LogisticRegressionModelDataEncoder())
+.withRollingPolicy(OnCheckpointRollingPolicy.build())
+.withBucketAssigner(new BasePathBucketAssigner<>())
+.build();
+ReadWriteUtils.saveMetadata(this, path);
+tEnv.toDataStream(model)
+.map(x -> (LogisticRegressionModelData) x.getField(0))
+.sinkTo(sink)
+.setParallelism(1);
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+Source source =
+FileSource.forRecordStreamFormat(
+new LogisticRegressionModelDataStreamFormat(),
+ReadWriteUtils.getDataPaths(path))
+.build();
+LogisticRegressionModel model 

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757203755



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##
@@ -146,7 +145,7 @@ public IterationBodyResult process(
 DataStream points = dataStreams.get(0);
 
 DataStream terminationCriteria =
-centroids.flatMap(new 
TerminateOnMaxIterationNum<>(maxIterationNum));
+centroids.map(x -> 0.).flatMap(new 
TerminationCriteria(maxIterationNum));

Review comment:
   During ML training, a common termination criteria is to terminate when 
number of iterations exceeds the specified max number of iterations, or the 
loss is smaller than a given tolerance (a tolerance is a `double`). 
   
   If we want to allow termination on a data stream with arbitrary input type, 
we probably need to add a new class (i.e., `TerminateOnMaxIterationNum` here). 
   
   IMO, maintaining two utility classes for termination may confuse the 
developers. What do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757200422



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/SortPartitionImpl.java
##
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.ml.common.utils.ComparatorAdapter;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.List;
+
+/** Applies sortPartition to a bounded data stream. */
+class SortPartitionImpl {

Review comment:
   If we only have three utlitity function (i.e., 
`SortPartitionImpl/DistinctImpl/MapPartitionImpl`), I am fine to move them to 
`DataStreamUtils`.
   
   The question here is that we clearly will have more utility functions for 
DataStream in the future. Adding all of them to `DataStreamUtils` may make 
`DataStreamUtils` too complex.
   
   An alternative is to add a new package 
(`org.apache.flink.ml.common.datastream.impl`) and put all these implementation 
under `impl` package. What do you think? 
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757200422



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/SortPartitionImpl.java
##
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.ml.common.utils.ComparatorAdapter;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.List;
+
+/** Applies sortPartition to a bounded data stream. */
+class SortPartitionImpl {

Review comment:
   If we only have three utlitity function (i.e., 
`SortPartitionImpl/DistinctImpl/MapPartitionImpl`), I am fine to move them to 
`DataStreamUtils`.
   
   The question here is that we clearly will have more utility functions for 
DataStream in the future. Adding all of them to `DataStreamUtils` may make 
`DataStreamUtils` to complex.
   
   An alternative is to add a new package 
(`org.apache.flink.ml.common.datastream.impl`) and put all these implementation 
under `impl` package. What do you think? 
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-25 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757198840



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxIter.java
##
@@ -26,7 +26,7 @@
 /** Interface for the shared maxIter param. */
 public interface HasMaxIter extends WithParams {
 Param MAX_ITER =
-new IntParam("maxIter", "Maximum number of iterations.", 20, 
ParamValidators.gtEq(0));
+new IntParam("maxIter", "Maximum number of iterations.", 20, 
ParamValidators.gt(0));

Review comment:
   I prefer not. Is there a real use case for training a model iteratively 
while setting `MaxIterationNumber` to zero?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-24 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r755792254



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasPredictionDetailCol.java
##
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared prediction detail param. */
+public interface HasPredictionDetailCol extends WithParams {

Review comment:
   `HasProbabilityCol ` may not be a very good solution here. I checked for 
Spark and reused its `HasRawPredictionCol`.
   
   What do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-23 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r755088231



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java
##
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 regularization param. */
+public interface HasL2 extends WithParams {

Review comment:
   Good observation! I prefer using `HasReg` and `HasElasticNet` to 
spefcify the three modes, similar as Spark.
   
   I have renamed `HasL2` as `HasReg`. Given that we are not using 
`HasElasticNet` in LogisticRegression, we can add this later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-23 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r755088231



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java
##
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 regularization param. */
+public interface HasL2 extends WithParams {

Review comment:
   Good observation! I prefer using `HasReg` and `HasElasticNet` to 
spefcify the three modes, similar as Spark.
   
   I have renamed `HasL2` as HasReg`. Given that we are not using 
`HasElasticNet` in LogisticRegression, we can add this later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-23 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r755067122



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchSize.java
##
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batchSize param. */
+public interface HasBatchSize extends WithParams {
+
+Param BATCH_SIZE =
+new IntParam(
+"batchSize", "Batch size of training algorithms.", 100, 
ParamValidators.gt(0));

Review comment:
   Hi Dong, thanks for the comments. The doc you mentioned above is talking 
about setting batch size for a single machine with multiple cores. (Since the 
number of cores is usually power of 2.)
   
   However, doing distributed machine learning on multiple workers is a bit 
different from doing machine learning on a single machine --- We would have 
arbitary number of parallel instances.
   
   I have no preference on the default value of batch size. If you prefer using 
32 as the default value, please let me know.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-23 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r755067122



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchSize.java
##
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batchSize param. */
+public interface HasBatchSize extends WithParams {
+
+Param BATCH_SIZE =
+new IntParam(
+"batchSize", "Batch size of training algorithms.", 100, 
ParamValidators.gt(0));

Review comment:
   Hi Dong, thanks for the comments. The doc you mentioned above is talking 
about setting batch size for a single machine with multiple cores. (Since the 
number of cores is usually power of 2.)
   
   However, doing machine learning on big data engines like Spark/Flink is a 
bit different from doing machine learning on a single machine --- We would have 
arbitary number of parallel instances.
   
   I have no preference on the default value of batch size. If you prefer using 
32 as the default value, please let me know.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-22 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r754834016



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap;
+
+private Table model;
+
+public LogisticRegressionModel(Map, Object> paramMap) {
+this.paramMap = paramMap;
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+public LogisticRegressionModel() {
+this(new HashMap<>());
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
model).getTableEnvironment();
+String dataPath = ReadWriteUtils.getDataPath(path);
+FileSink sink =
+FileSink.forRowFormat(new Path(dataPath), new 
LogisticRegressionModelDataEncoder())
+.withRollingPolicy(OnCheckpointRollingPolicy.build())
+.withBucketAssigner(new BasePathBucketAssigner<>())
+.build();
+ReadWriteUtils.saveMetadata(this, path);
+tEnv.toDataStream(model)
+.map(x -> (LogisticRegressionModelData) x.getField(0))
+.sinkTo(sink)
+.setParallelism(1);
+}
+
+public static LogisticRegressionModel load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+Source source =
+FileSource.forRecordStreamFormat(
+new LogisticRegressionModelDataStreamFormat(),
+ReadWriteUtils.getDataPaths(path))
+  

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-22 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753980720



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap;
+
+private Table model;
+
+public LogisticRegressionModel(Map, Object> paramMap) {
+this.paramMap = paramMap;
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+public LogisticRegressionModel() {
+this(new HashMap<>());

Review comment:
   It is called in `LogisticRegressionModel(Map xxx)`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-21 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753980720



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder;
+import 
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+implements Model,
+LogisticRegressionModelParams {
+
+private Map, Object> paramMap;
+
+private Table model;
+
+public LogisticRegressionModel(Map, Object> paramMap) {
+this.paramMap = paramMap;
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+public LogisticRegressionModel() {
+this(new HashMap<>());

Review comment:
   It is called in `LogisticRegressionModel(Map xxx)`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-21 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753651488



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java
##
@@ -0,0 +1,594 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.AllReduceUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/** This class implements methods to train a logistic regression model. */
+public class LogisticRegression
+implements Estimator,
+LogisticRegressionParams {
+
+Map, Object> paramMap;
+
+private static final OutputTag> MODEL_OUTPUT =
+new OutputTag>("MODEL_OUTPUT") {};
+
+public LogisticRegression(Map, Object> paramMap) {
+this.paramMap = paramMap;
+ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+}
+
+public LogisticRegression() {
+this(new HashMap<>());
+}
+
+@Override
+public Map, Object> getParamMap() {
+return paramMap;
+}
+
+@Override
+public void save(String path) throws IOException {
+ReadWriteUtils.saveMetadata(this, path);
+}
+
+public static LogisticRegression load(StreamExecutionEnvironment env, 
String path)
+throws IOException {
+r

[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753653282



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasPredictionDetailCol.java
##
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared prediction detail param. */
+public interface HasPredictionDetailCol extends WithParams {

Review comment:
   Logistic regression aims to predict a label for each training data. Note 
that `predictionCol` is used to report the label for each data point, 
`PredictionDetailCol` is used to output the probablity of assigning this data 
point to each class. 
   
   The output should be a double array. For binary classification here, it is a 
double array with two elements.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753652954



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasVectorCol.java
##
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared vector column param. */
+public interface HasVectorCol extends WithParams {

Review comment:
   I have removed this param and reused `HasFeaturesCol`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753652891



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasMaxIter.java
##
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared maxIteration param. */
+public interface HasMaxIter extends WithParams {

Review comment:
   Thanks for the updates. I have moved all of them to param.**




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753652504



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasBatchSize.java
##
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batchSize param. */
+public interface HasBatchSize extends WithParams {

Review comment:
   Batch size is a terminology/concept in machine learning. Please also 
refer to Sklearn/Tensorflow/PyTorch.
   
   Spark used `miniBatchFraction` because it is expensive to sample a fixed 
number of data points if you can only access the data once. It is a compromise 
for performance.
   
   But in FlinkML, we have the opportunity to cache data by ourselves 
efficiently. Thus we should probably go back to `batch size.`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753652276



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasLearningRate.java
##
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared learning rate param. */
+public interface HasLearningRate extends WithParams {

Review comment:
   Learning rate is a terminology/concept in machine learning, please refer 
to Sklearn[1] or Tensorflow[2].
   
   Spark used stepSize. But it is not consistent with other ML libraries.
   
   [1] 
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html
   [2] https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753652276



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasLearningRate.java
##
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param.linear;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared learning rate param. */
+public interface HasLearningRate extends WithParams {

Review comment:
   Learning rate is concept in machine learning, please refer to Sklearn[1] 
or Tensorflow[2].
   
   Spark used stepSize. But it is not consistent with other ML libraries.
   
   [1] 
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html
   [2] https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression

2021-11-20 Thread GitBox


zhipeng93 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r753651901



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java
##
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.common.linalg.BLAS;
+
+import java.io.Serializable;
+
+/** Logistic gradient. */
+public class LogisticGradient implements Serializable {
+private static final long serialVersionUID = 1178693053439209380L;
+
+/** L1 regularization term. */
+private final double l1;
+
+/** L2 regularization term. */
+private final double l2;
+
+public LogisticGradient(double l1, double l2) {
+this.l1 = l1;
+this.l2 = l2;
+}
+
+/**
+ * Computes loss and weightSum on a set of samples.
+ *
+ * @param labeledData a sample set of train data.
+ * @param coefficient model parameters.
+ * @return loss and weightSum.
+ */
+public final Tuple2 computeLoss(
+Iterable> labeledData, double[] 
coefficient) {
+double weightSum = 0.0;
+double lossSum = 0.0;
+double loss;
+for (Tuple3 dataPoint : labeledData) {
+loss = computeLoss(dataPoint, coefficient);
+lossSum += loss * dataPoint.f0;
+weightSum += dataPoint.f0;
+}
+if (Double.compare(0, Math.abs(l1)) != 0) {

Review comment:
   Thanks for the comments. L1 is removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




  1   2   >