Repository: spark
Updated Branches:
  refs/heads/branch-2.2 72fca9a0a -> ca3f7edba


[SPARK-20587][ML] Improve performance of ML ALS recommendForAll

This PR is a `DataFrame` version of #17742 for 
[SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968), for improving 
the performance of `recommendAll` methods.

## How was this patch tested?

Existing unit tests.

Author: Nick Pentreath <ni...@za.ibm.com>

Closes #17845 from MLnick/ml-als-perf.

(cherry picked from commit 10b00abadf4a3473332eef996db7b66f491316f2)
Signed-off-by: Nick Pentreath <ni...@za.ibm.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca3f7edb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca3f7edb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca3f7edb

Branch: refs/heads/branch-2.2
Commit: ca3f7edbad6a2e7fcd1c1d3dbd1a522cd0d7c476
Parents: 72fca9a
Author: Nick Pentreath <ni...@za.ibm.com>
Authored: Tue May 9 10:13:15 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Tue May 9 10:13:36 2017 +0200

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 71 ++++++++++++++++++--
 1 file changed, 64 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca3f7edb/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index a20ef72..4a130e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
 import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, 
SortDataFormat, Sorter}
 import org.apache.spark.util.random.XORShiftRandom
 
@@ -356,6 +356,19 @@ class ALSModel private[ml] (
 
   /**
    * Makes recommendations for all users (or items).
+   *
+   * Note: the previous approach used for computing top-k recommendations
+   * used a cross-join followed by predicting a score for each row of the 
joined dataset.
+   * However, this results in exploding the size of intermediate data. While 
Spark SQL makes it
+   * relatively efficient, the approach implemented here is significantly more 
efficient.
+   *
+   * This approach groups factors into blocks and computes the top-k elements 
per block,
+   * using a simple dot product (instead of gemm) and an efficient 
[[BoundedPriorityQueue]].
+   * It then computes the global top-k by aggregating the per block top-k 
elements with
+   * a [[TopByKeyAggregator]]. This significantly reduces the size of 
intermediate and shuffle data.
+   * This is the DataFrame equivalent to the approach used in
+   * [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]].
+   *
    * @param srcFactors src factors for which to generate recommendations
    * @param dstFactors dst factors used to make recommendations
    * @param srcOutputColumn name of the column for the source ID in the output 
DataFrame
@@ -372,11 +385,43 @@ class ALSModel private[ml] (
       num: Int): DataFrame = {
     import srcFactors.sparkSession.implicits._
 
-    val ratings = srcFactors.crossJoin(dstFactors)
-      .select(
-        srcFactors("id"),
-        dstFactors("id"),
-        predict(srcFactors("features"), dstFactors("features")))
+    val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])])
+    val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])])
+    val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
+      .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
+      .flatMap { case (srcIter, dstIter) =>
+        val m = srcIter.size
+        val n = math.min(dstIter.size, num)
+        val output = new Array[(Int, Int, Float)](m * n)
+        var j = 0
+        val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
+        srcIter.foreach { case (srcId, srcFactor) =>
+          dstIter.foreach { case (dstId, dstFactor) =>
+            /*
+             * The below code is equivalent to
+             *    `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)`
+             * This handwritten version is as or more efficient as BLAS calls 
in this case.
+             */
+            var score = 0.0f
+            var k = 0
+            while (k < rank) {
+              score += srcFactor(k) * dstFactor(k)
+              k += 1
+            }
+            pq += dstId -> score
+          }
+          val pqIter = pq.iterator
+          var i = 0
+          while (i < n) {
+            val (dstId, score) = pqIter.next()
+            output(j + i) = (srcId, dstId, score)
+            i += 1
+          }
+          j += n
+          pq.clear()
+        }
+        output.toSeq
+      }
     // We'll force the IDs to be Int. Unfortunately this converts IDs to Int 
in the output.
     val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, 
Ordering.by(_._2))
     val recs = ratings.as[(Int, Int, 
Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
@@ -387,8 +432,20 @@ class ALSModel private[ml] (
         .add(dstOutputColumn, IntegerType)
         .add("rating", FloatType)
     )
-    recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
+    recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType))
   }
+
+  /**
+   * Blockifies factors to improve the efficiency of cross join
+   * TODO: SPARK-20443 - expose blockSize as a param?
+   */
+  private def blockify(
+      factors: Dataset[(Int, Array[Float])],
+      blockSize: Int = 4096): Dataset[Seq[(Int, Array[Float])]] = {
+    import factors.sparkSession.implicits._
+    factors.mapPartitions(_.grouped(blockSize))
+  }
+
 }
 
 @Since("1.6.0")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to