[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764665262



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,277 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+private static final String LABEL_COL = "label";
+private static final String PRED_COL = "pred";
+private static final String VEC_COL = "vec";
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of(1, Vectors.dense(2.0, 3.0)),
+Row.of(1, Vectors.dense(2.1, 3.1)),
+Row.of(2, Vectors.dense(200.1, 300.1)),
+Row.of(2, Vectors.dense(200.2, 300.2)),
+Row.of(2, Vectors.dense(200.3, 300.3)),
+Row.of(2, Vectors.dense(200.4, 300.4)),
+Row.of(2, Vectors.dense(200.4, 300.4)),
+Row.of(2, Vectors.dense(200.6, 300.6)),
+Row.of(1, Vectors.dense(2.1, 3.1)),
+Row.of(1, Vectors.dense(2.1, 3.1)),
+Row.of(1, Vectors.dense(2.1, 3.1)),
+Row.of(1, Vectors.dense(2.1, 3.1)),
+Row.of(1, Vectors.dense(2.3, 3.2)),
+Row.of(1, Vectors.dense(2.3, 3.2)),
+Row.of(3, Vectors.dense(2.8, 3.2)),
+Row.of(4, Vectors.dense(300., 3.2)),
+Row.of(1, Vectors.dense(2.2, 3.2)),
+Row.of(5, Vectors.dense(2.4, 3.2)),
+Row.of(5, Vectors.dense(2.5, 3.2)),
+Row.of(5, Vectors.dense(2.5, 3.2)),
+Row.of(1, Vectors.dense(2.1, 3.1;
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of(5, Vectors.dense(4.0, 4.1)), Row.of(2, 
Vectors.dense(300, 42;
+private Table testData;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+Schema schema =
+Schema.newBuilder()
+.column("f0", DataTypes.INT())
+.column("f1", DataTypes.of(DenseVector.class))
+.build();
+
+DataStream dataStream = env.fromCollection(trainArray);
+trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + "," 
+ VEC_COL);
+
+DataStream predDataStream = env.fromCollection(testArray);
+testData = tEnv.fromDataStream(predDataStream, schema).as(LABEL_COL + 
"," + VEC_COL);
+}
+
+// Executes the graph and returns a list which has true label and predict 
label.
+private static List> executeAndCollect(Table 
output) throws Exception {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+DataStream> stream =
+tEnv.toDataStream(output)
+   

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764663947



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,255 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+private static final long serialVersionUID = 5292477422193301398L;
+private static final int ROW_SIZE = 2;
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+/** constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public Knn(Map, Object> params) {
+this.params = params;

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764661446



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764661044



##
File path: flink-ml-core/pom.xml
##
@@ -59,11 +59,11 @@ under the License.
   ${flink.version}
   test
 
-
 
-  org.apache.flink
-  flink-shaded-jackson
-  provided
+  com.google.code.gson

Review comment:
   done

##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0", 1, 0, 1.47),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("m", "200.1 300.1", 1, 0, 1.5),
+Row.of("m", "200.2 300.2", 1, 0, 2.59),
+Row.of("m", "200.3 300.3", 1, 0, 2.55),
+Row.of("m", "200.4 300.4", 1, 0, 2.53),
+Row.of("m", "200.4 300.4", 1, 0, 2.52),
+Row.of("m", "200.6 300.6", 1, 0, 2.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.56),
+Row.of("f", "2.1 3.1", 1, 0, 1.51),
+Row.of("f", "2.1 3.1", 1, 0, 1.52),
+Row.of("f", "2.3 3.2", 1, 0, 1.53),
+Row.of("f", "2.3 3.2", 1, 0, 1.54),
+Row.of("c", "2.8 3.2", 3, 0, 1.6),
+Row.of("d", "300. 3.2", 5, 0, 1.5),
+Row.of("f", "2.2 3.2", 1, 0, 1.5),
+Row.of("e", "2.4 3.2", 2, 0, 1.3),
+Row.of("e", "2.5 3.2", 2, 0, 1.4),
+Row.of("e", "2.5 3.2", 2, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+private Table testData;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+DataStream dataStream =
+env.fromCollection(
+trainArray,
+new RowTypeInfo(
+new TypeInformation[] {
+Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+},
+new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+trainData = tEnv.fromDataStream(dataStream);
+
+DataStream dataStreamStr =
+env.fromCollection(
+testArray,
+new RowTypeInfo(
+new TypeInformation[] {
+  

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764660930



##
File path: flink-ml-lib/pom.xml
##
@@ -65,6 +65,12 @@ under the License.
   test
 
 
+
+  com.github.fommil.netlib

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764660309



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764659636



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+/** Row dimension. */
+public int numRows;
+
+/** Column dimension. */
+public int numCols;
+
+/**
+ * Array for internal storage of elements.
+ *
+ * The matrix data is stored in column major format internally.
+ */
+public double[] values;
+
+/**
+ * Constructs an m-by-n matrix of zeros.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of columns.
+ */
+public DenseMatrix(int numRows, int numCols) {
+this(numRows, numCols, new double[numRows * numCols]);
+}
+
+/**
+ * Constructs a matrix from a 1-D array. The data in the array should 
organize in column major.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of cols.
+ * @param values One-dimensional array of doubles.
+ */
+public DenseMatrix(int numRows, int numCols, double[] values) {
+assert (values.length == numRows * numCols);

Review comment:
   done

##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+/** Row dimension. */
+public int numRows;
+
+/** Column dimension. */
+public int numCols;
+
+/**
+ * Array for internal storage of elements.
+ *
+ * The matrix data is stored in column major format internally.
+ */
+public double[] values;
+
+/**
+ * Constructs an m-by-n matrix of zeros.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of columns.
+ */
+public DenseMatrix(int numRows, int numCols) {
+this(numRows, numCols, new double[numRows * numCols]);
+}
+
+/**
+ * Constructs a matrix from a 1-D array. The data in the array should 
organize in column major.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of cols.
+ * @param values One-dimensional array of doubles.
+ */
+public DenseMatrix(int numRows, int numCols, double[] values) {
+assert (values.length == numRows * numCols);

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764659454



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
##
@@ -38,4 +43,28 @@ public static RowTypeInfo getRowTypeInfo(ResolvedSchema 
schema) {
 }
 return new RowTypeInfo(types, names);
 }
+
+/** Constructs a RowTypeInfo from the given schema. */
+public static RowTypeInfo getRowTypeInfo(Schema schema) {

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764659333



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnModelParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** Constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String broadcastKey = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(broadcastKey, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};
+DataType[] resultTypes = new DataType[] {idType};
+
+ResolvedSchema outputSchema =
+TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+Function>, DataStream> function =
+dataStreams -> {
+DataStream stream = dataStreams.get(0);
+return stream.transform(
+"mapFunc",
+TableUtils.getRowTypeInfo(outputSchema),
+new PredictOperator(
+inputs[0].getResolvedSchema(),
+broadcastKey,
+getK(),
+getFeaturesCol()));
+ 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764659169



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,278 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+private static final String LABEL_COL = "label";
+private static final String PRED_COL = "pred";
+private static final String VEC_COL = "vec";
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", Vectors.dense(2.0, 3.0)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("m", Vectors.dense(200.1, 300.1)),
+Row.of("m", Vectors.dense(200.2, 300.2)),
+Row.of("m", Vectors.dense(200.3, 300.3)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.6, 300.6)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("c", Vectors.dense(2.8, 3.2)),
+Row.of("d", Vectors.dense(300., 3.2)),
+Row.of("f", Vectors.dense(2.2, 3.2)),
+Row.of("e", Vectors.dense(2.4, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("f", Vectors.dense(2.1, 3.1;
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("e", Vectors.dense(4.0, 4.1)),
+Row.of("m", Vectors.dense(300, 42;
+private Table testData;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+Schema schema =
+Schema.newBuilder()
+.column("f0", DataTypes.STRING())
+.column("f1", DataTypes.of(DenseVector.class))
+.build();
+
+DataStream dataStream = env.fromCollection(trainArray);
+trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + "," 
+ VEC_COL);
+
+DataStream predDataStream = env.fromCollection(testArray);
+testData = tEnv.fromDataStream(predDataStream, schema).as(LABEL_COL + 
"," + VEC_COL);
+}
+
+// Executes the graph and returns a list which has true label and predict 
label.
+private static List> executeAndCollect(Table 
output) throws Exception {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764658436



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+protected Map, Object> params = new HashMap<>();
+
+/** Constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String labelCol = getLabelCol();
+String vecCol = getFeaturesCol();
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = 
String.valueOf(value.getField(labelCol));

Review comment:
   done

##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764650802



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
##
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@code DenseMatrix}. */
+public final class DenseMatrixSerializer extends 
TypeSerializerSingleton {
+
+private static final long serialVersionUID = 1L;
+
+private static final double[] EMPTY = new double[0];
+
+private static final DenseMatrixSerializer INSTANCE = new 
DenseMatrixSerializer();
+
+@Override
+public boolean isImmutableType() {
+return false;
+}
+
+@Override
+public DenseMatrix createInstance() {
+return new DenseMatrix(0, 0, EMPTY);
+}
+
+@Override
+public DenseMatrix copy(DenseMatrix from) {
+return new DenseMatrix(
+from.numRows, from.numCols, Arrays.copyOf(from.values, 
from.values.length));
+}
+
+@Override
+public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) {
+if (from.values.length == reuse.values.length) {
+System.arraycopy(from.values, 0, reuse.values, 0, 
from.values.length);
+return reuse;
+}
+return copy(from);
+}
+
+@Override
+public int getLength() {
+return -1;
+}
+
+@Override
+public void serialize(DenseMatrix matrix, DataOutputView target) throws 
IOException {
+if (matrix == null) {
+throw new IllegalArgumentException("The matrix must not be null.");
+}
+
+final int len = matrix.values.length;
+target.writeInt(matrix.numRows);
+target.writeInt(matrix.numCols);
+for (int i = 0; i < len; i++) {
+target.writeDouble(matrix.values[i]);
+}
+}
+
+@Override
+public DenseMatrix deserialize(DataInputView source) throws IOException {
+int m = source.readInt();
+int n = source.readInt();
+double[] values = new double[m * n];
+for (int i = 0; i < m * n; i++) {
+values[i] = source.readDouble();
+}
+return new DenseMatrix(m, n, values);
+}
+
+private static void readDoubleArray(double[] dst, DataInputView source, 
int len)
+throws IOException {
+for (int i = 0; i < len; i++) {
+dst[i] = source.readDouble();
+}
+}
+
+@Override
+public DenseMatrix deserialize(DenseMatrix reuse, DataInputView source) 
throws IOException {
+int m = source.readInt();
+int n = source.readInt();
+
+double[] values = new double[m * n];
+readDoubleArray(values, source, m * n);
+return new DenseMatrix(m, n, values);
+}
+
+@Override
+public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+int m = source.readInt();
+target.writeInt(m);
+int n = source.readInt();
+target.writeInt(n);
+
+target.write(source, m * n * 8);
+}
+
+// 
+
+@Override
+public TypeSerializerSnapshot snapshotConfiguration() {
+return new DenseMatrixSerializerSnapshot();
+}
+
+/** Serializer configuration snapshot for compatibility and format 
evolution. */
+@SuppressWarnings("WeakerAccess")

Review comment:
   @lindong28 I write this method as VectorSerializer. Can you explain the 
functionality?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-08 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764649082



##
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasK.java
##
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared K param. */
+public interface HasK extends WithParams {

Review comment:
   OK, I have defined this params inside KnnModelParams already.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r764497065



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,217 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm. KNN is to classify 
unlabeled observations by
+ * assigning them to the class of the most similar labeled examples.
+ *
+ * See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator, KnnParams {
+
+protected Map, Object> params = new HashMap<>();
+
+/** Constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables, including train data table.
+ * @return Knn model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String labelCol = getLabelCol();
+
+DataStream> trainData =
+input.map(
+new MapFunction>() {
+@Override
+public Tuple2 map(Row value) 
{
+Integer label = (Integer) 
value.getField(labelCol);
+DenseVector vec = (DenseVector) 
value.getField(getFeaturesCol());
+return Tuple2.of(vec, label);
+}
+});
+
+DataStream model = buildModel(trainData);
+KnnModel knnModel =
+new KnnModel()
+.setFeaturesCol(getFeaturesCol())
+.setK(getK())
+.setPredictionCol(getPredictionCol());
+knnModel.setModelData(tEnv.fromDataStream(model, 
KnnModelData.getModelSchema()));
+return knnModel;
+}
+
+/**
+ * Builds knn model.
+ *
+ * @param dataStream Input data.
+ * @return Knn model.
+ */
+private static DataStream buildModel(DataStream> dataStream) {
+Schema schema = KnnModelData.getModelSchema();
+return dataStream.transform(
+"build knn model",
+TableUtils.getRowTypeInfo(schema),
+new MapPartitionFunctionWrapper<>(
+new RichMapPartitionFunction, Row>() {
+@Override
+public void mapPartition(
+Iterable> 
values,
+Collector out) {
+List> list =
+prepareMatrixData(values);
+for (Tuple3 t3 : list) {
+Row ret = new Row(3);
+ret.setField(0, t3.f0);
+ret.setField(1, t3.f1);
+ret.setField(2, t3.f2);
+out.collect(ret);
+}
+}
+}));
+}
+
+/**
+ * Prepares matrix data, the output is a list of Tuple3, which includes 
vectors, vecNorms and
+ * label.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of list tuple3.
+ */
+private static List> 
prepareMatrixData(
+

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763930742



##
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasK.java
##
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared K param. */
+public interface HasK extends WithParams {

Review comment:
   I think this two params can't use one param. for :
   1. default value different.
   2. description different.
   3. ParamValidators different. 
   
   If update one with the other, we need update so many different info. Use two 
different params maybe better.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763729532



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnModelParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** Constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String broadcastKey = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(broadcastKey, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};
+DataType[] resultTypes = new DataType[] {idType};
+
+ResolvedSchema outputSchema =
+TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+Function>, DataStream> function =
+dataStreams -> {
+DataStream stream = dataStreams.get(0);
+return stream.transform(
+"mapFunc",
+TableUtils.getRowTypeInfo(outputSchema),
+new PredictOperator(
+inputs[0].getResolvedSchema(),
+broadcastKey,
+getK(),
+getFeaturesCol()));
+ 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763801857



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+/** Row dimension. */
+public int numRows;

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763808468



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763808290



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763802873



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763795024



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnModelParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** Constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String broadcastKey = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(broadcastKey, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763792288



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,278 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+private static final String LABEL_COL = "label";
+private static final String PRED_COL = "pred";
+private static final String VEC_COL = "vec";
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", Vectors.dense(2.0, 3.0)),

Review comment:
   Here, I use int instead of string.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763791507



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+protected Map, Object> params = new HashMap<>();
+
+/** Constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String labelCol = getLabelCol();
+String vecCol = getFeaturesCol();
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = 
String.valueOf(value.getField(labelCol));
+DenseVector vec = (DenseVector) 
value.getField(vecCol);
+return Row.of(label, vec);
+});
+DataType idType = null;
+for (int i = 0; i < colNames.length; i++) {
+if (labelCol.equalsIgnoreCase(colNames[i])) {

Review comment:
   code has been removed already.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763790521



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+protected Map, Object> params = new HashMap<>();
+
+/** Constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String labelCol = getLabelCol();
+String vecCol = getFeaturesCol();
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = 
String.valueOf(value.getField(labelCol));

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763730634



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnModelParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** Constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String broadcastKey = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(broadcastKey, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};
+DataType[] resultTypes = new DataType[] {idType};
+
+ResolvedSchema outputSchema =
+TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+Function>, DataStream> function =
+dataStreams -> {
+DataStream stream = dataStreams.get(0);
+return stream.transform(
+"mapFunc",
+TableUtils.getRowTypeInfo(outputSchema),
+new PredictOperator(
+inputs[0].getResolvedSchema(),
+broadcastKey,
+getK(),
+getFeaturesCol()));
+ 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763729532



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnModelParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** Constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String broadcastKey = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(broadcastKey, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};
+DataType[] resultTypes = new DataType[] {idType};
+
+ResolvedSchema outputSchema =
+TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+Function>, DataStream> function =
+dataStreams -> {
+DataStream stream = dataStreams.get(0);
+return stream.transform(
+"mapFunc",
+TableUtils.getRowTypeInfo(outputSchema),
+new PredictOperator(
+inputs[0].getResolvedSchema(),
+broadcastKey,
+getK(),
+getFeaturesCol()));
+ 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763728994



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+protected Map, Object> params = new HashMap<>();
+
+/** Constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String labelCol = getLabelCol();
+String vecCol = getFeaturesCol();
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = 
String.valueOf(value.getField(labelCol));

Review comment:
   OK, I will try it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763727970



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+protected Map, Object> params = new HashMap<>();
+
+/** Constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String labelCol = getLabelCol();
+String vecCol = getFeaturesCol();
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = 
String.valueOf(value.getField(labelCol));
+DenseVector vec = (DenseVector) 
value.getField(vecCol);
+return Row.of(label, vec);
+});
+DataType idType = null;
+for (int i = 0; i < colNames.length; i++) {
+if (labelCol.equalsIgnoreCase(colNames[i])) {

Review comment:
   OK




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763727160



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,278 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+private static final String LABEL_COL = "label";
+private static final String PRED_COL = "pred";
+private static final String VEC_COL = "vec";
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", Vectors.dense(2.0, 3.0)),

Review comment:
   OK, knn algorithm support any label type.
   I will refine it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763726499



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnModelParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** Constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String broadcastKey = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(broadcastKey, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};

Review comment:
   OK 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763726141



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */

Review comment:
   OK 
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763725903



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+/** Row dimension. */
+public int numRows;

Review comment:
   DenseMatrix data maybe changed by algorithms. I think final not needed 
here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763724812



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it

Review comment:
   OK




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763724549



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfo.java
##
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+/** A {@link TypeInformation} for the {@link DenseMatrix} type. */
+public class DenseMatrixTypeInfo extends TypeInformation {
+private static final long serialVersionUID = 1L;
+
+public static final DenseMatrixTypeInfo INSTANCE = new 
DenseMatrixTypeInfo();
+
+public DenseMatrixTypeInfo() {}
+
+@Override
+public int getArity() {

Review comment:
   OK




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-07 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763724378



##
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate 
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {

Review comment:
   OK, I will add it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763606874



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistance.java
##
@@ -0,0 +1,192 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * FastDistance is an accelerated distance calculating method. It use matrix 
vector operation to
+ * improve speed of distance calculating.
+ *
+ * The distance type in this class is euclidean distance:
+ *
+ * https://en.wikipedia.org/wiki/Euclidean_distance
+ */
+public class FastDistance implements Serializable {

Review comment:
   this class has been removed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763606748



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", Vectors.dense(2.0, 3.0)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("m", Vectors.dense(200.1, 300.1)),
+Row.of("m", Vectors.dense(200.2, 300.2)),
+Row.of("m", Vectors.dense(200.3, 300.3)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.6, 300.6)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("c", Vectors.dense(2.8, 3.2)),
+Row.of("d", Vectors.dense(300., 3.2)),
+Row.of("f", Vectors.dense(2.2, 3.2)),
+Row.of("e", Vectors.dense(2.4, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("f", Vectors.dense(2.1, 3.1;
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)), 
Row.of(Vectors.dense(300, 42;
+private Table testData;
+
+Row[] expectedData =
+new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m", 
Vectors.dense(300, 42))};
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+Schema schema =
+Schema.newBuilder()
+.column("f0", DataTypes.STRING())
+.column("f1", DataTypes.of(DenseVector.class))
+.build();
+
+DataStream dataStream = env.fromCollection(trainArray);
+trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+Schema outputSchema =
+Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+DataStream predDataStream = env.fromCollection(testArray);
+testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+}
+
+/** test knn Estimator. */
+@Test
+public void testFitAntTransform() throws Exception {
+Knn knn =
+new Knn()
+.setLabelCol("label")
+

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762978168



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", Vectors.dense(2.0, 3.0)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("m", Vectors.dense(200.1, 300.1)),
+Row.of("m", Vectors.dense(200.2, 300.2)),
+Row.of("m", Vectors.dense(200.3, 300.3)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.6, 300.6)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("c", Vectors.dense(2.8, 3.2)),
+Row.of("d", Vectors.dense(300., 3.2)),
+Row.of("f", Vectors.dense(2.2, 3.2)),
+Row.of("e", Vectors.dense(2.4, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("f", Vectors.dense(2.1, 3.1;
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)), 
Row.of(Vectors.dense(300, 42;
+private Table testData;
+
+Row[] expectedData =
+new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m", 
Vectors.dense(300, 42))};
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+Schema schema =
+Schema.newBuilder()
+.column("f0", DataTypes.STRING())
+.column("f1", DataTypes.of(DenseVector.class))
+.build();
+
+DataStream dataStream = env.fromCollection(trainArray);
+trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+Schema outputSchema =
+Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+DataStream predDataStream = env.fromCollection(testArray);
+testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+}
+
+/** test knn Estimator. */
+@Test
+public void testFitAntTransform() throws Exception {
+Knn knn =
+new Knn()
+.setLabelCol("label")
+

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762977739



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", Vectors.dense(2.0, 3.0)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("m", Vectors.dense(200.1, 300.1)),
+Row.of("m", Vectors.dense(200.2, 300.2)),
+Row.of("m", Vectors.dense(200.3, 300.3)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.4, 300.4)),
+Row.of("m", Vectors.dense(200.6, 300.6)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.1, 3.1)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("f", Vectors.dense(2.3, 3.2)),
+Row.of("c", Vectors.dense(2.8, 3.2)),
+Row.of("d", Vectors.dense(300., 3.2)),
+Row.of("f", Vectors.dense(2.2, 3.2)),
+Row.of("e", Vectors.dense(2.4, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("e", Vectors.dense(2.5, 3.2)),
+Row.of("f", Vectors.dense(2.1, 3.1;
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)), 
Row.of(Vectors.dense(300, 42;
+private Table testData;
+
+Row[] expectedData =
+new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m", 
Vectors.dense(300, 42))};
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+Schema schema =
+Schema.newBuilder()
+.column("f0", DataTypes.STRING())
+.column("f1", DataTypes.of(DenseVector.class))
+.build();
+
+DataStream dataStream = env.fromCollection(trainArray);
+trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+Schema outputSchema =
+Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+DataStream predDataStream = env.fromCollection(testArray);
+testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+}
+
+/** test knn Estimator. */
+@Test
+public void testFitAntTransform() throws Exception {
+Knn knn =
+new Knn()
+.setLabelCol("label")
+

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762977526



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+/** Stores the vector(sparse or dense). */
+final DenseVector vector;
+
+/**
+ * Stores some extra info extracted from the vector. For example, if we 
want to save the L1 norm
+ * and L2 norm of the vector, then the two values are viewed as a 
two-dimension label vector.
+ */
+public DenseVector label;

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762977288



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+/** Stores the vector(sparse or dense). */
+final DenseVector vector;

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762976987



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+/** Stores the vector(sparse or dense). */

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762976749



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in vectors is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+/**
+ * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+ * saves m vectors, then the number of rows of vectors is n 
and the number of cols
+ * of vectors is m.
+ */
+public final DenseMatrix vectors;
+/**
+ * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+ * FastDistanceVectorData, the length of rows is one. And for
+ * FastDistanceMatrixData, the length of rows is equal to the 
number of cols of
+ * matrix. Besides, the order of the rows are the same with 
the vectors.
+ */
+public final String[] ids;
+
+/**
+ * Stores some extra info extracted from the vector. It's also organized 
in columns. For
+ * example, if we want to save the L1 norm and L2 norm of the vector, then 
the two values are
+ * viewed as a two-dimension label vector. We organize the norm vectors 
together to get the
+ * label. If the number of cols of vectors is m, 
then in this case the
+ * dimension of label is 2 * m.
+ */
+public DenseMatrix label;
+
+public String[] getIds() {
+return ids;
+}
+
+/**
+ * Constructor, initialize the vector data and extra info.
+ *
+ * @param vectors DenseMatrix which saves vectors in columns.
+ * @param ids extra info besides the vector.
+ */
+public FastDistanceMatrixData(DenseMatrix vectors, String[] ids) {
+this.ids = ids;
+Preconditions.checkNotNull(vectors, "DenseMatrix should not be null!");
+if (null != ids) {
+Preconditions.checkArgument(
+vectors.numCols() == ids.length,
+"The column number of DenseMatrix must be equal to the 
rows array length!");
+}
+this.vectors = vectors;
+}
+
+/**
+ * serialization of FastDistanceMatrixData.
+ *
+ * @return json string.
+ */
+@Override
+public String toString() {

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762976486



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in vectors is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+/**
+ * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+ * saves m vectors, then the number of rows of vectors is n 
and the number of cols
+ * of vectors is m.
+ */
+public final DenseMatrix vectors;
+/**
+ * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+ * FastDistanceVectorData, the length of rows is one. And for
+ * FastDistanceMatrixData, the length of rows is equal to the 
number of cols of
+ * matrix. Besides, the order of the rows are the same with 
the vectors.
+ */
+public final String[] ids;
+
+/**
+ * Stores some extra info extracted from the vector. It's also organized 
in columns. For
+ * example, if we want to save the L1 norm and L2 norm of the vector, then 
the two values are
+ * viewed as a two-dimension label vector. We organize the norm vectors 
together to get the
+ * label. If the number of cols of vectors is m, 
then in this case the
+ * dimension of label is 2 * m.
+ */
+public DenseMatrix label;
+
+public String[] getIds() {

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762976263



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in vectors is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+/**
+ * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+ * saves m vectors, then the number of rows of vectors is n 
and the number of cols
+ * of vectors is m.
+ */
+public final DenseMatrix vectors;
+/**
+ * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+ * FastDistanceVectorData, the length of rows is one. And for
+ * FastDistanceMatrixData, the length of rows is equal to the 
number of cols of
+ * matrix. Besides, the order of the rows are the same with 
the vectors.
+ */
+public final String[] ids;

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-06 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762975880



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistance.java
##
@@ -0,0 +1,192 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * FastDistance is an accelerated distance calculating method. It use matrix 
vector operation to
+ * improve speed of distance calculating.
+ *
+ * The distance type in this class is euclidean distance:
+ *
+ * https://en.wikipedia.org/wiki/Euclidean_distance
+ */
+public class FastDistance implements Serializable {
+/** Label size. */
+private static final int LABEL_SIZE = 1;
+
+/** Maximum size of a matrix. */
+private static final int SIZE = 5 * 1024 * 1024;
+
+private static final int MAX_ROW_NUMBER = (int) Math.sqrt(200 * 1024 * 
1024 / 8.0);
+
+/** The blas used to accelerating speed. */
+private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
+(F2jBLAS) F2jBLAS.getInstance();
+
+/**
+ * Prepare the FastDistanceData, the output is a list of 
FastDistanceMatrixData. As the size of

Review comment:
   This class has been deleted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-05 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762657594



##
File path: flink-ml-api/src/main/java/org/apache/flink/ml/linalg/Vector.java
##
@@ -29,6 +29,10 @@
 /** Gets the value of the ith element. */
 double get(int i);
 
+
+/** set the value of the ith element. */
+void set(int i, double val);

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-02 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760919513



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,489 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.VectorUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String BROADCAST_STR = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] reservedCols =
+inputs[0].getResolvedSchema().getColumnNames().toArray(new 
String[0]);
+DataType[] reservedTypes =
+inputs[0].getResolvedSchema().getColumnDataTypes().toArray(new 
DataType[0]);
+String[] resultCols = new String[] {(String) 
params.get(KnnParams.PREDICTION_COL)};
+DataType[] resultTypes = new DataType[] {idType};
+ResolvedSchema outputSchema =
+ResolvedSchema.physical(
+

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-02 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760919046



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0", 1, 0, 1.47),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("m", "200.1 300.1", 1, 0, 1.5),
+Row.of("m", "200.2 300.2", 1, 0, 2.59),
+Row.of("m", "200.3 300.3", 1, 0, 2.55),
+Row.of("m", "200.4 300.4", 1, 0, 2.53),
+Row.of("m", "200.4 300.4", 1, 0, 2.52),
+Row.of("m", "200.6 300.6", 1, 0, 2.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.56),
+Row.of("f", "2.1 3.1", 1, 0, 1.51),
+Row.of("f", "2.1 3.1", 1, 0, 1.52),
+Row.of("f", "2.3 3.2", 1, 0, 1.53),
+Row.of("f", "2.3 3.2", 1, 0, 1.54),
+Row.of("c", "2.8 3.2", 3, 0, 1.6),
+Row.of("d", "300. 3.2", 5, 0, 1.5),
+Row.of("f", "2.2 3.2", 1, 0, 1.5),
+Row.of("e", "2.4 3.2", 2, 0, 1.3),
+Row.of("e", "2.5 3.2", 2, 0, 1.4),
+Row.of("e", "2.5 3.2", 2, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+private Table testData;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+DataStream dataStream =
+env.fromCollection(
+trainArray,
+new RowTypeInfo(
+new TypeInformation[] {
+Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+},
+new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+trainData = tEnv.fromDataStream(dataStream);
+
+DataStream dataStreamStr =
+env.fromCollection(
+testArray,
+new RowTypeInfo(
+new TypeInformation[] {
+Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+},
+new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+testData = tEnv.fromDataStream(dataStreamStr);
+}
+
+/** 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-02 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760918901



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0", 1, 0, 1.47),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("m", "200.1 300.1", 1, 0, 1.5),
+Row.of("m", "200.2 300.2", 1, 0, 2.59),
+Row.of("m", "200.3 300.3", 1, 0, 2.55),
+Row.of("m", "200.4 300.4", 1, 0, 2.53),
+Row.of("m", "200.4 300.4", 1, 0, 2.52),
+Row.of("m", "200.6 300.6", 1, 0, 2.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.56),
+Row.of("f", "2.1 3.1", 1, 0, 1.51),
+Row.of("f", "2.1 3.1", 1, 0, 1.52),
+Row.of("f", "2.3 3.2", 1, 0, 1.53),
+Row.of("f", "2.3 3.2", 1, 0, 1.54),
+Row.of("c", "2.8 3.2", 3, 0, 1.6),
+Row.of("d", "300. 3.2", 5, 0, 1.5),
+Row.of("f", "2.2 3.2", 1, 0, 1.5),
+Row.of("e", "2.4 3.2", 2, 0, 1.3),
+Row.of("e", "2.5 3.2", 2, 0, 1.4),
+Row.of("e", "2.5 3.2", 2, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+private Table testData;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+DataStream dataStream =
+env.fromCollection(
+trainArray,
+new RowTypeInfo(
+new TypeInformation[] {
+Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+},
+new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+trainData = tEnv.fromDataStream(dataStream);
+
+DataStream dataStreamStr =
+env.fromCollection(
+testArray,
+new RowTypeInfo(
+new TypeInformation[] {
+Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+},
+new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+testData = tEnv.fromDataStream(dataStreamStr);
+}
+
+/** 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-02 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760918561



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0", 1, 0, 1.47),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("m", "200.1 300.1", 1, 0, 1.5),
+Row.of("m", "200.2 300.2", 1, 0, 2.59),
+Row.of("m", "200.3 300.3", 1, 0, 2.55),
+Row.of("m", "200.4 300.4", 1, 0, 2.53),
+Row.of("m", "200.4 300.4", 1, 0, 2.52),
+Row.of("m", "200.6 300.6", 1, 0, 2.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.56),
+Row.of("f", "2.1 3.1", 1, 0, 1.51),
+Row.of("f", "2.1 3.1", 1, 0, 1.52),
+Row.of("f", "2.3 3.2", 1, 0, 1.53),
+Row.of("f", "2.3 3.2", 1, 0, 1.54),
+Row.of("c", "2.8 3.2", 3, 0, 1.6),
+Row.of("d", "300. 3.2", 5, 0, 1.5),
+Row.of("f", "2.2 3.2", 1, 0, 1.5),
+Row.of("e", "2.4 3.2", 2, 0, 1.3),
+Row.of("e", "2.5 3.2", 2, 0, 1.4),
+Row.of("e", "2.5 3.2", 2, 0, 1.5),
+Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+List testArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+private Table testData;
+
+@Before
+public void before() {
+Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+env.setParallelism(4);
+env.enableCheckpointing(100);
+env.setRestartStrategy(RestartStrategies.noRestart());
+tEnv = StreamTableEnvironment.create(env);
+
+DataStream dataStream =
+env.fromCollection(
+trainArray,
+new RowTypeInfo(
+new TypeInformation[] {
+Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+},
+new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+trainData = tEnv.fromDataStream(dataStream);

Review comment:
   done.

##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-12-02 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760918295



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0", 1, 0, 1.47),

Review comment:
   Thx for your comment. 
   I have refine code as you refer.
   
   done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-30 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r759912742



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Params of the names of the feature columns used for training in the input 
table. */
+public interface HasFeatureColsDefaultAsNull extends WithParams {
+/**
+ * @cn-name 特征列名数组
+ * @cn 特征列名数组,默认全选
+ */
+Param FEATURE_COLS =

Review comment:
   done
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-30 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r759909634



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##
@@ -0,0 +1,134 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+/** knn model data, which will be used to calculate the distances between 
nodes. */
+public class KnnModelData implements Serializable, Cloneable {
+private static final long serialVersionUID = -2940551481683238630L;

Review comment:
   done

##
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasK.java
##
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared K param. */
+public interface HasK extends WithParams {
+
+/**
+ * topK
+ */
+Param K = new IntParam("k", "k", 10, ParamValidators.gt(0));

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-30 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r759909541



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##
@@ -0,0 +1,18 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn parameters. */
+public interface KnnParams
+extends WithParams,

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-26 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r757375204



##
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasK.java
##
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared K param. */
+public interface HasK extends WithParams {
+
+/**
+ * topK
+ */
+Param K = new IntParam("k", "k", 10, ParamValidators.gt(0));

Review comment:
   OK, I will refine other use case.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-26 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r757366985



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,493 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.VectorUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+protected Map, Object> params = new HashMap<>();
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn classification model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+final String BROADCAST_STR = "broadcastModelKey";
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+String[] reservedCols = 
inputs[0].getResolvedSchema().getColumnNames().toArray(new String[0]);
+DataType[] reservedTypes = 
inputs[0].getResolvedSchema().getColumnDataTypes().toArray(new DataType[0]);

Review comment:
   Different algorithm maybe have different output schema. For example : 
knn result may different with lr, for lr has detail info which knn not have.
   
   So, I think the output 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-26 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r757360455



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##
@@ -0,0 +1,134 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+/** knn model data, which will be used to calculate the distances between 
nodes. */
+public class KnnModelData implements Serializable, Cloneable {
+private static final long serialVersionUID = -2940551481683238630L;

Review comment:
   OK, I will refine it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-26 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r757359842



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##
@@ -0,0 +1,18 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn parameters. */
+public interface KnnParams
+extends WithParams,

Review comment:
   OK, I will refine it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-25 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756819327



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Params of the names of the feature columns used for training in the input 
table. */
+public interface HasFeatureColsDefaultAsNull extends WithParams {
+/**
+ * @cn-name 特征列名数组
+ * @cn 特征列名数组,默认全选
+ */
+Param FEATURE_COLS =

Review comment:
   I think we need discuss the format of features input for the algorithm. 
A dense vector or a string format vector or some numerical columns?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-25 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756818051



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn classification model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+ResolvedSchema outputSchema =
+getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+DataType[] dataTypes = 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-25 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756817571



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,172 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Pipeline;
+import org.apache.flink.ml.api.core.Stage;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0"),
+Row.of("f", "2.1 3.1"),
+Row.of("m", "200.1 300.1"),
+Row.of("m", "200.2 300.2"),
+Row.of("m", "200.3 300.3"),
+Row.of("m", "200.4 300.4"),
+Row.of("m", "200.4 300.4"),
+Row.of("m", "200.6 300.6"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.3 3.2"),
+Row.of("f", "2.3 3.2"),
+Row.of("c", "2.8 3.2"),
+Row.of("d", "300. 3.2"),
+Row.of("f", "2.2 3.2"),
+Row.of("e", "2.4 3.2"),
+Row.of("e", "2.5 3.2"),
+Row.of("e", "2.5 3.2"),
+Row.of("f", "2.1 3.1")));
+
+List testArray =
+new ArrayList<>(Arrays.asList(Row.of("e", "4.0 4.1"), Row.of("m", 
"300 42")));
+
+private Table testData;
+

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-25 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756817467



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/EuclideanDistance.java
##
@@ -0,0 +1,272 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import static 
org.apache.flink.ml.classification.knn.KnnUtil.appendVectorToMatrix;
+
+/**
+ * Euclidean distance is the "ordinary" straight-line distance between two 
points in Euclidean
+ * space.
+ *
+ * https://en.wikipedia.org/wiki/Euclidean_distance
+ *
+ * Given two vectors a and b, Euclidean Distance = ||a - b||, where ||*|| 
means the L2 norm of
+ * the vector.
+ */
+public class EuclideanDistance implements Serializable {

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-24 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756627878



##
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
##
@@ -60,6 +75,74 @@ public boolean equals(Object obj) {
 return Arrays.equals(values, ((DenseVector) obj).values);
 }
 
+/**
+ * Parse the dense vector from a formatted string.
+ *
+ * The format of a dense vector is space separated values such as "1 2 
3 4".
+ *
+ * @param str A string of space separated values.
+ * @return The parsed vector.
+ */
+public static DenseVector fromString(String str) {

Review comment:
   OK




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-24 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756526795



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Params of the names of the feature columns used for training in the input 
table. */
+public interface HasFeatureColsDefaultAsNull extends WithParams {
+/**
+ * @cn-name 特征列名数组

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-24 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r756526696



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/MapPartitionFunctionWrapper.java
##
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.common;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+/**
+ * MapPartitionFunction wrapper.
+ *
+ * @param  Input element type.
+ * @param  Output element type.
+ */
+public class MapPartitionFunctionWrapper extends 
AbstractStreamOperator

Review comment:
   done
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755690633



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/EuclideanDistance.java
##
@@ -0,0 +1,272 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import static 
org.apache.flink.ml.classification.knn.KnnUtil.appendVectorToMatrix;
+
+/**
+ * Euclidean distance is the "ordinary" straight-line distance between two 
points in Euclidean
+ * space.
+ *
+ * https://en.wikipedia.org/wiki/Euclidean_distance
+ *
+ * Given two vectors a and b, Euclidean Distance = ||a - b||, where ||*|| 
means the L2 norm of
+ * the vector.
+ */
+public class EuclideanDistance implements Serializable {

Review comment:
   OK, knn's EuclideanDistance may be not a common distance, it's a fast 
distance. 
   Here, I will change the name of knn's distance name, using FastDistance 
instead of EuclideanDistance.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755688272



##
File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
##
@@ -75,7 +75,7 @@ public String jsonEncode(T value) throws IOException {
  */
 @SuppressWarnings("unchecked")
 public T jsonDecode(String json) throws IOException {
-return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz);
+return ReadWriteUtils.OBJECT_MAPPER.fromJson(json, clazz);

Review comment:
   done

##
File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
##
@@ -64,7 +64,7 @@ public Param(
  * @return A json-formatted string.
  */
 public String jsonEncode(T value) throws IOException {
-return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value);
+return ReadWriteUtils.OBJECT_MAPPER.toJson(value);

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755688178



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,255 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+private static final long serialVersionUID = 5292477422193301398L;
+private static final int ROW_SIZE = 2;
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+/** constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public Knn(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * @param inputs a list of tables
+ * @return knn classification model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String[] targetCols = getFeatureCols();
+final int[] featureIndices;
+if (targetCols == null) {
+featureIndices = new int[colNames.length];
+for (int i = 0; i < colNames.length; i++) {
+featureIndices[i] = i;
+}
+} else {
+featureIndices = new int[targetCols.length];
+for (int i = 0; i < featureIndices.length; i++) {
+featureIndices[i] = findColIndex(colNames, targetCols[i]);
+}
+}
+String labelCol = getLabelCol();
+final int labelIdx = findColIndex(colNames, labelCol);
+final int vecIdx =
+getVectorCol() != null
+? findColIndex(
+inputs[0]
+.getResolvedSchema()
+.getColumnNames()
+.toArray(new String[0]),
+getVectorCol())
+: -1;
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = value.getField(labelIdx);

Review comment:
   done

##
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
##
@@ -43,8 +46,12 @@
 
 /** Utility methods for reading and writing stages. */
 public class ReadWriteUtils {
-public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-
+public static Gson OBJECT_MAPPER =

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755684260



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##
@@ -0,0 +1,32 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn fit parameters. */
+public interface KnnParams
+extends WithParams,
+HasVectorColDefaultAsNull,
+HasLabelCol,
+HasFeatureColsDefaultAsNull,
+HasPredictionCol {
+/**
+ * @cn-name topK
+ * @cn topK
+ */
+Param K = new IntParam("k", "k", 10, ParamValidators.gt(0));

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755684156



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##
@@ -0,0 +1,32 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn fit parameters. */
+public interface KnnParams
+extends WithParams,
+HasVectorColDefaultAsNull,
+HasLabelCol,
+HasFeatureColsDefaultAsNull,
+HasPredictionCol {
+/**
+ * @cn-name topK
+ * @cn topK
+ */
+Param K = new IntParam("k", "k", 10, ParamValidators.gt(0));
+
+default Integer getK() {

Review comment:
   done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755682429



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn classification model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+ResolvedSchema outputSchema =
+getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+DataType[] dataTypes = 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755682171



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";

Review comment:
   done
   

##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755678624



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";

Review comment:
   OK, I will refine it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755678460



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn classification model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+ResolvedSchema outputSchema =
+getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+DataType[] dataTypes = 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755678350



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn classification model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+ResolvedSchema outputSchema =
+getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+DataType[] dataTypes = 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755677909



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model, KnnParams {
+
+private static final long serialVersionUID = 1303892137143865652L;
+
+private static final String BROADCAST_STR = "broadcastModelKey";
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+private Table[] modelData;
+
+/** constructor. */
+public KnnModel() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public KnnModel(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * Set model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn classification model.
+ */
+@Override
+public KnnModel setModelData(Table... modelData) {
+this.modelData = modelData;
+return this;
+}
+
+/**
+ * get model data.
+ *
+ * @return list of tables.
+ */
+@Override
+public Table[] getModelData() {
+return modelData;
+}
+
+/**
+ * @param inputs a list of tables.
+ * @return result.
+ */
+@Override
+public Table[] transform(Table... inputs) {
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+DataStream model = tEnv.toDataStream(modelData[0]);
+
+Map> broadcastMap = new HashMap<>(1);
+broadcastMap.put(BROADCAST_STR, model);
+ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+ResolvedSchema outputSchema =
+getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+DataType[] dataTypes = 

[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755677169



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Params of the names of the feature columns used for training in the input 
table. */
+public interface HasFeatureColsDefaultAsNull extends WithParams {
+/**
+ * @cn-name 特征列名数组
+ * @cn 特征列名数组,默认全选
+ */
+Param FEATURE_COLS =

Review comment:
   OK, I will refine it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755676992



##
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##
@@ -0,0 +1,172 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Pipeline;
+import org.apache.flink.ml.api.core.Stage;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/** knn algorithm test. */
+public class KnnTest {
+private StreamExecutionEnvironment env;
+private StreamTableEnvironment tEnv;
+private Table trainData;
+
+List trainArray =
+new ArrayList<>(
+Arrays.asList(
+Row.of("f", "2.0 3.0"),
+Row.of("f", "2.1 3.1"),
+Row.of("m", "200.1 300.1"),
+Row.of("m", "200.2 300.2"),
+Row.of("m", "200.3 300.3"),
+Row.of("m", "200.4 300.4"),
+Row.of("m", "200.4 300.4"),
+Row.of("m", "200.6 300.6"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.1 3.1"),
+Row.of("f", "2.3 3.2"),
+Row.of("f", "2.3 3.2"),
+Row.of("c", "2.8 3.2"),
+Row.of("d", "300. 3.2"),
+Row.of("f", "2.2 3.2"),
+Row.of("e", "2.4 3.2"),
+Row.of("e", "2.5 3.2"),
+Row.of("e", "2.5 3.2"),
+Row.of("f", "2.1 3.1")));
+
+List testArray =
+new ArrayList<>(Arrays.asList(Row.of("e", "4.0 4.1"), Row.of("m", 
"300 42")));
+
+private Table testData;
+

Review comment:
   OK, I  will do it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755676909



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##
@@ -0,0 +1,32 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn fit parameters. */
+public interface KnnParams
+extends WithParams,
+HasVectorColDefaultAsNull,
+HasLabelCol,
+HasFeatureColsDefaultAsNull,
+HasPredictionCol {
+/**
+ * @cn-name topK
+ * @cn topK
+ */
+Param K = new IntParam("k", "k", 10, ParamValidators.gt(0));
+
+default Integer getK() {

Review comment:
   K can be get in tow parttern: getK() or params.get(KnnParams.K)
   algorithm's implementation is no problem. I just use the second pattern, 
maybe I need change the patten later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755675421



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##
@@ -0,0 +1,93 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in vectors is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+private static final long serialVersionUID = 3093977891649431843L;
+
+/**
+ * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+ * saves m vectors, then the number of rows of vectors is n 
and the number of cols
+ * of vectors is m.
+ */
+public final DenseMatrix vectors;
+/**
+ * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+ * FastDistanceVectorData, the length of rows is one. And for
+ * FastDistanceMatrixData, the length of rows is equal to the 
number of cols of
+ * matrix. Besides, the order of the rows are the same with 
the vectors.
+ */
+public final Row[] rows;
+
+/**
+ * Stores some extra info extracted from the vector. It's also organized 
in columns. For
+ * example, if we want to save the L1 norm and L2 norm of the vector, then 
the two values are
+ * viewed as a two-dimension label vector. We organize the norm vectors 
together to get the
+ * label. If the number of cols of vectors is m, 
then in this case the
+ * dimension of label is 2 * m.
+ */
+public DenseMatrix label;
+
+public Row[] getRows() {
+return rows;
+}
+
+/**
+ * Constructor, initialize the vector data and extra info.
+ *
+ * @param vectors DenseMatrix which saves vectors in columns.
+ * @param rows extra info besides the vector.
+ */
+public FastDistanceMatrixData(DenseMatrix vectors, Row[] rows) {
+this.rows = rows;
+Preconditions.checkNotNull(vectors, "DenseMatrix should not be null!");
+if (null != rows) {
+Preconditions.checkArgument(
+vectors.numCols() == rows.length,
+"The column number of DenseMatrix must be equal to the 
rows array length!");
+}
+this.vectors = vectors;
+}
+
+/**
+ * serialization of FastDistanceMatrixData.
+ *
+ * @return json string.
+ */
+@Override
+public String toString() {
+Map params = new HashMap<>(3);
+params.put("vectors", ReadWriteUtils.OBJECT_MAPPER.toJson(vectors));
+params.put("label", ReadWriteUtils.OBJECT_MAPPER.toJson(label));
+params.put("rows", ReadWriteUtils.OBJECT_MAPPER.toJson(rows));
+return ReadWriteUtils.OBJECT_MAPPER.toJson(params);
+}
+
+/**
+ * deserialization of FastDistanceMatrixData.
+ *
+ * @param modelStr string of model serialization.
+ * @return FastDistanceMatrixData
+ */
+public static FastDistanceMatrixData fromString(String modelStr) {

Review comment:
   It's not a public data structure, it is using only in knn algo. I think 
toString and fromString is OK.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755674770



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/MapPartitionFunctionWrapper.java
##
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.common;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+/**
+ * MapPartitionFunction wrapper.
+ *
+ * @param  Input element type.
+ * @param  Output element type.
+ */
+public class MapPartitionFunctionWrapper extends 
AbstractStreamOperator

Review comment:
   OK, I will refine it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755674858



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##
@@ -0,0 +1,32 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeatureColsDefaultAsNull;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasVectorColDefaultAsNull;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** knn fit parameters. */
+public interface KnnParams
+extends WithParams,
+HasVectorColDefaultAsNull,
+HasLabelCol,
+HasFeatureColsDefaultAsNull,
+HasPredictionCol {
+/**
+ * @cn-name topK
+ * @cn topK
+ */
+Param K = new IntParam("k", "k", 10, ParamValidators.gt(0));

Review comment:
   OK, I will refine it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755674542



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##
@@ -0,0 +1,255 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator, KnnParams {
+
+private static final long serialVersionUID = 5292477422193301398L;
+private static final int ROW_SIZE = 2;
+private static final int FASTDISTANCE_TYPE_INDEX = 0;
+private static final int DATA_INDEX = 1;
+
+protected Map, Object> params = new HashMap<>();
+
+/** constructor. */
+public Knn() {
+ParamUtils.initializeMapWithDefaultValues(params, this);
+}
+
+/**
+ * constructor.
+ *
+ * @param params parameters for algorithm.
+ */
+public Knn(Map, Object> params) {
+this.params = params;
+}
+
+/**
+ * @param inputs a list of tables
+ * @return knn classification model.
+ */
+@Override
+public KnnModel fit(Table... inputs) {
+Preconditions.checkArgument(inputs.length == 1);
+ResolvedSchema schema = inputs[0].getResolvedSchema();
+String[] colNames = schema.getColumnNames().toArray(new String[0]);
+StreamTableEnvironment tEnv =
+(StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+DataStream input = tEnv.toDataStream(inputs[0]);
+String[] targetCols = getFeatureCols();
+final int[] featureIndices;
+if (targetCols == null) {
+featureIndices = new int[colNames.length];
+for (int i = 0; i < colNames.length; i++) {
+featureIndices[i] = i;
+}
+} else {
+featureIndices = new int[targetCols.length];
+for (int i = 0; i < featureIndices.length; i++) {
+featureIndices[i] = findColIndex(colNames, targetCols[i]);
+}
+}
+String labelCol = getLabelCol();
+final int labelIdx = findColIndex(colNames, labelCol);
+final int vecIdx =
+getVectorCol() != null
+? findColIndex(
+inputs[0]
+.getResolvedSchema()
+.getColumnNames()
+.toArray(new String[0]),
+getVectorCol())
+: -1;
+
+DataStream trainData =
+input.map(
+(MapFunction)
+value -> {
+Object label = value.getField(labelIdx);

Review comment:
   OK, using this pattern may be better. I will modify it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755674684



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureColsDefaultAsNull.java
##
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Params of the names of the feature columns used for training in the input 
table. */
+public interface HasFeatureColsDefaultAsNull extends WithParams {
+/**
+ * @cn-name 特征列名数组

Review comment:
   OK, I will refine it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755674205



##
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
##
@@ -43,8 +46,12 @@
 
 /** Utility methods for reading and writing stages. */
 public class ReadWriteUtils {
-public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-
+public static Gson OBJECT_MAPPER =

Review comment:
   After discuss with lindong, I will take back ObjectMapper.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755673857



##
File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
##
@@ -75,7 +75,7 @@ public String jsonEncode(T value) throws IOException {
  */
 @SuppressWarnings("unchecked")
 public T jsonDecode(String json) throws IOException {
-return ReadWriteUtils.OBJECT_MAPPER.readValue(json, clazz);
+return ReadWriteUtils.OBJECT_MAPPER.fromJson(json, clazz);

Review comment:
   Just like writeValueAsString. I will modify it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755673679



##
File path: flink-ml-api/src/main/java/org/apache/flink/ml/param/Param.java
##
@@ -64,7 +64,7 @@ public Param(
  * @return A json-formatted string.
  */
 public String jsonEncode(T value) throws IOException {
-return ReadWriteUtils.OBJECT_MAPPER.writeValueAsString(value);
+return ReadWriteUtils.OBJECT_MAPPER.toJson(value);

Review comment:
   After discuss with lindong, I will modify this code. using 
writeValueAsString may be better.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-23 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755673190



##
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
##
@@ -42,14 +45,26 @@ public double get(int i) {
 return values[i];
 }
 
+@Override
+public void set(int i, double val) {
+values[i] = val;
+}
+
 @Override
 public double[] toArray() {
 return values;
 }
 
 @Override
 public String toString() {
-return Arrays.toString(values);
+StringBuilder sbd = new StringBuilder();

Review comment:
   DenseVectorSerializer is the serializer for flink
   Here, we just want to transform a denseVector to string, and transform 
string to denseVector. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-18 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r752853592



##
File path: flink-ml-lib/pom.xml
##
@@ -106,6 +112,11 @@ under the License.
   jar
   test
 
+  
+  com.google.code.gson
+  gson
+  2.8.6
+  

Review comment:
   ReadWriteUtils may not satisfy my need, it only a utils to help 
serialize model data.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: [Flink 24557] - Add knn algorithm to flink-ml

2021-11-18 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r752930662



##
File path: flink-ml-lib/pom.xml
##
@@ -106,6 +112,11 @@ under the License.
   jar
   test
 
+  
+  com.google.code.gson
+  gson
+  2.8.6
+  

Review comment:
   For the json tool ReadWriteUtils supply not support some class's 
serialization. Here, I use the google replace the default json tool. Then, we 
have only one json tool.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [flink-ml] weibozhao commented on a change in pull request #24: Flink 24557- add knn algorithm to flink-ml

2021-11-18 Thread GitBox


weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r752903128



##
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/distance/EuclideanDistance.java
##
@@ -0,0 +1,259 @@
+package org.apache.flink.ml.classification.knn.distance;
+
+import org.apache.flink.ml.common.linalg.BLAS;
+import org.apache.flink.ml.common.linalg.DenseMatrix;
+import org.apache.flink.ml.common.linalg.DenseVector;
+import org.apache.flink.ml.common.linalg.MatVecOp;
+import org.apache.flink.ml.common.linalg.SparseVector;
+import org.apache.flink.ml.common.linalg.Vector;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Arrays;
+
+/**
+ * Euclidean distance is the "ordinary" straight-line distance between two 
points in Euclidean
+ * space.
+ *
+ * https://en.wikipedia.org/wiki/Euclidean_distance
+ *
+ * Given two vectors a and b, Euclidean Distance = ||a - b||, where ||*|| 
means the L2 norm of
+ * the vector.
+ */
+public class EuclideanDistance extends BaseFastDistance {

Review comment:
   yes
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org