Repository: spark Updated Branches: refs/heads/master 51d3c854c -> b34f7665d
[SPARK-19825][R][ML] spark.ml R API for FPGrowth ## What changes were proposed in this pull request? Adds SparkR API for FPGrowth: [SPARK-19825](https://issues.apache.org/jira/browse/SPARK-19825): - `spark.fpGrowth` -model training. - `freqItemsets` and `associationRules` methods with new corresponding generics. - Scala helper: `org.apache.spark.ml.r. FPGrowthWrapper` - unit tests. ## How was this patch tested? Feature specific unit tests. Author: zero323 <zero...@users.noreply.github.com> Closes #17170 from zero323/SPARK-19825. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b34f7665 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b34f7665 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b34f7665 Branch: refs/heads/master Commit: b34f7665ddb0a40044b4c2bc7d351599c125cb13 Parents: 51d3c85 Author: zero323 <zero...@users.noreply.github.com> Authored: Mon Apr 3 23:42:04 2017 -0700 Committer: Felix Cheung <felixche...@apache.org> Committed: Mon Apr 3 23:42:04 2017 -0700 ---------------------------------------------------------------------- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 5 +- R/pkg/R/generics.R | 12 ++ R/pkg/R/mllib_fpm.R | 158 +++++++++++++++++++ R/pkg/R/mllib_utils.R | 2 + R/pkg/inst/tests/testthat/test_mllib_fpm.R | 83 ++++++++++ .../org/apache/spark/ml/r/FPGrowthWrapper.scala | 86 ++++++++++ .../scala/org/apache/spark/ml/r/RWrappers.scala | 2 + 8 files changed, 348 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/DESCRIPTION ---------------------------------------------------------------------- diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 00dde64..f475ee8 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -44,6 +44,7 @@ Collate: 'jvm.R' 'mllib_classification.R' 'mllib_clustering.R' + 'mllib_fpm.R' 'mllib_recommendation.R' 'mllib_regression.R' 'mllib_stat.R' http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/NAMESPACE ---------------------------------------------------------------------- diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c02046c..9b7e95c 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -66,7 +66,10 @@ exportMethods("glm", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", - "spark.svmLinear") + "spark.svmLinear", + "spark.fpGrowth", + "spark.freqItemsets", + "spark.associationRules") # Job group lifecycle management methods export("setJobGroup", http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/R/generics.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 80283e4..945676c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1445,6 +1445,18 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) + #' @param object a fitted ML model object. #' @param path the directory where the model is saved. #' @param ... additional argument(s) passed to the method. http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/R/mllib_fpm.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R new file mode 100644 index 0000000..96251b2 --- /dev/null +++ b/R/pkg/R/mllib_fpm.R @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration + +#' S4 class that represents a FPGrowthModel +#' +#' @param jobj a Java object reference to the backing Scala FPGrowthModel +#' @export +#' @note FPGrowthModel since 2.2.0 +setClass("FPGrowthModel", slots = list(jobj = "jobj")) + +#' FP-growth +#' +#' A parallel FP-growth algorithm to mine frequent itemsets. +#' For more details, see +#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ +#' FP-growth}. +#' +#' @param data A SparkDataFrame for training. +#' @param minSupport Minimal support level. +#' @param minConfidence Minimal confidence level. +#' @param itemsCol Features column name. +#' @param numPartitions Number of partitions used for fitting. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model. +#' @rdname spark.fpGrowth +#' @name spark.fpGrowth +#' @aliases spark.fpGrowth,SparkDataFrame-method +#' @export +#' @examples +#' \dontrun{ +#' raw_data <- read.df( +#' "data/mllib/sample_fpgrowth.txt", +#' source = "csv", +#' schema = structType(structField("raw_items", "string"))) +#' +#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items") +#' model <- spark.fpGrowth(data) +#' +#' # Show frequent itemsets +#' frequent_itemsets <- spark.freqItemsets(model) +#' showDF(frequent_itemsets) +#' +#' # Show association rules +#' association_rules <- spark.associationRules(model) +#' showDF(association_rules) +#' +#' # Predict on new data +#' new_itemsets <- data.frame(items = c("t", "t,s")) +#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items") +#' predict(model, new_data) +#' +#' # Save and load model +#' path <- "/path/to/model" +#' write.ml(model, path) +#' read.ml(path) +#' +#' # Optional arguments +#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets") +#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5, +#' itemsCol = "baskets", numPartitions = 10) +#' } +#' @note spark.fpGrowth since 2.2.0 +setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), + function(data, minSupport = 0.3, minConfidence = 0.8, + itemsCol = "items", numPartitions = NULL) { + if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) { + stop("minSupport should be a number [0, 1].") + } + if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) { + stop("minConfidence should be a number [0, 1].") + } + if (!is.null(numPartitions)) { + numPartitions <- as.integer(numPartitions) + stopifnot(numPartitions > 0) + } + + jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit", + data@sdf, as.numeric(minSupport), as.numeric(minConfidence), + itemsCol, numPartitions) + new("FPGrowthModel", jobj = jobj) + }) + +# Get frequent itemsets. + +#' @param object a fitted FPGrowth model. +#' @return A \code{SparkDataFrame} with frequent itemsets. +#' The \code{SparkDataFrame} contains two columns: +#' \code{items} (an array of the same type as the input column) +#' and \code{freq} (frequency of the itemset). +#' @rdname spark.fpGrowth +#' @aliases freqItemsets,FPGrowthModel-method +#' @export +#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 +setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "freqItemsets")) + }) + +# Get association rules. + +#' @return A \code{SparkDataFrame} with association rules. +#' The \code{SparkDataFrame} contains three columns: +#' \code{antecedent} (an array of the same type as the input column), +#' \code{consequent} (an array of the same type as the input column), +#' and \code{condfidence} (confidence). +#' @rdname spark.fpGrowth +#' @aliases associationRules,FPGrowthModel-method +#' @export +#' @note spark.associationRules(FPGrowthModel) since 2.2.0 +setMethod("spark.associationRules", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "associationRules")) + }) + +# Makes predictions based on generated association rules + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.fpGrowth +#' @aliases predict,FPGrowthModel-method +#' @export +#' @note predict(FPGrowthModel) since 2.2.0 +setMethod("predict", signature(object = "FPGrowthModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the FPGrowth model to the output path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' @rdname spark.fpGrowth +#' @aliases write.ml,FPGrowthModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(FPGrowthModel, character) since 2.2.0 +setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/R/mllib_utils.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 04a0a6f..5dfef86 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -118,6 +118,8 @@ read.ml <- function(path) { new("BisectingKMeansModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) { new("LinearSVCModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) { + new("FPGrowthModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/inst/tests/testthat/test_mllib_fpm.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R new file mode 100644 index 0000000..c38f113 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib frequent pattern mining") + +# Tests for MLlib frequent pattern mining algorithms in SparkR +sparkSession <- sparkR.session(enableHiveSupport = FALSE) + +test_that("spark.fpGrowth", { + data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,2", + "1,2,3", + "1,3" + ))), "split(items, ',') as items") + + model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, numPartitions = 1) + + itemsets <- collect(spark.freqItemsets(model)) + + expected_itemsets <- data.frame( + items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), list("1"))), + freq = c(2, 2, 3, 3, 4) + ) + + expect_equivalent(expected_itemsets, itemsets) + + expected_association_rules <- data.frame( + antecedent = I(list(list("2"), list("3"))), + consequent = I(list(list("1"), list("1"))), + confidence = c(1, 1) + ) + + expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) + + new_data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,3", + "2,3" + ))), "split(items, ',') as items") + + expected_predictions <- data.frame( + items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))), + prediction = I(list(list(), list(), list("1"))) + ) + + expect_equivalent(expected_predictions, collect(predict(model, new_data))) + + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) + + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) + + unlink(modelPath) + + model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) + expect_equal( + count(spark.freqItemsets(model_without_numpartitions)), + count(spark.freqItemsets(model)) + ) + +}) + +sparkR.session.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala new file mode 100644 index 0000000..b8151d8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.fpm.{FPGrowth, FPGrowthModel} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class FPGrowthWrapper private (val fpGrowthModel: FPGrowthModel) extends MLWritable { + def freqItemsets: DataFrame = fpGrowthModel.freqItemsets + def associationRules: DataFrame = fpGrowthModel.associationRules + + def transform(dataset: Dataset[_]): DataFrame = { + fpGrowthModel.transform(dataset) + } + + override def write: MLWriter = new FPGrowthWrapper.FPGrowthWrapperWriter(this) +} + +private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] { + + def fit( + data: DataFrame, + minSupport: Double, + minConfidence: Double, + itemsCol: String, + numPartitions: Integer): FPGrowthWrapper = { + val fpGrowth = new FPGrowth() + .setMinSupport(minSupport) + .setMinConfidence(minConfidence) + .setItemsCol(itemsCol) + + if (numPartitions != null && numPartitions > 0) { + fpGrowth.setNumPartitions(numPartitions) + } + + val fpGrowthModel = fpGrowth.fit(data) + + new FPGrowthWrapper(fpGrowthModel) + } + + override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader + + class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] { + override def load(path: String): FPGrowthWrapper = { + val modelPath = new Path(path, "model").toString + val fPGrowthModel = FPGrowthModel.load(modelPath) + + new FPGrowthWrapper(fPGrowthModel) + } + } + + class FPGrowthWrapperWriter(instance: FPGrowthWrapper) extends MLWriter { + override protected def saveImpl(path: String): Unit = { + val modelPath = new Path(path, "model").toString + val rMetadataPath = new Path(path, "rMetadata").toString + + val rMetadataJson: String = compact(render( + "class" -> instance.getClass.getName + )) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.fpGrowthModel.save(modelPath) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 358e522..b30ce12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -68,6 +68,8 @@ private[r] object RWrappers extends MLReader[Object] { BisectingKMeansWrapper.load(path) case "org.apache.spark.ml.r.LinearSVCWrapper" => LinearSVCWrapper.load(path) + case "org.apache.spark.ml.r.FPGrowthWrapper" => + FPGrowthWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org