Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/18748#discussion_r139161851 --- Diff: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala --- @@ -356,6 +371,40 @@ class ALSModel private[ml] ( } /** + * Returns top `numUsers` users recommended for each item id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of item ids. The column name must match `itemCol`. + * @param numUsers max number of recommendations for each item. + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, $(itemCol)) + recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Returns a subset of a factor DataFrame limited to only those unique ids contained + * in the input dataset. + * @param dataset input Dataset containing id column to user to filter factors. + * @param factors factor DataFrame to filter. + * @param column column name containing the ids in the input dataset. + * @return DataFrame containing factors only for those ids present in both the input dataset and + * the factor DataFrame. + */ + private def getSourceFactorSubset( + dataset: Dataset[_], + factors: DataFrame, + column: String): DataFrame = { + dataset.select(column) + .distinct() + .join(factors, dataset(column) === factors("id")) + .select(factors("id"), factors("features")) + } --- End diff -- Oh! But the order of table in left-semi-join matters: You should use `factors.join(dataset.select("user"), factors("id") === dataset("user"), "left_semi")` instead of `dataset.select("user").join(factors, dataset("user") === factors("id"), "left_semi")` they will generate different result. ``` scala> factors.join(dataset.select("user"), factors("id") === dataset("user"), "left_semi").show +---+--------+ | id|features| +---+--------+ | 0| [0, 1]| | 3| [3, 4]| +---+--------+ ```
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org