weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763729532



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table[] modelData;
+
+    /** Constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Sets model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * Gets model data.
+     *
+     * @return table array.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * Predicts label with knn model.
+     *
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+        final String broadcastKey = "broadcastModelKey";
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(broadcastKey, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+        String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+
+        ResolvedSchema outputSchema =
+                TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            TableUtils.getRowTypeInfo(outputSchema),
+                            new PredictOperator(
+                                    inputs[0].getResolvedSchema(),
+                                    broadcastKey,
+                                    getK(),
+                                    getFeaturesCol()));
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator load the model data and do the prediction. */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+            implements OneInputStreamOperator<Row, Row> {
+
+        private boolean firstEle = true;
+        private final String[] reservedCols;
+        private final String featureCol;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private final String broadcastKey;
+
+        public PredictOperator(
+                ResolvedSchema dataSchema, String broadcastKey, int k, String 
featureCol) {
+            super(new AbstractRichFunction() {});
+            reservedCols = dataSchema.getColumnNames().toArray(new String[0]);
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(map(value)));
+        }
+
+        public Row map(Row row) throws Exception {
+            if (firstEle) {
+                
loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstEle = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Object>, List<Double>> t2 = findNeighbor(vector, topN, 
modelData);
+            Row ret = new Row(reservedCols.length + 1);
+            for (int i = 0; i < reservedCols.length; ++i) {
+                ret.setField(i, row.getField(reservedCols[i]));
+            }
+
+            Tuple2<Object, String> tuple2 = getResultFormat(t2);
+            ret.setField(reservedCols.length, tuple2.f0);
+            return ret;
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input input node.
+         * @param topN top N.
+         * @return neighbor.
+         */
+        private Tuple2<List<Object>, List<Double>> findNeighbor(
+                Object input, Integer topN, KnnModelData modelData) {
+            PriorityQueue<Tuple2<Double, Object>> priorityQueue =
+                    new PriorityQueue<>(modelData.getQueueComparator());
+            search(input, topN, priorityQueue, modelData);
+            List<Object> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Object> result = priorityQueue.poll();
+                items.add(result.f1);
+                metrics.add(result.f0);
+            }
+            Collections.reverse(items);
+            Collections.reverse(metrics);
+            priorityQueue.clear();
+            return Tuple2.of(items, metrics);
+        }
+
+        /**
+         * @param input input node.
+         * @param topN top N.
+         * @param priorityQueue priority queue.
+         */
+        private void search(
+                Object input,
+                Integer topN,
+                PriorityQueue<Tuple2<Double, Object>> priorityQueue,
+                KnnModelData modelData) {
+            Tuple2<DenseVector, Double> sample = computeNorm((DenseVector) 
input);
+            Tuple2<Double, Object> head = null;
+            for (int i = 0; i < modelData.getLength(); i++) {
+                ArrayList<Tuple2<Double, Object>> values = 
computeDistance(sample, i);
+                for (Tuple2<Double, Object> currentValue : values) {
+                    head = updateQueue(priorityQueue, topN, currentValue, 
head);
+                }
+            }
+        }
+
+        /**
+         * Updates queue.
+         *
+         * @param map queue.
+         * @param topN top N.
+         * @param newValue new value.
+         * @param head head value.
+         * @param <T> id type.
+         * @return head value.
+         */
+        private <T> Tuple2<Double, T> updateQueue(
+                PriorityQueue<Tuple2<Double, T>> map,
+                int topN,
+                Tuple2<Double, T> newValue,
+                Tuple2<Double, T> head) {
+            if (map.size() < topN) {
+                map.add(Tuple2.of(newValue.f0, newValue.f1));
+                head = map.peek();
+            } else {
+                if (map.comparator().compare(head, newValue) < 0) {
+                    Tuple2<Double, T> peek = map.poll();
+                    assert peek != null;
+                    peek.f0 = newValue.f0;
+                    peek.f1 = newValue.f1;
+                    map.add(peek);
+                    head = map.peek();
+                }
+            }
+            return head;
+        }
+
+        /**
+         * Computes distance between sample and dictionary vectors.
+         *
+         * @param input sample with l2 norm.
+         * @param index dictionary vectors index.
+         * @return distances.
+         */
+        private ArrayList<Tuple2<Double, Object>> computeDistance(
+                Tuple2<DenseVector, Double> input, Integer index) {
+            Tuple3<DenseMatrix, DenseVector, String[]> data = 
modelData.getDictData().get(index);
+
+            DenseMatrix res = calc(input, data);
+            ArrayList<Tuple2<Double, Object>> list = new ArrayList<>(0);
+            String[] curLabels = data.f2;
+            for (int i = 0; i < Objects.requireNonNull(curLabels).length; i++) 
{
+                Tuple2<Double, Object> tuple = Tuple2.of(res.values[i], 
curLabels[i]);
+                list.add(tuple);
+            }
+            return list;
+        }
+
+        /** The blas used to accelerating speed. */
+        private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
+                (F2jBLAS) F2jBLAS.getInstance();
+
+        /**
+         * Compute distance between sample and dictionary vectors.
+         *
+         * @param left Sample and norm.
+         * @param right Dictionary vectors with row format.
+         * @return a new DenseMatrix which store the result distance.
+         */
+        public DenseMatrix calc(
+                Tuple2<DenseVector, Double> left,
+                Tuple3<DenseMatrix, DenseVector, String[]> right) {
+            DenseMatrix vectors = right.f0;
+            DenseMatrix res = new 
DenseMatrix(Objects.requireNonNull(vectors).numCols, 1);
+            DenseVector norm = right.f1;
+            double[] normL2Square = Objects.requireNonNull(norm).values;
+
+            final int m = vectors.numRows;
+            final int n = vectors.numCols;
+            NATIVE_BLAS.dgemv(
+                    "T", m, n, -2.0, vectors.values, m, left.f0.toArray(), 1, 
0.0, res.values, 1);
+
+            for (int i = 0; i < res.values.length; i++) {
+                res.values[i] = Math.sqrt(Math.abs(res.values[i] + left.f1 + 
normL2Square[i]));
+            }
+            return res;
+        }
+
+        /**
+         * Computes norm2 of vector.
+         *
+         * @return Sample with norm2.
+         */
+        public static Tuple2<DenseVector, Double> computeNorm(DenseVector 
vector) {

Review comment:
       OK 
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table[] modelData;
+
+    /** Constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Sets model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * Gets model data.
+     *
+     * @return table array.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * Predicts label with knn model.
+     *
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+        final String broadcastKey = "broadcastModelKey";
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(broadcastKey, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+        String[] resultCols = new String[] {(String) 
params.get(KnnModelParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+
+        ResolvedSchema outputSchema =
+                TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            TableUtils.getRowTypeInfo(outputSchema),
+                            new PredictOperator(
+                                    inputs[0].getResolvedSchema(),
+                                    broadcastKey,
+                                    getK(),
+                                    getFeaturesCol()));
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator load the model data and do the prediction. */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+            implements OneInputStreamOperator<Row, Row> {
+
+        private boolean firstEle = true;
+        private final String[] reservedCols;
+        private final String featureCol;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private final String broadcastKey;
+
+        public PredictOperator(
+                ResolvedSchema dataSchema, String broadcastKey, int k, String 
featureCol) {
+            super(new AbstractRichFunction() {});
+            reservedCols = dataSchema.getColumnNames().toArray(new String[0]);
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(map(value)));
+        }
+
+        public Row map(Row row) throws Exception {
+            if (firstEle) {
+                
loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstEle = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Object>, List<Double>> t2 = findNeighbor(vector, topN, 
modelData);
+            Row ret = new Row(reservedCols.length + 1);
+            for (int i = 0; i < reservedCols.length; ++i) {
+                ret.setField(i, row.getField(reservedCols[i]));
+            }
+
+            Tuple2<Object, String> tuple2 = getResultFormat(t2);
+            ret.setField(reservedCols.length, tuple2.f0);
+            return ret;
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input input node.
+         * @param topN top N.
+         * @return neighbor.
+         */
+        private Tuple2<List<Object>, List<Double>> findNeighbor(
+                Object input, Integer topN, KnnModelData modelData) {
+            PriorityQueue<Tuple2<Double, Object>> priorityQueue =
+                    new PriorityQueue<>(modelData.getQueueComparator());
+            search(input, topN, priorityQueue, modelData);
+            List<Object> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Object> result = priorityQueue.poll();
+                items.add(result.f1);
+                metrics.add(result.f0);
+            }
+            Collections.reverse(items);
+            Collections.reverse(metrics);
+            priorityQueue.clear();
+            return Tuple2.of(items, metrics);
+        }
+
+        /**
+         * @param input input node.
+         * @param topN top N.
+         * @param priorityQueue priority queue.
+         */
+        private void search(
+                Object input,
+                Integer topN,
+                PriorityQueue<Tuple2<Double, Object>> priorityQueue,
+                KnnModelData modelData) {
+            Tuple2<DenseVector, Double> sample = computeNorm((DenseVector) 
input);
+            Tuple2<Double, Object> head = null;
+            for (int i = 0; i < modelData.getLength(); i++) {
+                ArrayList<Tuple2<Double, Object>> values = 
computeDistance(sample, i);
+                for (Tuple2<Double, Object> currentValue : values) {
+                    head = updateQueue(priorityQueue, topN, currentValue, 
head);
+                }
+            }
+        }
+
+        /**
+         * Updates queue.
+         *
+         * @param map queue.
+         * @param topN top N.
+         * @param newValue new value.
+         * @param head head value.
+         * @param <T> id type.
+         * @return head value.
+         */
+        private <T> Tuple2<Double, T> updateQueue(
+                PriorityQueue<Tuple2<Double, T>> map,
+                int topN,
+                Tuple2<Double, T> newValue,
+                Tuple2<Double, T> head) {
+            if (map.size() < topN) {
+                map.add(Tuple2.of(newValue.f0, newValue.f1));
+                head = map.peek();
+            } else {
+                if (map.comparator().compare(head, newValue) < 0) {
+                    Tuple2<Double, T> peek = map.poll();
+                    assert peek != null;
+                    peek.f0 = newValue.f0;
+                    peek.f1 = newValue.f1;
+                    map.add(peek);
+                    head = map.peek();
+                }
+            }
+            return head;
+        }
+
+        /**
+         * Computes distance between sample and dictionary vectors.
+         *
+         * @param input sample with l2 norm.
+         * @param index dictionary vectors index.
+         * @return distances.
+         */
+        private ArrayList<Tuple2<Double, Object>> computeDistance(
+                Tuple2<DenseVector, Double> input, Integer index) {
+            Tuple3<DenseMatrix, DenseVector, String[]> data = 
modelData.getDictData().get(index);
+
+            DenseMatrix res = calc(input, data);
+            ArrayList<Tuple2<Double, Object>> list = new ArrayList<>(0);
+            String[] curLabels = data.f2;
+            for (int i = 0; i < Objects.requireNonNull(curLabels).length; i++) 
{
+                Tuple2<Double, Object> tuple = Tuple2.of(res.values[i], 
curLabels[i]);
+                list.add(tuple);
+            }
+            return list;
+        }
+
+        /** The blas used to accelerating speed. */
+        private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
+                (F2jBLAS) F2jBLAS.getInstance();
+
+        /**
+         * Compute distance between sample and dictionary vectors.
+         *
+         * @param left Sample and norm.
+         * @param right Dictionary vectors with row format.
+         * @return a new DenseMatrix which store the result distance.
+         */
+        public DenseMatrix calc(
+                Tuple2<DenseVector, Double> left,
+                Tuple3<DenseMatrix, DenseVector, String[]> right) {
+            DenseMatrix vectors = right.f0;
+            DenseMatrix res = new 
DenseMatrix(Objects.requireNonNull(vectors).numCols, 1);
+            DenseVector norm = right.f1;
+            double[] normL2Square = Objects.requireNonNull(norm).values;
+
+            final int m = vectors.numRows;
+            final int n = vectors.numCols;
+            NATIVE_BLAS.dgemv(
+                    "T", m, n, -2.0, vectors.values, m, left.f0.toArray(), 1, 
0.0, res.values, 1);
+
+            for (int i = 0; i < res.values.length; i++) {
+                res.values[i] = Math.sqrt(Math.abs(res.values[i] + left.f1 + 
normL2Square[i]));
+            }
+            return res;
+        }
+
+        /**
+         * Computes norm2 of vector.
+         *
+         * @return Sample with norm2.
+         */
+        public static Tuple2<DenseVector, Double> computeNorm(DenseVector 
vector) {

Review comment:
       done

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** Constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Fits data and produces knn model.
+     *
+     * @param inputs A list of tables
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String labelCol = getLabelCol();
+        String vecCol = getFeaturesCol();
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = 
String.valueOf(value.getField(labelCol));

Review comment:
       OK, I will try it.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** Constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Fits data and produces knn model.
+     *
+     * @param inputs A list of tables
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String labelCol = getLabelCol();
+        String vecCol = getFeaturesCol();
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = 
String.valueOf(value.getField(labelCol));

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to