Repository: spark
Updated Branches:
  refs/heads/master 369a148e5 -> 70f9d7f71


[SPARK-19535][ML] RecommendForAllUsers RecommendForAllItems for ALS on Dataframe

## What changes were proposed in this pull request?

This is a simple implementation of RecommendForAllUsers & RecommendForAllItems 
for the Dataframe version of ALS. It uses Dataframe operations (not a wrapper 
on the RDD implementation). Haven't benchmarked against a wrapper, but unit 
test examples do work.

## How was this patch tested?

Unit tests
```
$ build/sbt
> mllib/testOnly *ALSSuite -- -z "recommendFor"
> mllib/testOnly
```

Author: Your Name <y...@example.com>
Author: sueann <sue...@databricks.com>

Closes #17090 from sueann/SPARK-19535.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70f9d7f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70f9d7f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70f9d7f7

Branch: refs/heads/master
Commit: 70f9d7f71c63d2b1fdfed75cb7a59285c272a62b
Parents: 369a148
Author: Sue Ann Hong <sue...@databricks.com>
Authored: Sun Mar 5 16:49:31 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Sun Mar 5 16:49:31 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 79 ++++++++++++++--
 .../ml/recommendation/TopByKeyAggregator.scala  | 60 +++++++++++++
 .../spark/ml/recommendation/ALSSuite.scala      | 94 ++++++++++++++++++++
 .../TopByKeyAggregatorSuite.scala               | 73 +++++++++++++++
 4 files changed, 297 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 799e881..60dd736 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.CholeskyDecomposition
 import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
@@ -284,18 +285,20 @@ class ALSModel private[ml] (
   @Since("2.2.0")
   def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, 
value)
 
+  private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) =>
+    if (featuresA != null && featuresB != null) {
+      // TODO(SPARK-19759): try dot-producting on Seqs or another 
non-converted type for
+      // potential optimization.
+      blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1)
+    } else {
+      Float.NaN
+    }
+  }
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
-    // Register a UDF for DataFrame, and then
     // create a new column named map(predictionCol) by running the predict UDF.
-    val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
-      if (userFeatures != null && itemFeatures != null) {
-        blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
-      } else {
-        Float.NaN
-      }
-    }
     val predictions = dataset
       .join(userFactors,
         checkedCast(dataset($(userCol))) === userFactors("id"), "left")
@@ -327,6 +330,64 @@ class ALSModel private[ml] (
 
   @Since("1.6.0")
   override def write: MLWriter = new ALSModel.ALSModelWriter(this)
+
+  /**
+   * Returns top `numItems` items recommended for each user, for all users.
+   * @param numItems max number of recommendations for each user
+   * @return a DataFrame of (userCol: Int, recommendations), where 
recommendations are
+   *         stored as an array of (itemCol: Int, rating: Float) Rows.
+   */
+  @Since("2.2.0")
+  def recommendForAllUsers(numItems: Int): DataFrame = {
+    recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
+  }
+
+  /**
+   * Returns top `numUsers` users recommended for each item, for all items.
+   * @param numUsers max number of recommendations for each item
+   * @return a DataFrame of (itemCol: Int, recommendations), where 
recommendations are
+   *         stored as an array of (userCol: Int, rating: Float) Rows.
+   */
+  @Since("2.2.0")
+  def recommendForAllItems(numUsers: Int): DataFrame = {
+    recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
+  }
+
+  /**
+   * Makes recommendations for all users (or items).
+   * @param srcFactors src factors for which to generate recommendations
+   * @param dstFactors dst factors used to make recommendations
+   * @param srcOutputColumn name of the column for the source ID in the output 
DataFrame
+   * @param dstOutputColumn name of the column for the destination ID in the 
output DataFrame
+   * @param num max number of recommendations for each record
+   * @return a DataFrame of (srcOutputColumn: Int, recommendations), where 
recommendations are
+   *         stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
+   */
+  private def recommendForAll(
+      srcFactors: DataFrame,
+      dstFactors: DataFrame,
+      srcOutputColumn: String,
+      dstOutputColumn: String,
+      num: Int): DataFrame = {
+    import srcFactors.sparkSession.implicits._
+
+    val ratings = srcFactors.crossJoin(dstFactors)
+      .select(
+        srcFactors("id"),
+        dstFactors("id"),
+        predict(srcFactors("features"), dstFactors("features")))
+    // We'll force the IDs to be Int. Unfortunately this converts IDs to Int 
in the output.
+    val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, 
Ordering.by(_._2))
+    val recs = ratings.as[(Int, Int, 
Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
+      .toDF("id", "recommendations")
+
+    val arrayType = ArrayType(
+      new StructType()
+        .add(dstOutputColumn, IntegerType)
+        .add("rating", FloatType)
+    )
+    recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
+  }
 }
 
 @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
new file mode 100644
index 0000000..517179c
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.{Encoder, Encoders}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.expressions.Aggregator
+import org.apache.spark.util.BoundedPriorityQueue
+
+
+/**
+ * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the 
score value. Finds
+ * the top `num` K2 items based on the given Ordering.
+ */
+private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: 
TypeTag]
+  (num: Int, ord: Ordering[(K2, V)])
+  extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, 
V)]] {
+
+  override def zero: BoundedPriorityQueue[(K2, V)] = new 
BoundedPriorityQueue[(K2, V)](num)(ord)
+
+  override def reduce(
+      q: BoundedPriorityQueue[(K2, V)],
+      a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = {
+    q += {(a._2, a._3)}
+  }
+
+  override def merge(
+      q1: BoundedPriorityQueue[(K2, V)],
+      q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = {
+    q1 ++= q2
+  }
+
+  override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = {
+    r.toArray.sorted(ord.reverse)
+  }
+
+  override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = {
+    Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
+  }
+
+  override def outputEncoder: Encoder[Array[(K2, V)]] = 
ExpressionEncoder[Array[(K2, V)]]()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index c8228dd..e494ea8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -22,6 +22,7 @@ import java.util.Random
 
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.WrappedArray
 import scala.collection.JavaConverters._
 import scala.language.existentials
 
@@ -660,6 +661,99 @@ class ALSSuite
       model.setColdStartStrategy(s).transform(data)
     }
   }
+
+  private def getALSModel = {
+    val spark = this.spark
+    import spark.implicits._
+
+    val userFactors = Seq(
+      (0, Array(6.0f, 4.0f)),
+      (1, Array(3.0f, 4.0f)),
+      (2, Array(3.0f, 6.0f))
+    ).toDF("id", "features")
+    val itemFactors = Seq(
+      (3, Array(5.0f, 6.0f)),
+      (4, Array(6.0f, 2.0f)),
+      (5, Array(3.0f, 6.0f)),
+      (6, Array(4.0f, 1.0f))
+    ).toDF("id", "features")
+    val als = new ALS().setRank(2)
+    new ALSModel(als.uid, als.getRank, userFactors, itemFactors)
+      .setUserCol("user")
+      .setItemCol("item")
+  }
+
+  test("recommendForAllUsers with k < num_items") {
+    val topItems = getALSModel.recommendForAllUsers(2)
+    assert(topItems.count() == 3)
+    assert(topItems.columns.contains("user"))
+
+    val expected = Map(
+      0 -> Array((3, 54f), (4, 44f)),
+      1 -> Array((3, 39f), (5, 33f)),
+      2 -> Array((3, 51f), (5, 45f))
+    )
+    checkRecommendations(topItems, expected, "item")
+  }
+
+  test("recommendForAllUsers with k = num_items") {
+    val topItems = getALSModel.recommendForAllUsers(4)
+    assert(topItems.count() == 3)
+    assert(topItems.columns.contains("user"))
+
+    val expected = Map(
+      0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
+      1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
+      2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
+    )
+    checkRecommendations(topItems, expected, "item")
+  }
+
+  test("recommendForAllItems with k < num_users") {
+    val topUsers = getALSModel.recommendForAllItems(2)
+    assert(topUsers.count() == 4)
+    assert(topUsers.columns.contains("item"))
+
+    val expected = Map(
+      3 -> Array((0, 54f), (2, 51f)),
+      4 -> Array((0, 44f), (2, 30f)),
+      5 -> Array((2, 45f), (0, 42f)),
+      6 -> Array((0, 28f), (2, 18f))
+    )
+    checkRecommendations(topUsers, expected, "user")
+  }
+
+  test("recommendForAllItems with k = num_users") {
+    val topUsers = getALSModel.recommendForAllItems(3)
+    assert(topUsers.count() == 4)
+    assert(topUsers.columns.contains("item"))
+
+    val expected = Map(
+      3 -> Array((0, 54f), (2, 51f), (1, 39f)),
+      4 -> Array((0, 44f), (2, 30f), (1, 26f)),
+      5 -> Array((2, 45f), (0, 42f), (1, 33f)),
+      6 -> Array((0, 28f), (2, 18f), (1, 16f))
+    )
+    checkRecommendations(topUsers, expected, "user")
+  }
+
+  private def checkRecommendations(
+      topK: DataFrame,
+      expected: Map[Int, Array[(Int, Float)]],
+      dstColName: String): Unit = {
+    val spark = this.spark
+    import spark.implicits._
+
+    assert(topK.columns.contains("recommendations"))
+    topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: 
Seq[(Int, Float)]) =>
+      assert(recs === expected(id))
+    }
+    topK.collect().foreach { row =>
+      val recs = row.getAs[WrappedArray[Row]]("recommendations")
+      assert(recs(0).fieldIndex(dstColName) == 0)
+      assert(recs(0).fieldIndex("rating") == 1)
+    }
+  }
 }
 
 class ALSCleanerSuite extends SparkFunSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/70f9d7f7/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
new file mode 100644
index 0000000..5e763a8
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+
+class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext 
{
+
+  private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = {
+    val sqlContext = spark.sqlContext
+    import sqlContext.implicits._
+
+    val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, 
Ordering.by(_._2))
+    Seq(
+      (0, 3, 54f),
+      (0, 4, 44f),
+      (0, 5, 42f),
+      (0, 6, 28f),
+      (1, 3, 39f),
+      (2, 3, 51f),
+      (2, 5, 45f),
+      (2, 6, 18f)
+    ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn)
+  }
+
+  test("topByKey with k < #items") {
+    val topK = getTopK(2)
+    assert(topK.count() === 3)
+
+    val expected = Map(
+      0 -> Array((3, 54f), (4, 44f)),
+      1 -> Array((3, 39f)),
+      2 -> Array((3, 51f), (5, 45f))
+    )
+    checkTopK(topK, expected)
+  }
+
+  test("topByKey with k > #items") {
+    val topK = getTopK(5)
+    assert(topK.count() === 3)
+
+    val expected = Map(
+      0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
+      1 -> Array((3, 39f)),
+      2 -> Array((3, 51f), (5, 45f), (6, 18f))
+    )
+    checkTopK(topK, expected)
+  }
+
+  private def checkTopK(
+      topK: Dataset[(Int, Array[(Int, Float)])],
+      expected: Map[Int, Array[(Int, Float)]]): Unit = {
+    topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to