[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r771436858 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java ## @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.feature; + +import org.apache.flink.ml.linalg.DenseVector; + +/** Utility class to represent a data point that contains features, label and weight. */ +public class LabeledPointWithWeight { + +private DenseVector features; + +private double label; + +private double weight; + +public LabeledPointWithWeight(DenseVector features, double label, double weight) { +this.features = features; +this.label = label; +this.weight = weight; +} + +public LabeledPointWithWeight() {} + +public DenseVector getFeatures() { Review comment: According to the POJO rules [1], we don't have to use setter/getter for those variables. We can just make these variables public. [1] https://nightlies.apache.org/flink/flink-docs-release-1.10/dev/types_serialization.html#rules-for-pojo-types -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r771390383 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java ## @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.feature; + +import org.apache.flink.ml.linalg.DenseVector; + +/** Utility class to represent a data point that contains features, label and weight. */ +public class LabeledPointWithWeight { + +private DenseVector features; + +private double label; + +private double weight; + +public LabeledPointWithWeight(DenseVector features, double label, double weight) { +this.features = features; +this.label = label; +this.weight = weight; +} + +public LabeledPointWithWeight() {} + +public DenseVector getFeatures() { Review comment: The addition of these `getXXX(...)` and `setXXX(...)` seems very verbose. Since we don't expect these `setXXX(...)` to be called, would it be simpler to just make these fields `public`? We can add comment above saying `These fields' values are not expected to be updated. They are not final in order to make this class POJO`. Same for other `XXXModelData` classes. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r771390383 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/feature/LabeledPointWithWeight.java ## @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.feature; + +import org.apache.flink.ml.linalg.DenseVector; + +/** Utility class to represent a data point that contains features, label and weight. */ +public class LabeledPointWithWeight { + +private DenseVector features; + +private double label; + +private double weight; + +public LabeledPointWithWeight(DenseVector features, double label, double weight) { +this.features = features; +this.label = label; +this.weight = weight; +} + +public LabeledPointWithWeight() {} + +public DenseVector getFeatures() { Review comment: The addition of these `getXXX(...)` and `setXXX(...)` seems very verbose. Some we don't expect these `setXXX(...)` to be called, would it be simpler to just make these fields `public`? We can add comment above saying `These fields' values are not expected to be updated. They are not final in order to make this class POJO`. Same for other `XXXModelData` classes. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r771379887 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java ## @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LogisticRegressionModel}. + * + * This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LogisticRegressionModelData { + +public final DenseVector coefficient; + +public LogisticRegressionModelData(DenseVector coefficient) { +this.coefficient = coefficient; +} + +/** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ +public static DataStream getModelDataStream(Table modelData) { Review comment: > For small objects, I think we should still allow kryo? Could you be specific which small object requires kryo? If we do have a reasonable use-case to use kryo, then it is OK not to make this change in unit tests. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r771203645 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java ## @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LogisticRegressionModel}. + * + * This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LogisticRegressionModelData { + +public final DenseVector coefficient; + +public LogisticRegressionModelData(DenseVector coefficient) { +this.coefficient = coefficient; +} + +/** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ +public static DataStream getModelDataStream(Table modelData) { Review comment: I see. Could we make all *ModelData* classes POJO? I think it is possible for model data have large size. BTW, in order to make sure that we don't accidentally use kryo serializer, how about we set `env.getConfig().disableGenericTypes()` in every algorithm's test? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [FLINK-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r768460063 ## File path: flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java ## @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests the {@link BLAS}. */ +public class BLASTest { + +private static final double TOLERANCE = 1e-7; + +private DenseVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5); Review comment: Could this be `final`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r767506178 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java ## @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LogisticRegressionModel}. + * + * This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LogisticRegressionModelData { + +public final DenseVector coefficient; + +public LogisticRegressionModelData(DenseVector coefficient) { +this.coefficient = coefficient; +} + +/** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ +public static DataStream getModelDataStream(Table modelData) { Review comment: @gaoyunhaii is this OK to use `LogisticRegressionModelData` as the DataStream element type when `LogisticRegressionModelData` is not serializable? If we use `LogisticRegressionModelData` as DataStream element type, would Flink automatically use `DenseVectorSerializer` to serialize/de-serialize `coefficient`? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java ## @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.a
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r766361133 ## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/linear/LogisticRegressionTest.java ## @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ +public class LogisticRegressionTest { + +@Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + +private StreamExecutionEnvironment env; + +private StreamTableEnvironment tEnv; + +private static final List binomialTrainData = +Arrays.asList( +Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), +Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), +Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), +Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), +Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), +Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), +Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), +Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), +Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), +Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + +private static final List multinomialTrainData = +Arrays.asList( +Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), +Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), +Row.of(Vectors.dense(3, 2, 3, 4), 2., 3.), +Row.of(Vectors.dense(4, 2, 3, 4), 2., 4.), +Row.of(Vectors.dense(5, 2, 3, 4), 2., 5.), +Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), +Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), +Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), +Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), +Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + +private static final double[] expectedCoefficient = +new double[] {0.528, -0.286, -0.429, -0.572}; + +private static final double TOLERANCE = 1e-7; + +private Table binomialDataTable; + +private Table multinomialDataTable; + +@Before +public void before() { +Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); +env = StreamExecutionEnvironment.getExecutionEnvironment(config); +env.setParallelism(4); +env.enableCheckpointing(100); +env.setRestartStrategy(RestartStrategies.noRestart()); +tEnv = StreamTableEnvironment.create(env); +Collections.shuffle(binomialTrainData); +binomialDataTable = +tEnv.fromDataStream( +env.fromCollection(
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r765456269 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java ## @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** Model data of {@link LogisticRegressionModel}. */ +public class LogisticRegressionModelData { + +public final DenseVector coefficient; + +public LogisticRegressionModelData(DenseVector coefficient) { +this.coefficient = coefficient; +} + +/** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ +public static DataStream getModelDataStream(Table modelData) { +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment(); +return tEnv.toDataStream(modelData).map(x -> (LogisticRegressionModelData) x.getField(0)); +} + +/** Data encoder for {@link LogisticRegressionModel}. */ +public static class ModelDataEncoder implements Encoder { + +@Override +public void encode(LogisticRegressionModelData modelData, OutputStream outputStream) +throws IOException { +DenseVectorSerializer serializer = new DenseVectorSerializer(); +serializer.serialize( +modelData.coefficient, new DataOutputViewStreamWrapper(outputStream)); +} +} + +/** Data decoder for {@link LogisticRegressionModel}. */ +public static class ModelDataDecoder extends SimpleStreamFormat { + +@Override +public Reader createReader( +Configuration configuration, FSDataInputStream inputStream) { +return new Reader() { + +@Override +public LogisticRegressionModelData read() throws IOException { +DenseVectorSerializer serializer = new DenseVectorSerializer(); Review comment: Could we make `serializer` a private member field of `Reader` so that we don't instantiate `serializer` repeatedly for every invocation of read()? Same for `ModelDataEncoder::encode()`. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java ## @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r764646973 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table modelData; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +ReadWriteUtils.saveModelData( +LogisticRegressionModelData.getModelDataStream(modelData), +path, +new LogisticRegressionModelData.ModelDataEncoder()); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); +Table modelData = +ReadWriteUtils.loadModelData( +env, path, new LogisticRegressionModelData.ModelDataDecoder()); +return model.setModelData(modelData); +} + +@Override +public LogisticRegressionModel setModelData(Table... inputs) { +modelData = inputs[0]; +return this; +} + +@Override +public Table[] getModelData() { +return new Table[] {modelData}; +} + +@Override +@SuppressWarnings("unchecked") +public Table[] transform(Table... inputs) { +Preconditions.checkArgument(inputs.length == 1); +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); +DataStream inputStream = tEnv.toDataStream(inputs[0]); +final String broadcastModelKey = "broadcastModelKey"; +DataStream modelData = +LogisticRegressionModelData.getModelDataStream(this.modelData); +RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); +RowTypeInfo outputTypeInfo = +new RowTypeInfo( +ArrayUtils.addAll( +inputTypeInfo.getFieldTypes(), +BasicTypeI
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r764646164 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria -implements FlatMapFunction, IterationListener { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol +implements IterationListener, FlatMapFunction { + +private final int maxIter; -private final int maxRound; +private final double tol; -public RoundBasedTerminationCriteria(int maxRound) { -this.maxRound = maxRound; +private double loss = Double.NEGATIVE_INFINITY; + +public TerminateOnMaxIterOrTol(int maxIter, double tol) { +this.maxIter = maxIter; +this.tol = tol; +} + +public TerminateOnMaxIterOrTol(int maxIter) { +this.maxIter = maxIter; +this.tol = Double.NEGATIVE_INFINITY; +} + +public TerminateOnMaxIterOrTol(double tol) { +this.maxIter = Integer.MAX_VALUE; +this.tol = tol; } @Override -public void flatMap(EpochRecord integer, Collector collector) throws Exception {} +public void flatMap(Double value, Collector out) { +this.loss = value; Review comment: Sounds good. Let's throw exception in this case. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762744248 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table modelData; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +ReadWriteUtils.saveModelData( +LogisticRegressionModelData.getModelDataStream(modelData), +path, +new LogisticRegressionModelData.ModelDataEncoder()); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); +Table modelData = +ReadWriteUtils.loadModelData( +env, path, new LogisticRegressionModelData.ModelDataDecoder()); +return model.setModelData(modelData); +} + +@Override +public LogisticRegressionModel setModelData(Table... inputs) { +modelData = inputs[0]; +return this; +} + +@Override +public Table[] getModelData() { +return new Table[] {modelData}; +} + +@Override +@SuppressWarnings("unchecked") +public Table[] transform(Table... inputs) { +Preconditions.checkArgument(inputs.length == 1); +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); +DataStream inputStream = tEnv.toDataStream(inputs[0]); +final String broadcastModelKey = "broadcastModelKey"; +DataStream modelData = +LogisticRegressionModelData.getModelDataStream(this.modelData); +RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); +RowTypeInfo outputTypeInfo = +new RowTypeInfo( +ArrayUtils.addAll( +inputTypeInfo.getFieldTypes(), +BasicTypeI
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762744248 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table modelData; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +ReadWriteUtils.saveModelData( +LogisticRegressionModelData.getModelDataStream(modelData), +path, +new LogisticRegressionModelData.ModelDataEncoder()); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); +Table modelData = +ReadWriteUtils.loadModelData( +env, path, new LogisticRegressionModelData.ModelDataDecoder()); +return model.setModelData(modelData); +} + +@Override +public LogisticRegressionModel setModelData(Table... inputs) { +modelData = inputs[0]; +return this; +} + +@Override +public Table[] getModelData() { +return new Table[] {modelData}; +} + +@Override +@SuppressWarnings("unchecked") +public Table[] transform(Table... inputs) { +Preconditions.checkArgument(inputs.length == 1); +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); +DataStream inputStream = tEnv.toDataStream(inputs[0]); +final String broadcastModelKey = "broadcastModelKey"; +DataStream modelData = +LogisticRegressionModelData.getModelDataStream(this.modelData); +RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); +RowTypeInfo outputTypeInfo = +new RowTypeInfo( +ArrayUtils.addAll( +inputTypeInfo.getFieldTypes(), +BasicTypeI
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762744248 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table modelData; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +ReadWriteUtils.saveModelData( +LogisticRegressionModelData.getModelDataStream(modelData), +path, +new LogisticRegressionModelData.ModelDataEncoder()); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); +Table modelData = +ReadWriteUtils.loadModelData( +env, path, new LogisticRegressionModelData.ModelDataDecoder()); +return model.setModelData(modelData); +} + +@Override +public LogisticRegressionModel setModelData(Table... inputs) { +modelData = inputs[0]; +return this; +} + +@Override +public Table[] getModelData() { +return new Table[] {modelData}; +} + +@Override +@SuppressWarnings("unchecked") +public Table[] transform(Table... inputs) { +Preconditions.checkArgument(inputs.length == 1); +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); +DataStream inputStream = tEnv.toDataStream(inputs[0]); +final String broadcastModelKey = "broadcastModelKey"; +DataStream modelData = +LogisticRegressionModelData.getModelDataStream(this.modelData); +RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); +RowTypeInfo outputTypeInfo = +new RowTypeInfo( +ArrayUtils.addAll( +inputTypeInfo.getFieldTypes(), +BasicTypeI
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762740871 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria -implements FlatMapFunction, IterationListener { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol +implements IterationListener, FlatMapFunction { + +private final int maxIter; -private final int maxRound; +private final double tol; -public RoundBasedTerminationCriteria(int maxRound) { -this.maxRound = maxRound; +private double loss = Double.NEGATIVE_INFINITY; + +public TerminateOnMaxIterOrTol(int maxIter, double tol) { +this.maxIter = maxIter; +this.tol = tol; +} + +public TerminateOnMaxIterOrTol(int maxIter) { +this.maxIter = maxIter; +this.tol = Double.NEGATIVE_INFINITY; +} + +public TerminateOnMaxIterOrTol(double tol) { +this.maxIter = Integer.MAX_VALUE; +this.tol = tol; } @Override -public void flatMap(EpochRecord integer, Collector collector) throws Exception {} +public void flatMap(Double value, Collector out) { +this.loss = value; Review comment: The issue here is that there is no rule forcing `only one loss value in each epoch`. If there should indeed be `only one loss value in each epoch`, could we throw exception in `flatMap()` to enforce it? I think it is better not to limit the number of input values in each epoch so that this operator can also be used in asynchronous iteration. And if we agree to do so, we probably need to use the minimum value here. ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria -implements FlatMapFunction, IterationListener { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol +implements IterationListener, FlatMapFunction { + +private final int maxIter; -private final int maxRound; +private final double tol; -public RoundBasedTerminationCriteria(int maxRound) { -this.maxRound = maxRound; +private double loss = Double.NEGATIVE_INFINITY; Review comment: Sounds good. Thanks. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762739857 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria -implements FlatMapFunction, IterationListener { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol +implements IterationListener, FlatMapFunction { + +private final int maxIter; -private final int maxRound; +private final double tol; -public RoundBasedTerminationCriteria(int maxRound) { -this.maxRound = maxRound; +private double loss = Double.NEGATIVE_INFINITY; + +public TerminateOnMaxIterOrTol(int maxIter, double tol) { +this.maxIter = maxIter; +this.tol = tol; +} + +public TerminateOnMaxIterOrTol(int maxIter) { Review comment: Hmmm.. It is not clear to me why `it is not a complete one` if we require the user to always provider `loss` here. Could you clarify? From caller's perspective, caller can just use `TerminateOnMaxIter` if he/she does not intend to use `loss` here. Is there concrete benefit of adding this constructor? The benefit of removing this constructor is to avoid having two ways of doing the same thing and reduce the overall complexity of the code. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762683610 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table modelData; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +ReadWriteUtils.saveModelData( +LogisticRegressionModelData.getModelDataStream(modelData), +path, +LogisticRegressionModelData.getModelDataEncoder()); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); +Table modelData = +ReadWriteUtils.loadModelData( +env, path, LogisticRegressionModelData.getModelDataDecoder()); +return model.setModelData(modelData); +} + +@Override +public LogisticRegressionModel setModelData(Table... inputs) { +modelData = inputs[0]; +return this; +} + +@Override +public Table[] getModelData() { +return new Table[] {modelData}; +} + +@Override +@SuppressWarnings("unchecked") +public Table[] transform(Table... inputs) { +Preconditions.checkArgument(inputs.length == 1); +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); +DataStream inputStream = tEnv.toDataStream(inputs[0]); +final String broadcastModelKey = "broadcastModelKey"; +DataStream modelData = +LogisticRegressionModelData.getModelDataStream(this.modelData); +RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); +RowTypeInfo outputTypeInfo = +new RowTypeInfo( +ArrayUtils.addAll( +inputTypeInfo.getFieldTypes(), +BasicTypeInf
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r762673330 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria -implements FlatMapFunction, IterationListener { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol +implements IterationListener, FlatMapFunction { + +private final int maxIter; -private final int maxRound; +private final double tol; -public RoundBasedTerminationCriteria(int maxRound) { -this.maxRound = maxRound; +private double loss = Double.NEGATIVE_INFINITY; + +public TerminateOnMaxIterOrTol(int maxIter, double tol) { +this.maxIter = maxIter; +this.tol = tol; +} + +public TerminateOnMaxIterOrTol(int maxIter) { Review comment: nits: It seems simpler to remove this constructor and let user use `TerminateOnMaxIter` instead. ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ## @@ -16,29 +16,53 @@ * limitations under the License. */ -package org.apache.flink.test.iteration.operators; +package org.apache.flink.ml.common.iteration; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.iteration.IterationListener; import org.apache.flink.util.Collector; -/** An termination criteria function that asks to stop after the specialized round. */ -public class RoundBasedTerminationCriteria -implements FlatMapFunction, IterationListener { +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol +implements IterationListener, FlatMapFunction { + +private final int maxIter; -private final int maxRound; +private final double tol; -public RoundBasedTerminationCriteria(int maxRound) { -this.maxRound = maxRound; +private double loss = Double.NEGATIVE_INFINITY; Review comment: What happens if there are not input values in the first epoch? Do we expect the iteration to terminate (which seems to be the case with `loss = Double.NEGATIVE_INFINITY`)? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUti
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761945296 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMultiClass.java ## @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared multi-class param. + * + * Supported options: + * auto: selects the classification type based on the number of classes: If numClasses is one or Review comment: nits: since we don't have any API or variable named `numClasses`, it may be a slightly better to say `number of classes` and explain how it is derived. How about use comments like this: ``` auto: selects the classification type based on the number of classes: If the number of unique label values from the input data is one or two, set to "binomial". Otherwise, set to "multinomial". ``` ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java ## @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Serializable; + +/** Model data of {@link LogisticRegressionModel}. */ +public class LogisticRegressionModelData implements Serializable { Review comment: Could this `implements Serializable` be removed? Passing model data as `DataStream` seems to be more efficient and straightforward than passing it as `DataStream`. The meaning of streams can be specified as variable name. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java ## @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific langu
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761940006 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761937399 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761937754 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag MODEL_OUTPUT = Review comment: Sounds good. Thanks for the update. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761937399 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761937399 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761907459 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMultiClass.java ## @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared multi-class param. + * + * Supported options: + * auto: select the version based on the number of classes: If numClasses is one or two, set to Review comment: Thanks for the explanation. Sounds good. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761597739 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761600036 ## File path: flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java ## @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.common.typeutils.base.LongComparator; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.Collector; +import org.apache.flink.util.NumberSequenceIterator; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; + +/** Tests the {@link DataStreamUtils}. */ +public class DataStreamUtilsTest { +private StreamExecutionEnvironment env; + +@Before +public void before() { +Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); +env = StreamExecutionEnvironment.getExecutionEnvironment(config); +env.setParallelism(4); +env.enableCheckpointing(100); +env.setRestartStrategy(RestartStrategies.noRestart()); +} + +@Test +@SuppressWarnings("unchecked") +public void testDistinct() throws Exception { +DataStream dataStream = +env.fromParallelCollection(new NumberSequenceIterator(1L, 10L), Types.LONG) Review comment: I see. Sounds good. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761599850 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag MODEL_OUTPUT = Review comment: Hmm... it looks like the iteration body is all round and therefore the `MODEL_OUTPUT` would be created only once even if we pass it as local variable. Thus making it a global static variable would not improve performance. Did I miss something here? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761597962 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { Review comment: Thanks. Let's close this comment then. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761597739 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761597739 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761597164 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input element. + * @return The sorted data stream. + */ +p
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761595720 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIter.java ## @@ -27,16 +27,17 @@ * threshold. * * When the output of this FlatMapFunction is used as the termination criteria of an iteration - * body, the iteration will be executed for at most the given `numRounds` rounds. + * body, the iteration will be executed for at most the given `maxIter` rounds. * * @param The class type of the input element. */ -public class TerminateOnMaxIterationNum +public class TerminateOnMaxIter Review comment: Can we remove the `RoundBasedTerminationCriteria` entirely, and use `TerminateOnMaxIter` in the test code instead? "replace A with B" means that we can remove `A` and use `B` where `A` was needed. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r761594701 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { Review comment: Thanks. Sounds good. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r760717691 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag MODEL_OUTPUT = +new OutputTag("MODEL_OUTPUT") {}; + +private static final OutputTag LOSS_OUTPUT = +new OutputTag("LOSS_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.loadStageParam(path); +} + +@Override +@SuppressWarnings("rawTypes") +public LogisticRegressionModel fit(Table... inputs) { +Preconditions.checkArgument(inputs.length == 1); +Preconditions.checkArgument( +"auto".equals(g
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r759854126 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r759853261 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r759035730 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r759030707 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r759029510 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r759028197 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758920274 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758920274 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758946526 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ## @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.ml.common.utils.ComparatorAdapter; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { +/** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * Note that we throw exception when one of the following two cases happen: + * There exists one partition that contains more than one double array. + * The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ +public static DataStream allReduceSum(DataStream input) { +return AllReduceImpl.allReduceSum(input); +} + +/** + * Collects distinct values in a bounded data stream. The parallelism of the output stream is 1. + * + * @param The class type of the input data stream. + * @param input The bounded input data stream. + * @return The result data stream that contains all the distinct values. + */ +public static DataStream distinct(DataStream input) { +return input.transform( +"distinctInEachPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(input.getParallelism()) +.transform( +"distinctInFinalPartition", +input.getType(), +new DistinctPartitionOperator<>()) +.setParallelism(1); +} + +/** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param The class type of the input element. + * @param The class type of output element. + * @return The result data stream. + */ +public static DataStream mapPartition( +DataStream input, MapPartitionFunction func) { +TypeInformation resultType = +TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); +return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) +.setParallelism(input.getParallelism()); +} + +/** + * Sorts the elements in each partition of the input bounded data stream. + * + * @param input The input data stream. + * @param comparator The comparator used to sort the elements. + * @param The class type of input elem
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758922576 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758920274 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758909062 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedMultiInput; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.util.Collector; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * This class implements methods to train a logistic regression model. For details, see + * https://en.wikipedia.org/wiki/Logistic_regression. + */ +public class LogisticRegression +implements Estimator, +LogisticRegressionParams { + +private Map, Object> paramMap = new HashMap<>(); + +private static final OutputTag> MODEL_OUTPUT = +new OutputTag>("MODEL_OUTPUT") {}; + +public LogisticRegression() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +ReadWriteUtils.saveMetadata(this, path); +} + +public static LogisticRegression load(StreamExecutionEnvironment env, String path) +throws IOException { +return ReadWriteUtils.load
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758878889 ## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/linear/LogisticRegressionTest.java ## @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ +public class LogisticRegressionTest { + +private StreamExecutionEnvironment env; + +private StreamTableEnvironment tEnv; + +private static List trainData = +Arrays.asList( +Row.of(new double[] {1, 2, 3, 4}, -1., 1.), Review comment: Since `LogisticRegression` supports `weights`, would it be better to use different weights in the test data to provide more test coverage? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegression.java ## @@ -0,0 +1,653 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.base.DoubleComparator; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.IterationConfig; +import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.ReplayableDataStreamList; +import org.apache.flink.iteration.operator.OperatorStateUt
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758414290 ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java ## @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.util.Preconditions; + +/** A utility class that provides BLAS routines over matrices and vectors. */ +public class BLAS { + +private static final dev.ludovic.netlib.BLAS NATIVE_BLAS = +dev.ludovic.netlib.BLAS.getInstance(); + +/** + * \sum_i |x_i| . + * + * @param x x + * @return + */ +public static double asum(double[] x) { +return NATIVE_BLAS.dasum(x.length, x, 0, 1); +} + +/** + * y += a * x . + * + * @param a a + * @param x x + * @param y y + */ +public static void axpy(double a, double[] x, double[] y) { Review comment: In the future will we need to support `axpy(a: Double, x: SparseVector, y: DenseVector)`? If so, should we make this `axpy(a: Double, x: DenseVector, y: DenseVector)`? Same for the methods in BLAS. ## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java ## @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.util.Preconditions; + +/** A utility class that provides BLAS routines over matrices and vectors. */ +public class BLAS { + +private static final dev.ludovic.netlib.BLAS NATIVE_BLAS = +dev.ludovic.netlib.BLAS.getInstance(); + +/** + * \sum_i |x_i| . + * + * @param x x + * @return + */ +public static double asum(double[] x) { +return NATIVE_BLAS.dasum(x.length, x, 0, 1); +} + +/** + * y += a * x . + * + * @param a a + * @param x x + * @param y y + */ +public static void axpy(double a, double[] x, double[] y) { Review comment: In the future will we need to support `axpy(a: Double, x: SparseVector, y: DenseVector)`? If so, should we make this `axpy(a: Double, x: DenseVector, y: DenseVector)`? Same for other methods in BLAS. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r758296725 ## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/linear/LogisticRegressionTest.java ## @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ +public class LogisticRegressionTest { + +private StreamExecutionEnvironment env; + +private StreamTableEnvironment tEnv; + +private static List trainData = +Arrays.asList( +Row.of(new double[] {1, 2, 3, 4}, -1., 1.), +Row.of(new double[] {1, 2, 3, 4}, -1., 1.), +Row.of(new double[] {1, 2, 3, 4}, -1., 1.), +Row.of(new double[] {1, 2, 3, 4}, -1., 1.), +Row.of(new double[] {1, 2, 3, 4}, -1., 1.), +Row.of(new double[] {10, 2, 3, 4}, 1., 1.), +Row.of(new double[] {10, 2, 3, 4}, 1., 1.), +Row.of(new double[] {10, 2, 3, 4}, 1., 1.), +Row.of(new double[] {10, 2, 3, 4}, 1., 1.), +Row.of(new double[] {10, 2, 3, 4}, 1., 1.)); + +private static double[] expectedCoefficient = new double[] {-0.67, 0.21, 0.32, 0.43}; Review comment: nits: `private static final` ## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/linear/LogisticRegressionTest.java ## @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757255599 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java ## @@ -146,7 +145,7 @@ public IterationBodyResult process( DataStream points = dataStreams.get(0); DataStream terminationCriteria = -centroids.flatMap(new TerminateOnMaxIterationNum<>(maxIterationNum)); +centroids.map(x -> 0.).flatMap(new TerminationCriteria(maxIterationNum)); Review comment: How about we have `TerminateOnMaxIter` and `TerminateOnMaxIterOrTol`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757247408 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.core.Model; +import org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder; +import org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table model; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) model).getTableEnvironment(); +String dataPath = ReadWriteUtils.getDataPath(path); +FileSink sink = +FileSink.forRowFormat(new Path(dataPath), new LogisticRegressionModelDataEncoder()) +.withRollingPolicy(OnCheckpointRollingPolicy.build()) +.withBucketAssigner(new BasePathBucketAssigner<>()) +.build(); +ReadWriteUtils.saveMetadata(this, path); +tEnv.toDataStream(model) +.map(x -> (LogisticRegressionModelData) x.getField(0)) +.sinkTo(sink) +.setParallelism(1); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); +Source source = +FileSource.forRecordStreamFormat( +new LogisticRegressionModelDataStreamFormat(), +ReadWriteUtils.getDataPaths(path)) +.build(); +LogisticRegressionModel model
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757252256 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java ## @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.linalg.BLAS; + +import java.io.Serializable; + +/** Utility class to compute gradient and loss for logistic loss. */ Review comment: The link provided by Spark links to slides, which is not very easy to glance through. It would be sufficient if we can have a link to wikipedia with the necessary mathematical formulas. I am OK to just keep the slideshare link for now. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757247408 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java ## @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.core.Model; +import org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder; +import org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.linalg.BLAS; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** This class implements {@link Model} for {@link LogisticRegression}. */ +public class LogisticRegressionModel +implements Model, +LogisticRegressionModelParams { + +private Map, Object> paramMap = new HashMap<>(); + +private Table model; + +public LogisticRegressionModel() { +ParamUtils.initializeMapWithDefaultValues(this.paramMap, this); +} + +@Override +public Map, Object> getParamMap() { +return paramMap; +} + +@Override +public void save(String path) throws IOException { +StreamTableEnvironment tEnv = +(StreamTableEnvironment) ((TableImpl) model).getTableEnvironment(); +String dataPath = ReadWriteUtils.getDataPath(path); +FileSink sink = +FileSink.forRowFormat(new Path(dataPath), new LogisticRegressionModelDataEncoder()) +.withRollingPolicy(OnCheckpointRollingPolicy.build()) +.withBucketAssigner(new BasePathBucketAssigner<>()) +.build(); +ReadWriteUtils.saveMetadata(this, path); +tEnv.toDataStream(model) +.map(x -> (LogisticRegressionModelData) x.getField(0)) +.sinkTo(sink) +.setParallelism(1); +} + +public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) +throws IOException { +StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); +Source source = +FileSource.forRecordStreamFormat( +new LogisticRegressionModelDataStreamFormat(), +ReadWriteUtils.getDataPaths(path)) +.build(); +LogisticRegressionModel model
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757245219 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java ## @@ -146,7 +145,7 @@ public IterationBodyResult process( DataStream points = dataStreams.get(0); DataStream terminationCriteria = -centroids.flatMap(new TerminateOnMaxIterationNum<>(maxIterationNum)); +centroids.map(x -> 0.).flatMap(new TerminationCriteria(maxIterationNum)); Review comment: IMO, for users who wants to termination iteration based on round number, asking them to map input to double-typed values seems unnecessary and might confuse users. And if we have two separate classes such as `TerminateOnMaxIterationNum ` and `TerminateOnToleranceThreshold`, given that users already have concepts of these two types of termination pattern, the class names seem to be pretty self-explanatory and intuitive. What do you think? Or maybe we can wait for @gaoyunhaii comment? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757242053 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/SortPartitionImpl.java ## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.ml.common.utils.ComparatorAdapter; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.List; + +/** Applies sortPartition to a bounded data stream. */ +class SortPartitionImpl { Review comment: The problem you described here is common to most projects (e.g. Flink). Can we follow the existing practice in e.g. Flink instead inventing new patterns? Is there any util class in popular projects like Flink/Kafka/Spark that explicitly moves the implementation of its static methods to individual `*impl` classes? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757240950 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxIter.java ## @@ -26,7 +26,7 @@ /** Interface for the shared maxIter param. */ public interface HasMaxIter extends WithParams { Param MAX_ITER = -new IntParam("maxIter", "Maximum number of iterations.", 20, ParamValidators.gtEq(0)); +new IntParam("maxIter", "Maximum number of iterations.", 20, ParamValidators.gt(0)); Review comment: OK. Then let's keep it as is. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757008784 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxIter.java ## @@ -26,7 +26,7 @@ /** Interface for the shared maxIter param. */ public interface HasMaxIter extends WithParams { Param MAX_ITER = -new IntParam("maxIter", "Maximum number of iterations.", 20, ParamValidators.gtEq(0)); +new IntParam("maxIter", "Maximum number of iterations.", 20, ParamValidators.gt(0)); Review comment: Could we just update all algorithms to support `maxIter=0` instead of disallowing it? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/SortPartitionImpl.java ## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.ml.common.utils.ComparatorAdapter; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.apache.commons.collections.IteratorUtils; + +import java.util.List; + +/** Applies sortPartition to a bounded data stream. */ +class SortPartitionImpl { Review comment: Having one dedicated file for a package private static class seems a bit overkill. This is a rare pattern in Flink. The typical approach is to put those classes in the same file as the public static method that uses them (e.g. DataStreamUtil). It might make sense to have dedicated file if the number of static class/method is more than 4 (e.g. AllReduce). Maybe it is simpler to move `SortPartitionImpl/DistinctImpl/MapPartitionImpl` to DataStreamUtils. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java ## @@ -146,7 +145,7 @@ public IterationBodyResult process( DataStream points = dataStreams.get(0); DataStream terminationCriteria = -centroids.flatMap(new TerminateOnMaxIterationNum<>(maxIterationNum)); +centroids.map(x -> 0.).flatMap(new TerminationCriteria(maxIterationNum)); Review comment: Would it be better to still allow `TerminationCriteria` to accept arbitrary input type so that the caller code does not have to explicitly convert input type to Integer? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/EndOfStreamWindows.java ## @@ -37,6 +37,8 @@ private static final EndOfStreamWindows INSTANCE = new EndOfStreamWindows(); +private static final TimeWindow FOREVER_WINDOW = new TimeWindow(Long.MIN_VALUE, Long.MAX_VALUE); Review comment: The word `forever` seems a bit `rich` and it is rare to use this word as variable name. Can we use a more neural word such as `TIME_WINDOW_INSTANCE`? The variable name does not need to be very descriptive here since its role is obvious from its constructor parameter values. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java ## @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r757005360 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchSize.java ## @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared batchSize param. */ +public interface HasBatchSize extends WithParams { + +Param BATCH_SIZE = +new IntParam( +"batchSize", "Batch size of training algorithms.", 100, ParamValidators.gt(0)); Review comment: Hmm... in addition to renaming this class, could the parameter name also be renamed to `globalBatchSize`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r756004235 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasPredictionDetailCol.java ## @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param.linear; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared prediction detail param. */ +public interface HasPredictionDetailCol extends WithParams { Review comment: Sounds good! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r756012363 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java ## @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared L2 regularization param. */ +public interface HasL2 extends WithParams { Review comment: Sounds good. Thanks! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r756011254 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchSize.java ## @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared batchSize param. */ +public interface HasBatchSize extends WithParams { + +Param BATCH_SIZE = +new IntParam( +"batchSize", "Batch size of training algorithms.", 100, ParamValidators.gt(0)); Review comment: Hmm.. I think even if we run distributed machine learning algorithm on multiple workers, each worker will still process one batch of records at a time, right? If so, we probably still want the batch size to be a multiple of the number of CPU cores on a single machine. And yes, I think batchSize=32 is probably a good default value, based on the discussion in the StackExchange link. Thanks! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r756004235 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasPredictionDetailCol.java ## @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param.linear; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared prediction detail param. */ +public interface HasPredictionDetailCol extends WithParams { Review comment: Sounds good. Let's use `HasRawPredictionCol`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754831435 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasEpsilon.java ## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared epsilon param. */ +public interface HasEpsilon extends WithParams { Review comment: Would it be better to change the name to `HasTol` here? Spark and Scikit-learn [1] uses HasTol for this purpose. Logistic Regression wiki [2] mentions tolerance instead of epsilon. I searched on Google for words that are commonly used for determining the "termination criteria". It looks like tolerance is much more popular than epsilon in the machine learning domain (e.g. [3]). [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html [2] https://en.wikipedia.org/wiki/Logistic_regression [3] https://support.minitab.com/en-us/minitab/18/help-and-how-to/modeling-statistics/regression/how-to/nonlinear-regression/interpret-the-results/all-statistics-and-graphs/methods-and-starting-values/ -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754839066 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java ## @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared L2 regularization param. */ +public interface HasL2 extends WithParams { Review comment: Ideally we can make the parameter future proof. According to the scikit-learn doc [1], it looks like the regularization can be one of `l1`, `l2` and `elasticnet`. And scikit-learn supports all choices. Though Spark provides only the `HasElasticNetParam` without explicit `l1` or `l2` choices, the parameter doc suggests that `l1` or `l2` regularization is effectively used if user sets the parameter value to be `1` or `2`. So both scikit-learn and Spark support all three modes. I guess we also want to be able to support these three modes in Flink ML, even if we support only one for now. If we add `HasL2` here, how do we expect users to specify `L1` and `elasticnet` mode in the future? Should we use a double-valued `HasElasticNetParam` like Spark, or use a string-valued `HasPenalty` similar to Scikit-learn? [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754833974 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasPredictionDetailCol.java ## @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param.linear; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared prediction detail param. */ +public interface HasPredictionDetailCol extends WithParams { Review comment: I see. Since the semantic of this column is the probability of the prediction result, would it be more intuitive to use `HasProbabilityCol` here? The word `detail` is much broader than `probability` and does not provide much information of what is in this column. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754831435 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasEpsilon.java ## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared epsilon param. */ +public interface HasEpsilon extends WithParams { Review comment: I have similar thoughts as @yunfengzhou-hub. Here are my findings that may be useful to consider here. Spark and Scikit-learn [1] uses HasTol for this purpose. Logistic Regression wiki [2] mentions tolerance instead of epsilon. I searched on Google for words that are commonly used for determining the "termination criteria". It looks like tolerance is much more popular than epsilon in the machine learning domain (e.g. [3]). [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html [2] https://en.wikipedia.org/wiki/Logistic_regression [3] https://support.minitab.com/en-us/minitab/18/help-and-how-to/modeling-statistics/regression/how-to/nonlinear-regression/interpret-the-results/all-statistics-and-graphs/methods-and-starting-values/ -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754825324 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasEpsilon.java ## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared epsilon param. */ +public interface HasEpsilon extends WithParams { + +Param EPSILON = +new DoubleParam( +"epsilon", +"Convergence tolerance for iterative algorithms. The default value is 0.1", +0.1, Review comment: I have similar thoughts as @yunfengzhou-hub. Here are my findings that may be useful to consider here. Spark and Scikit-learn [1] uses HasTol for this purpose. Logistic Regression wiki [2] mentions tolerance instead of epsilon. I searched on Google for words that are commonly used for determining the "termination criteria". It looks like tolerance is much more popular than epsilon in the machine learning domain (e.g. [3]). [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html [2] https://en.wikipedia.org/wiki/Logistic_regression [3] https://support.minitab.com/en-us/minitab/18/help-and-how-to/modeling-statistics/regression/how-to/nonlinear-regression/interpret-the-results/all-statistics-and-graphs/methods-and-starting-values/ -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754825324 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasEpsilon.java ## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared epsilon param. */ +public interface HasEpsilon extends WithParams { + +Param EPSILON = +new DoubleParam( +"epsilon", +"Convergence tolerance for iterative algorithms. The default value is 0.1", +0.1, Review comment: I have similar thoughts as @yunfengzhou-hub. Here are my findings that may be useful to consider here. Spark and Scikit-learn [1] uses HasTol for this purpose. Logistic Regression wiki [2] mentions tolerance instead of epsilon. I searched on Google for words that are commonly used for determining the "termination criteria". It looks like tolerance is much more popular than epsilon in the machine learning domain (e.g. [3]). [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html [2] https://en.wikipedia.org/wiki/Logistic_regression [3] https://support.minitab.com/en-us/minitab/18/help-and-how-to/modeling-statistics/regression/how-to/nonlinear-regression/interpret-the-results/all-statistics-and-graphs/methods-and-starting-values/ -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r754820934 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasEpsilon.java ## @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared epsilon param. */ +public interface HasEpsilon extends WithParams { + +Param EPSILON = Review comment: Spark and Scikit-learn [1] uses HasTol for this purpose. Logistic Regression wiki [2] mentions `tolerance` instead of `epsilon`. I searched on Google for words that are commonly used for determining the "termination criteria". It looks like `tolerance` is much more popular than epsilon in the machine learning domain (e.g. [3]). How about we use the same `HasTol` as Spark? BTW, @yunfengzhou-hub asked a similar question in a previous comment. That comment was closed without reply. Can we wait for the confirmation from reviewers before resolving such comments? [1] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html [2] https://en.wikipedia.org/wiki/Logistic_regression [3] https://support.minitab.com/en-us/minitab/18/help-and-how-to/modeling-statistics/regression/how-to/nonlinear-regression/interpret-the-results/all-statistics-and-graphs/methods-and-starting-values/ ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeaturesCol.java ## @@ -27,7 +27,10 @@ public interface HasFeaturesCol extends WithParams { Param FEATURES_COL = new StringParam( -"featuresCol", "Features column name.", "features", ParamValidators.notNull()); +"featuresCol", +"Name of the features column name.", Review comment: According to Google results (google `name of column or column name`), it seems that the original `...column name` is more widely used than `name of column...`. So it seems simpler to use the original `Features column name`? And the 2nd `name` seems to be redundant here. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchSize.java ## @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared batchSize param. */ +public interface HasBatchSize extends WithParams { + +Param BATCH_SIZE = +new IntParam( +"batchSize", "Batch size of training algorithms.", 100, ParamValidators.gt(0)); Review comment: How about setting the default value here to be 32? As explained in [1], batch size is typically power of 2. And according to [2], batchSize=32 could be a good starting point. [1] https://datascience.stackexchange.com/questions/20179/what-is-the-advantage-of-keeping-batch-size-a-power-of-2 -- This is an automated message from the Apache Git Service
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r753633697 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java ## @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.apache.flink.util.Preconditions; + +/** A utility class that provides BLAS routines over matrices and vectors. */ +public class BLAS { + +/** For level-1 routines, we use Java implementation. */ +private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = Review comment: Spark has updated its BLAS library dependency to use e.g. `dev.ludovic.netlib.BLAS`. We can find explanation in its commit message and JIRA description. The Naive Bayes PR follows Spark's approach. Maybe we can do the same here? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasMaxIter.java ## @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param.linear; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared maxIteration param. */ +public interface HasMaxIter extends WithParams { Review comment: Can we re-use the existing `HasMaxIter` in the package `org.apache.flink.ml.common.param`? Is there any reason we need to put these params in the package `*.linear`? It looks like more (if not all) of these parameters are general to all types of algorithms. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java ## @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.linalg.BLAS; + +import java.io.Serializable; + +/** Logistic gradient. */ +public class LogisticGradient implements Serializable { +private static final long serialVersionUID = 1178693053439209380L; + +/** L1 regularization term. */ +private final double l1; + +/** L2 regularization term. */ +private final double l2; + +public LogisticGradient(double l1, double l2) { +this.l1 = l1; +this.l2 = l2; +} + +/** + * Computes loss and weightSum on a set of samples. + * + * @param labeledData a sample set of train data. + *
[GitHub] [flink-ml] lindong28 commented on a change in pull request #28: [Flink-24556] Add Estimator and Transformer for logistic regression
lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r753633697 ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/linalg/BLAS.java ## @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.linalg; + +import org.apache.flink.util.Preconditions; + +/** A utility class that provides BLAS routines over matrices and vectors. */ +public class BLAS { + +/** For level-1 routines, we use Java implementation. */ +private static final com.github.fommil.netlib.BLAS NATIVE_BLAS = Review comment: Spark has updated its BLAS library dependency to use e.g. `dev.ludovic.netlib.BLAS`. We can find explanation in its commit message and JIRA description. The Naive Bayes PR follows Spark's approach. Maybe we can do the same here? ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/linear/HasMaxIter.java ## @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param.linear; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared maxIteration param. */ +public interface HasMaxIter extends WithParams { Review comment: Can we re-use the existing `HasMaxIter` in the package `org.apache.flink.ml.common.param`? Is there any reason we need to put these params in the package `*.linear`? It looks like more (if not all) of these parameters are general to all types of algorithms. ## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java ## @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.common.linalg.BLAS; + +import java.io.Serializable; + +/** Logistic gradient. */ +public class LogisticGradient implements Serializable { +private static final long serialVersionUID = 1178693053439209380L; + +/** L1 regularization term. */ +private final double l1; + +/** L2 regularization term. */ +private final double l2; + +public LogisticGradient(double l1, double l2) { +this.l1 = l1; +this.l2 = l2; +} + +/** + * Computes loss and weightSum on a set of samples. + * + * @param labeledData a sample set of train data. + *