[ https://issues.apache.org/jira/browse/FLINK-1745?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15288930#comment-15288930 ]
ASF GitHub Bot commented on FLINK-1745: --------------------------------------- Github user tillrohrmann commented on a diff in the pull request: https://github.com/apache/flink/pull/1220#discussion_r63698942 --- Diff: flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/nn/KNN.scala --- @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.nn + +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.scala.utils._ +import org.apache.flink.api.scala._ +import org.apache.flink.ml.common._ +import org.apache.flink.ml.math.{Vector => FlinkVector, DenseVector} +import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric, DistanceMetric, +EuclideanDistanceMetric} +import org.apache.flink.ml.pipeline.{FitOperation, PredictDataSetOperation, Predictor} +import org.apache.flink.util.Collector +import org.apache.flink.api.common.operators.base.CrossOperatorBase.CrossHint + +import scala.collection.immutable.Vector +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** Implements a k-nearest neighbor join. + * + * Calculates the `k`-nearest neighbor points in the training set for each point in the test set. + * + * @example + * {{{ + * val trainingDS: DataSet[Vector] = ... + * val testingDS: DataSet[Vector] = ... + * + * val knn = KNN() + * .setK(10) + * .setBlocks(5) + * .setDistanceMetric(EuclideanDistanceMetric()) + * + * knn.fit(trainingDS) + * + * val predictionDS: DataSet[(Vector, Array[Vector])] = knn.predict(testingDS) + * }}} + * + * =Parameters= + * + * - [[org.apache.flink.ml.nn.KNN.K]] + * Sets the K which is the number of selected points as neighbors. (Default value: '''5''') + * + * - [[org.apache.flink.ml.nn.KNN.DistanceMetric]] + * Sets the distance metric we use to calculate the distance between two points. If no metric is + * specified, then [[org.apache.flink.ml.metrics.distances.EuclideanDistanceMetric]] is used. + * (Default value: '''EuclideanDistanceMetric()''') + * + * - [[org.apache.flink.ml.nn.KNN.Blocks]] + * Sets the number of blocks into which the input data will be split. This number should be set + * at least to the degree of parallelism. If no value is specified, then the parallelism of the + * input [[DataSet]] is used as the number of blocks. (Default value: '''None''') + * + * - [[org.apache.flink.ml.nn.KNN.UseQuadTreeParam]] + * A boolean variable that whether or not to use a Quadtree to partition the training set + * to potentially simplify the KNN search. If no value is specified, the code will + * automatically decide whether or not to use a Quadtree. Use of a Quadtree scales well + * with the number of training and testing points, though poorly with the dimension. + * (Default value: ```None```) + * + * - [[org.apache.flink.ml.nn.KNN.SizeHint]] + * Specifies whether the training set or test set is small to optimize the cross + * product operation needed for the KNN search. If the training set is small + * this should be `CrossHint.FIRST_IS_SMALL` and set to `CrossHint.SECOND_IS_SMALL` + * if the test set is small. + * (Default value: ```None```) + * + */ + +class KNN extends Predictor[KNN] { + + import KNN._ + + var trainingSet: Option[DataSet[Block[FlinkVector]]] = None + + /** Sets K + * @param k the number of selected points as neighbors + */ + def setK(k: Int): KNN = { + require(k > 0, "K must be positive.") + parameters.add(K, k) + this + } + + /** Sets the distance metric + * @param metric the distance metric to calculate distance between two points + */ + def setDistanceMetric(metric: DistanceMetric): KNN = { + parameters.add(DistanceMetric, metric) + this + } + + /** Sets the number of data blocks/partitions + * @param n the number of data blocks + */ + def setBlocks(n: Int): KNN = { + require(n > 0, "Number of blocks must be positive.") + parameters.add(Blocks, n) + this + } + + /** + * Sets the Boolean variable that decides whether to use the QuadTree or not + */ + def setUseQuadTree(useQuadTree: Boolean): KNN = { + if (useQuadTree){ + require(parameters(DistanceMetric).isInstanceOf[SquaredEuclideanDistanceMetric] || + parameters(DistanceMetric).isInstanceOf[EuclideanDistanceMetric]) --- End diff -- What happens if we change the distance metric after we've activated the quad tree usage? Is this condition checked later on again? > Add exact k-nearest-neighbours algorithm to machine learning library > -------------------------------------------------------------------- > > Key: FLINK-1745 > URL: https://issues.apache.org/jira/browse/FLINK-1745 > Project: Flink > Issue Type: New Feature > Components: Machine Learning Library > Reporter: Till Rohrmann > Assignee: Daniel Blazevski > Labels: ML, Starter > > Even though the k-nearest-neighbours (kNN) [1,2] algorithm is quite trivial > it is still used as a mean to classify data and to do regression. This issue > focuses on the implementation of an exact kNN (H-BNLJ, H-BRJ) algorithm as > proposed in [2]. > Could be a starter task. > Resources: > [1] [http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm] > [2] [https://www.cs.utah.edu/~lifeifei/papers/mrknnj.pdf] -- This message was sent by Atlassian JIRA (v6.3.4#6332)