huaxingao commented on a change in pull request #26415: [SPARK-18409][ML] LSH approxNearestNeighbors should use approxQuantile instead of sort URL: https://github.com/apache/spark/pull/26415#discussion_r346986229
########## File path: mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala ########## @@ -137,14 +139,23 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType) val hashDistCol = hashDistUDF(col($(outputCol))) - // Compute threshold to get exact k elements. - // TODO: SPARK-18409: Use approxQuantile to get the threshold - val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors) - val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol)) - val hashThreshold = thresholdDataset.take(1).head.getDouble(0) - - // Filter the dataset where the hash value is less than the threshold. - modelDataset.filter(hashDistCol <= hashThreshold) + val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol) + var filtered: DataFrame = null + var requestedNum = numNearestNeighbors + do { + requestedNum *= 2 + if (requestedNum > modelDataset.count()) { + requestedNum = modelDataset.count().toInt + } + var quantile = requestedNum.toDouble / modelDataset.count() + var hashThreshold = modelDatasetWithDist.stat + .approxQuantile(distCol, Array(quantile), 0.001) + + // Filter the dataset where the hash value is less than the threshold. + filtered = modelDatasetWithDist.filter(hashDistCol <= hashThreshold(0)) Review comment: Seems to me that I have to filter out to find out if I can get enough number of the nearest neighbors. If not, I go back to the loop to double the quantile. I am debating if I should continue this PR. The purpose of this PR is to improve performance. If the first round of the loop doesn't get enough number of the nearest neighbors and we have to go into the loop multiple times, the performance could be worse than the original code. In the doc of ```approxNearestNeighbors```, it says ```Given a large dataset and an item, approximately find at most k items which have the closest distance to the item. ``` If this is true, then I guess we can just use a quantile that should yield 2x the number of results. If we get less than k elements, that's OK. However, the original implementation returns exact k elements. I am not sure if we can change the original behavior. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org