[ https://issues.apache.org/jira/browse/FLINK-1745?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15049558#comment-15049558 ]
ASF GitHub Bot commented on FLINK-1745: --------------------------------------- Github user danielblazevski commented on a diff in the pull request: https://github.com/apache/flink/pull/1220#discussion_r47166230 --- Diff: flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/nn/QuadTree.scala --- @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.nn.util + +import org.apache.flink.ml.math.{Breeze, Vector} +import Breeze._ + +import org.apache.flink.ml.metrics.distances.{SquaredEuclideanDistanceMetric, +EuclideanDistanceMetric, DistanceMetric} + +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.PriorityQueue + +/** + * n-dimensional QuadTree data structure; partitions + * spatial data for faster queries (e.g. KNN query) + * The skeleton of the data structure was initially + * based off of the 2D Quadtree found here: + * http://www.cs.trinity.edu/~mlewis/CSCI1321-F11/Code/src/util/Quadtree.scala + * + * Many additional methods were added to the class both for + * efficient KNN queries and generalizing to n-dim. + * + * @param minVec vector of the corner of the bounding box with smallest coordinates + * @param maxVec vector of the corner of the bounding box with smallest coordinates + * @param distMetric metric, must be Euclidean or squareEuclidean + * @param maxPerBox threshold for number of points in each box before slitting a box + */ +class QuadTree(minVec: Vector, maxVec: Vector, distMetric: DistanceMetric, maxPerBox: Int){ + + class Node(center: Vector, width: Vector, var children: Seq[Node]) { + + val nodeElements = new ListBuffer[Vector] + + /** for testing purposes only; used in QuadTreeSuite.scala + * + * @return center and width of the box + */ + def getCenterWidth(): (Vector, Vector) = { + (center, width) + } + + def contains(queryPoint: Vector): Boolean = { + overlap(queryPoint, 0.0) + } + + /** Tests if queryPoint is within a radius of the node + * + * @param queryPoint + * @param radius + * @return + */ + def overlap(queryPoint: Vector, radius: Double): Boolean = { + var count = 0 + for (i <- 0 to queryPoint.size - 1) { + if (queryPoint(i) - radius < center(i) + width(i) / 2 && + queryPoint(i) + radius > center(i) - width(i) / 2) { + count += 1 + } + } + + if (count == queryPoint.size) { + true + } else { + false + } + } + + /** Tests if queryPoint is near a node + * + * @param queryPoint + * @param radius + * @return + */ + def isNear(queryPoint: Vector, radius: Double): Boolean = { + if (minDist(queryPoint) < radius) { + true + } else { + false + } + } + + /** + * used in error handling when computing minDist to make sure + * distMetric is Euclidean or SquaredEuclidean + * @param message + */ + case class metricException(message: String) extends Exception(message) + + /** + * minDist is defined so that every point in the box + * has distance to queryPoint greater than minDist + * (minDist adopted from "Nearest Neighbors Queries" by N. Roussopoulos et al.) + * + * @param queryPoint + * @return + */ + + def minDist(queryPoint: Vector): Double = { + var minDist = 0.0 + for (i <- 0 to queryPoint.size - 1) { + if (queryPoint(i) < center(i) - width(i) / 2) { + minDist += math.pow(queryPoint(i) - center(i) + width(i) / 2, 2) + } else if (queryPoint(i) > center(i) + width(i) / 2) { + minDist += math.pow(queryPoint(i) - center(i) - width(i) / 2, 2) + } + } + + if (distMetric.isInstanceOf[SquaredEuclideanDistanceMetric]) { + minDist + } else if (distMetric.isInstanceOf[EuclideanDistanceMetric]) { + math.sqrt(minDist) + } else{ + throw metricException(s" Error: metric must be Euclidean or SquaredEuclidean!") + } + } + + /** + * Finds which child queryPoint lies in. node.children is a Seq[Node], and + * whichChild finds the appropriate index of that Seq. + * @param queryPoint + * @return + */ + def whichChild(queryPoint: Vector): Int = { + var count = 0 + for (i <- 0 to queryPoint.size - 1) { + if (queryPoint(i) > center(i)) { + count += Math.pow(2, queryPoint.size -1 - i).toInt + } + } + count + } + + def makeChildren() { + val centerClone = center.copy + val cPart = partitionBox(centerClone, width) + val mappedWidth = 0.5*width.asBreeze + children = cPart.map(p => new Node(p, mappedWidth.fromBreeze, null)) + + } + + /** + * Recursive function that partitions a n-dim box by taking the (n-1) dimensional + * plane through the center of the box keeping the n-th coordinate fixed, + * then shifting it in the n-th direction up and down + * and recursively applying partitionBox to the two shifted (n-1) dimensional planes. + * + * @param center the center of the box + * @param width a vector of lengths of each dimension of the box + * @return + */ + def partitionBox(center: Vector, width: Vector): Seq[Vector] = { + + def partitionHelper(box: Seq[Vector], dim: Int): Seq[Vector] = { + if (dim >= width.size) { + box + } else { + val newBox = box.flatMap { + vector => + val (up, down) = (vector.copy, vector) + up.update(dim, up(dim) - width(dim) / 4) + down.update(dim, down(dim) + width(dim) / 4) + + Seq(up,down) + } + partitionHelper(newBox, dim + 1) + } + } + partitionHelper(Seq(center), 0) + } + } + + + val root = new Node( ((minVec.asBreeze + maxVec.asBreeze)*0.5).fromBreeze, + (maxVec.asBreeze - minVec.asBreeze).fromBreeze, null) + + /** + * Simple printing of tree for testing/debugging + */ + def printTree(): Unit = { + printTreeRecur(root) + } + + def printTreeRecur(node: Node){ + if(node.children != null) { + for (c <- node.children){ + printTreeRecur(c) + } + }else{ + println("printing tree: n.nodeElements " + node.nodeElements) + } + } + + /** + * Recursively adds an object to the tree + * @param queryPoint + */ + def insert(queryPoint: Vector){ + insertRecur(queryPoint,root) + } + + private def insertRecur(queryPoint: Vector,node: Node) { + if (node.children == null) { + if (node.nodeElements.length < maxPerBox ) { + node.nodeElements += queryPoint + } else{ + node.makeChildren() + for (o <- node.nodeElements){ + insertRecur(o, node.children(node.whichChild(o))) + } + node.nodeElements.clear() + insertRecur(queryPoint, node.children(node.whichChild(queryPoint))) + } + } else{ + insertRecur(queryPoint, node.children(node.whichChild(queryPoint))) + } + } + + /** + * Used to zoom in on a region near a test point for a fast KNN query. + * This capability is used in the KNN query to find k "near" neighbors n_1,...,n_k, from + * which one computes the max distance D_s to queryPoint. D_s is then used during the + * kNN query to find all points within a radius D_s of queryPoint using searchNeighbors. + * To find the "near" neighbors, a min-heap is defined on the leaf nodes of the leaf + * nodes of the minimal bounding box of the queryPoint. The priority of a leaf node + * is an appropriate notion of the distance between the test point and the node, + * which is defined by minDist(queryPoint), + * + * @param queryPoint + * @return + */ + def searchNeighborsSiblingQueue(queryPoint: Vector): ListBuffer[Vector] = { + var ret = new ListBuffer[Vector] + // edge case when the main box has not been partitioned at all + if (root.children == null) { + root.nodeElements.clone() + } else { + val nodeQueue = new PriorityQueue[(Double, Node)]()(Ordering.by(x => x._1)) + searchRecurSiblingQueue(queryPoint, root, nodeQueue) + + var count = 0 + while (count < maxPerBox) { + val dq = nodeQueue.dequeue() + if (dq._2.nodeElements.nonEmpty) { + ret ++= dq._2.nodeElements + count += dq._2.nodeElements.length + } + } + ret + } + } + + /** + * + * @param queryPoint + * @param node + * @param nodeQueue defined in searchSiblingQueue, this stores nodes based on their + * distance to node as defined by minDist + */ + private def searchRecurSiblingQueue(queryPoint: Vector, node: Node, + nodeQueue: PriorityQueue[(Double, Node)]) { + if (node.children != null) { + for (child <- node.children; if child.contains(queryPoint)) { + if (child.children == null) { + for (c <- node.children) { + MinNodes(queryPoint,c,nodeQueue) + } + } + else { + searchRecurSiblingQueue(queryPoint, child, nodeQueue) + } + } + } + } + + /** + * Goes down to minimal bounding box of queryPoint, and add elements to nodeQueue + * + * @param queryPoint + * @param node + * @param nodeQueue + */ + private def MinNodes(queryPoint: Vector, node: Node, nodeQueue: PriorityQueue[(Double, Node)]) { + if (node.children == null){ --- End diff -- done > Add exact k-nearest-neighbours algorithm to machine learning library > -------------------------------------------------------------------- > > Key: FLINK-1745 > URL: https://issues.apache.org/jira/browse/FLINK-1745 > Project: Flink > Issue Type: New Feature > Components: Machine Learning Library > Reporter: Till Rohrmann > Assignee: Daniel Blazevski > Labels: ML, Starter > > Even though the k-nearest-neighbours (kNN) [1,2] algorithm is quite trivial > it is still used as a mean to classify data and to do regression. This issue > focuses on the implementation of an exact kNN (H-BNLJ, H-BRJ) algorithm as > proposed in [2]. > Could be a starter task. > Resources: > [1] [http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm] > [2] [https://www.cs.utah.edu/~lifeifei/papers/mrknnj.pdf] -- This message was sent by Atlassian JIRA (v6.3.4#6332)