Github user felixcheung commented on a diff in the pull request: https://github.com/apache/spark/pull/14384#discussion_r74868437 --- Diff: R/pkg/R/mllib.R --- @@ -632,3 +642,159 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) + + +#' Alternating Least Squares (ALS) for Collaborative Filtering +#' +#' \code{spark.als} learns latent factors in collaborative filtering via alternating least +#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict} +#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib: +#' Collaborative Filtering}. +#' Additional arguments can be passed to the methods. +#' \describe{ +#' \item{nonnegative}{logical value indicating whether to apply nonnegativity constraints. +#' Default: FALSE} +#' \item{implicitPrefs}{logical value indicating whether to use implicit preference. +#' Default: FALSE} +#' \item{alpha}{alpha parameter in the implicit preference formulation (>= 0). Default: 1.0} +#' \item{seed}{integer seed for random number generation. Default: 0} +#' \item{numUserBlocks}{number of user blocks used to parallelize computation (> 0). +#' Default: 10} +#' \item{numItemBlocks}{number of item blocks used to parallelize computation (> 0). +#' Default: 10} +#' \item{checkpointInterval}{number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' Default: 10} +#' } +#' +#' @param data A SparkDataFrame for training +#' @param ratingCol column name for ratings +#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers +#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers +#' @param rank rank of the matrix factorization (> 0) +#' @param reg regularization parameter (>= 0) +#' @param maxIter maximum number of iterations (>= 0) +#' @param ... additional named argument(s) such as \code{nonnegative}. +#' @return \code{spark.als} returns a fitted ALS model +#' @rdname spark.als +#' @aliases spark.als,SparkDataFrame +#' @name spark.als +#' @export +#' @examples +#' \dontrun{ +#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), +#' list(2, 1, 1.0), list(2, 2, 5.0)) +#' df <- createDataFrame(ratings, c("user", "item", "rating")) +#' model <- spark.als(df, "rating", "user", "item") +#' +#' # extract latent factors +#' stats <- summary(model) +#' userFactors <- stats$userFactors +#' itemFactors <- stats$itemFactors +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # set other arguments +#' modelS <- spark.als(df, "rating", "user", "item", rank = 20, +#' reg = 0.1, nonnegative = TRUE) +#' statsS <- summary(modelS) +#' } +#' @note spark.als since 2.1.0 +setMethod("spark.als", signature(data = "SparkDataFrame"), + function(data, ratingCol = "rating", userCol = "user", itemCol = "item", + rank = 10, reg = 1.0, maxIter = 10, ...) { + + if (!is.numeric(rank) || rank <= 0) { + stop("rank should be a positive number.") + } + if (!is.numeric(reg) || reg < 0) { + stop("reg should be a nonnegative number.") + } + if (!is.numeric(maxIter) || maxIter <= 0) { + stop("maxIter should be a positive number.") + } + + `%||%` <- function(a, b) if (!is.null(a)) a else b + + args <- list(...) + numUserBlocks <- args$numUserBlocks %||% 10 + numItemBlocks <- args$numItemBlocks %||% 10 + implicitPrefs <- args$implicitPrefs %||% FALSE + alpha <- args$alpha %||% 1.0 + nonnegative <- args$nonnegative %||% FALSE + checkpointInterval <- args$checkpointInterval %||% 10 + seed <- args$seed %||% 0 + + features <- array(c(ratingCol, userCol, itemCol)) + distParams <- array(as.integer(c(numUserBlocks, numItemBlocks, + checkpointInterval, seed))) + + jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", + "fit", data@sdf, features, as.integer(rank), + reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + distParams) + return(new("ALSModel", jobj = jobj)) + }) + +# Returns a summary of the ALS model produced by spark.als. + +#' @param object A fitted ALS model +#' @return \code{summary} returns a list containing the estimated user and item factors, +#' rank, regularization parameter and maximum number of iterations used in training +#' @rdname spark.als +#' @export +#' @note summary(ALSModel) since 2.1.0 +setMethod("summary", signature(object = "ALSModel"), +function(object, ...) { + jobj <- object@jobj + userFactors <- dataFrame(callJMethod(jobj, "rUserFactors")) + itemFactors <- dataFrame(callJMethod(jobj, "rItemFactors")) + rank <- callJMethod(jobj, "rRank") + regParam <- callJMethod(jobj, "rRegParam") + maxIter <- callJMethod(jobj, "rMaxIter") + return(list(userFactors = userFactors, itemFactors = itemFactors, rank = rank, + regParam = regParam, maxIter = maxIter)) +}) + + +# Makes predictions from an ALS model or a model produced by spark.als. + +#' @param newData A SparkDataFrame for testing +#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' @rdname spark.als --- End diff -- add @aliases
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org