Repository: spark Updated Branches: refs/heads/master a59759e6c -> b28bbffba
[SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-20003 I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform. Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed . I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern. ## How was this patch tested? new unit test and I strength the unit test for model save/load to ensure the cache mechanism. Author: Yuhao Yang <yuhao.y...@intel.com> Closes #17336 from hhbyyh/fpmodelminconf. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b28bbffb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b28bbffb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b28bbffb Branch: refs/heads/master Commit: b28bbffbadf7ebc4349666e8f17111f6fca18c9a Parents: a59759e Author: Yuhao Yang <yuhao.y...@intel.com> Authored: Tue Apr 4 17:51:45 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Apr 4 17:51:45 2017 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/fpm/FPGrowth.scala | 21 ++++++-- .../org/apache/spark/ml/fpm/FPGrowthSuite.scala | 56 +++++++++++++------- 2 files changed, 56 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b28bbffb/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 65cc806..d604c1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -218,13 +218,28 @@ class FPGrowthModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) /** - * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe + * Cache minConfidence and associationRules to avoid redundant computation for association rules + * during transform. The associationRules will only be re-computed when minConfidence changed. + */ + @transient private var _cachedMinConf: Double = Double.NaN + + @transient private var _cachedRules: DataFrame = _ + + /** + * Get association rules fitted using the minConfidence. Returns a dataframe * with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and * "consequent" are Array[T] and "confidence" is Double. */ @Since("2.2.0") - @transient lazy val associationRules: DataFrame = { - AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + @transient def associationRules: DataFrame = { + if ($(minConfidence) == _cachedMinConf) { + _cachedRules + } else { + _cachedRules = AssociationRules + .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence)) + _cachedMinConf = $(minConfidence) + _cachedRules + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/b28bbffb/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 4603a61..6bec057 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ @@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("items") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } + + test("FPGrowthModel setMinConfidence should affect rules generation and transform") { + val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset) + val oldRulesNum = model.associationRules.count() + val oldPredict = model.transform(dataset) + + model.setMinConfidence(0.8765) + assert(oldRulesNum > model.associationRules.count()) + assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + + // association rules should stay the same for same minConfidence + model.setMinConfidence(0.1) + assert(oldRulesNum === model.associationRules.count()) + assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet)) + } + test("FPGrowth parameter check") { val fpGrowth = new FPGrowth().setMinSupport(0.4567) val model = fpGrowth.fit(dataset) .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + MLTestingUtils.checkCopy(model) } test("read/write") { def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = { - assert(model.freqItemsets.sort("items").collect() === - model2.freqItemsets.sort("items").collect()) + assert(model.freqItemsets.collect().toSet.equals( + model2.freqItemsets.collect().toSet)) + assert(model.associationRules.collect().toSet.equals( + model2.associationRules.collect().toSet)) + assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals( + model2.setMinConfidence(0.9).associationRules.collect().toSet)) } val fPGrowth = new FPGrowth() testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, FPGrowthSuite.allParamSettings, checkModelData) } - - test("FPGrowth prediction should not contain duplicates") { - // This should generate rule 1 -> 3, 2 -> 3 - val dataset = spark.createDataFrame(Seq( - Array("1", "3"), - Array("2", "3") - ).map(Tuple1(_))).toDF("items") - val model = new FPGrowth().fit(dataset) - - val prediction = model.transform( - spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items") - ).first().getAs[Seq[String]]("prediction") - - assert(prediction === Seq("3")) - } } object FPGrowthSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org