weibozhao commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r763728994
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,227 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * KNN is to classify unlabeled observations by assigning them to the class of the most similar + * labeled examples. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + protected Map<Param<?>, Object> params = new HashMap<>(); + + /** Constructor. */ + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Fits data and produces knn model. + * + * @param inputs A list of tables + * @return Knn model. + */ + @Override + public KnnModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + ResolvedSchema schema = inputs[0].getResolvedSchema(); + String[] colNames = schema.getColumnNames().toArray(new String[0]); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + String labelCol = getLabelCol(); + String vecCol = getFeaturesCol(); + + DataStream<Row> trainData = + input.map( + (MapFunction<Row, Row>) + value -> { + Object label = String.valueOf(value.getField(labelCol)); Review comment: OK, I will try it. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org