Repository: spark
Updated Branches:
  refs/heads/master 86174ea89 -> e300a5a14


[SPARK-20300][ML][PYSPARK] Python API for ALSModel.recommendForAllUsers,Items

Add Python API for `ALSModel` methods `recommendForAllUsers`, 
`recommendForAllItems`

## How was this patch tested?

New doc tests.

Author: Nick Pentreath <ni...@za.ibm.com>

Closes #17622 from MLnick/SPARK-20300-pyspark-recall.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e300a5a1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e300a5a1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e300a5a1

Branch: refs/heads/master
Commit: e300a5a145820ecd466885c73245d6684e8cb0aa
Parents: 86174ea
Author: Nick Pentreath <ni...@za.ibm.com>
Authored: Tue May 2 10:49:13 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Tue May 2 10:49:13 2017 +0200

----------------------------------------------------------------------
 python/pyspark/ml/recommendation.py | 30 ++++++++++++++++++++++++++++++
 1 file changed, 30 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e300a5a1/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py 
b/python/pyspark/ml/recommendation.py
index 8bc899a..bcfb368 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, 
HasPredictionCol, Ha
     Row(user=1, item=0, prediction=2.6258413791656494)
     >>> predictions[2]
     Row(user=2, item=0, prediction=-1.5018409490585327)
+    >>> user_recs = model.recommendForAllUsers(3)
+    >>> user_recs.where(user_recs.user == 0)\
+        .select("recommendations.item", "recommendations.rating").collect()
+    [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])]
+    >>> item_recs = model.recommendForAllItems(3)
+    >>> item_recs.where(item_recs.item == 2)\
+        .select("recommendations.user", "recommendations.rating").collect()
+    [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])]
     >>> als_path = temp_path + "/als"
     >>> als.save(als_path)
     >>> als2 = ALS.load(als_path)
@@ -384,6 +392,28 @@ class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
         """
         return self._call_java("itemFactors")
 
+    @since("2.2.0")
+    def recommendForAllUsers(self, numItems):
+        """
+        Returns top `numItems` items recommended for each user, for all users.
+
+        :param numItems: max number of recommendations for each user
+        :return: a DataFrame of (userCol, recommendations), where 
recommendations are
+                 stored as an array of (itemCol, rating) Rows.
+        """
+        return self._call_java("recommendForAllUsers", numItems)
+
+    @since("2.2.0")
+    def recommendForAllItems(self, numUsers):
+        """
+        Returns top `numUsers` users recommended for each item, for all items.
+
+        :param numUsers: max number of recommendations for each item
+        :return: a DataFrame of (itemCol, recommendations), where 
recommendations are
+                 stored as an array of (userCol, rating) Rows.
+        """
+        return self._call_java("recommendForAllItems", numUsers)
+
 
 if __name__ == "__main__":
     import doctest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to