Github user WeichenXu123 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/18748#discussion_r139161851
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
---
    @@ -356,6 +371,40 @@ class ALSModel private[ml] (
       }
     
       /**
    +   * Returns top `numUsers` users recommended for each item id in the 
input data set. Note that if
    +   * there are duplicate ids in the input dataset, only one set of 
recommendations per unique id
    +   * will be returned.
    +   * @param dataset a Dataset containing a column of item ids. The column 
name must match `itemCol`.
    +   * @param numUsers max number of recommendations for each item.
    +   * @return a DataFrame of (itemCol: Int, recommendations), where 
recommendations are
    +   *         stored as an array of (userCol: Int, rating: Float) Rows.
    +   */
    +  @Since("2.3.0")
    +  def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): 
DataFrame = {
    +    val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, 
$(itemCol))
    +    recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), 
numUsers)
    +  }
    +
    +  /**
    +   * Returns a subset of a factor DataFrame limited to only those unique 
ids contained
    +   * in the input dataset.
    +   * @param dataset input Dataset containing id column to user to filter 
factors.
    +   * @param factors factor DataFrame to filter.
    +   * @param column column name containing the ids in the input dataset.
    +   * @return DataFrame containing factors only for those ids present in 
both the input dataset and
    +   *         the factor DataFrame.
    +   */
    +  private def getSourceFactorSubset(
    +      dataset: Dataset[_],
    +      factors: DataFrame,
    +      column: String): DataFrame = {
    +    dataset.select(column)
    +      .distinct()
    +      .join(factors, dataset(column) === factors("id"))
    +      .select(factors("id"), factors("features"))
    +  }
    --- End diff --
    
    Oh! But the order of table in left-semi-join matters:
    
    You should use
    `factors.join(dataset.select("user"), factors("id") === dataset("user"), 
"left_semi")`
    instead of 
    `dataset.select("user").join(factors, dataset("user") === factors("id"), 
"left_semi")`
    
    they will generate different result.
    
    ```
    scala> factors.join(dataset.select("user"), factors("id") === 
dataset("user"), "left_semi").show
    +---+--------+
    | id|features|
    +---+--------+
    |  0|  [0, 1]|
    |  3|  [3, 4]|
    +---+--------+
    
    ```



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to