Repository: spark Updated Branches: refs/heads/master 86174ea89 -> e300a5a14
[SPARK-20300][ML][PYSPARK] Python API for ALSModel.recommendForAllUsers,Items Add Python API for `ALSModel` methods `recommendForAllUsers`, `recommendForAllItems` ## How was this patch tested? New doc tests. Author: Nick Pentreath <ni...@za.ibm.com> Closes #17622 from MLnick/SPARK-20300-pyspark-recall. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e300a5a1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e300a5a1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e300a5a1 Branch: refs/heads/master Commit: e300a5a145820ecd466885c73245d6684e8cb0aa Parents: 86174ea Author: Nick Pentreath <ni...@za.ibm.com> Authored: Tue May 2 10:49:13 2017 +0200 Committer: Nick Pentreath <ni...@za.ibm.com> Committed: Tue May 2 10:49:13 2017 +0200 ---------------------------------------------------------------------- python/pyspark/ml/recommendation.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e300a5a1/python/pyspark/ml/recommendation.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 8bc899a..bcfb368 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> user_recs = model.recommendForAllUsers(3) + >>> user_recs.where(user_recs.user == 0)\ + .select("recommendations.item", "recommendations.rating").collect() + [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])] + >>> item_recs = model.recommendForAllItems(3) + >>> item_recs.where(item_recs.item == 2)\ + .select("recommendations.user", "recommendations.rating").collect() + [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -384,6 +392,28 @@ class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("itemFactors") + @since("2.2.0") + def recommendForAllUsers(self, numItems): + """ + Returns top `numItems` items recommended for each user, for all users. + + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForAllUsers", numItems) + + @since("2.2.0") + def recommendForAllItems(self, numUsers): + """ + Returns top `numUsers` users recommended for each item, for all items. + + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForAllItems", numUsers) + if __name__ == "__main__": import doctest --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org