weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763791507



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** Constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Fits data and produces knn model.
+     *
+     * @param inputs A list of tables
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String labelCol = getLabelCol();
+        String vecCol = getFeaturesCol();
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = 
String.valueOf(value.getField(labelCol));
+                                    DenseVector vec = (DenseVector) 
value.getField(vecCol);
+                                    return Row.of(label, vec);
+                                });
+        DataType idType = null;
+        for (int i = 0; i < colNames.length; i++) {
+            if (labelCol.equalsIgnoreCase(colNames[i])) {

Review comment:
       code has been removed already.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to