Repository: spark
Updated Branches:
  refs/heads/master a59759e6c -> b28bbffba


[SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation 
and transform

## What changes were proposed in this pull request?

jira: https://issues.apache.org/jira/browse/SPARK-20003
I was doing some test and found the issue. ml.fpm.FPGrowthModel 
`setMinConfidence` should always affect rules generation and transform.
Currently associationRules in FPGrowthModel is a lazy val and 
`setMinConfidence` in FPGrowthModel has no impact once associationRules got 
computed .

I try to cache the associationRules to avoid re-computation if `minConfidence` 
is not changed, but this makes FPGrowthModel somehow stateful. Let me know if 
there's any concern.

## How was this patch tested?

new unit test and I strength the unit test for model save/load to ensure the 
cache mechanism.

Author: Yuhao Yang <yuhao.y...@intel.com>

Closes #17336 from hhbyyh/fpmodelminconf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b28bbffb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b28bbffb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b28bbffb

Branch: refs/heads/master
Commit: b28bbffbadf7ebc4349666e8f17111f6fca18c9a
Parents: a59759e
Author: Yuhao Yang <yuhao.y...@intel.com>
Authored: Tue Apr 4 17:51:45 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Apr 4 17:51:45 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/fpm/FPGrowth.scala      | 21 ++++++--
 .../org/apache/spark/ml/fpm/FPGrowthSuite.scala | 56 +++++++++++++-------
 2 files changed, 56 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b28bbffb/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 65cc806..d604c1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -218,13 +218,28 @@ class FPGrowthModel private[ml] (
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   /**
-   * Get association rules fitted by AssociationRules using the minConfidence. 
Returns a dataframe
+   * Cache minConfidence and associationRules to avoid redundant computation 
for association rules
+   * during transform. The associationRules will only be re-computed when 
minConfidence changed.
+   */
+  @transient private var _cachedMinConf: Double = Double.NaN
+
+  @transient private var _cachedRules: DataFrame = _
+
+  /**
+   * Get association rules fitted using the minConfidence. Returns a dataframe
    * with three fields, "antecedent", "consequent" and "confidence", where 
"antecedent" and
    * "consequent" are Array[T] and "confidence" is Double.
    */
   @Since("2.2.0")
-  @transient lazy val associationRules: DataFrame = {
-    AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", 
$(minConfidence))
+  @transient def associationRules: DataFrame = {
+    if ($(minConfidence) == _cachedMinConf) {
+      _cachedRules
+    } else {
+      _cachedRules = AssociationRules
+        .getAssociationRulesFromFP(freqItemsets, "items", "freq", 
$(minConfidence))
+      _cachedMinConf = $(minConfidence)
+      _cachedRules
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b28bbffb/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 4603a61..6bec057 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.ml.fpm
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions._
@@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
     
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
   }
 
+  test("FPGrowth prediction should not contain duplicates") {
+    // This should generate rule 1 -> 3, 2 -> 3
+    val dataset = spark.createDataFrame(Seq(
+      Array("1", "3"),
+      Array("2", "3")
+    ).map(Tuple1(_))).toDF("items")
+    val model = new FPGrowth().fit(dataset)
+
+    val prediction = model.transform(
+      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
+    ).first().getAs[Seq[String]]("prediction")
+
+    assert(prediction === Seq("3"))
+  }
+
+  test("FPGrowthModel setMinConfidence should affect rules generation and 
transform") {
+    val model = new 
FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
+    val oldRulesNum = model.associationRules.count()
+    val oldPredict = model.transform(dataset)
+
+    model.setMinConfidence(0.8765)
+    assert(oldRulesNum > model.associationRules.count())
+    
assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
+
+    // association rules should stay the same for same minConfidence
+    model.setMinConfidence(0.1)
+    assert(oldRulesNum === model.associationRules.count())
+    
assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
+  }
+
   test("FPGrowth parameter check") {
     val fpGrowth = new FPGrowth().setMinSupport(0.4567)
     val model = fpGrowth.fit(dataset)
       .setMinConfidence(0.5678)
     assert(fpGrowth.getMinSupport === 0.4567)
     assert(model.getMinConfidence === 0.5678)
+    MLTestingUtils.checkCopy(model)
   }
 
   test("read/write") {
     def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
-      assert(model.freqItemsets.sort("items").collect() ===
-        model2.freqItemsets.sort("items").collect())
+      assert(model.freqItemsets.collect().toSet.equals(
+        model2.freqItemsets.collect().toSet))
+      assert(model.associationRules.collect().toSet.equals(
+        model2.associationRules.collect().toSet))
+      
assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
+        model2.setMinConfidence(0.9).associationRules.collect().toSet))
     }
     val fPGrowth = new FPGrowth()
     testEstimatorAndModelReadWrite(fPGrowth, dataset, 
FPGrowthSuite.allParamSettings,
       FPGrowthSuite.allParamSettings, checkModelData)
   }
-
-  test("FPGrowth prediction should not contain duplicates") {
-    // This should generate rule 1 -> 3, 2 -> 3
-    val dataset = spark.createDataFrame(Seq(
-      Array("1", "3"),
-      Array("2", "3")
-    ).map(Tuple1(_))).toDF("items")
-    val model = new FPGrowth().fit(dataset)
-
-    val prediction = model.transform(
-      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
-    ).first().getAs[Seq[String]]("prediction")
-
-    assert(prediction === Seq("3"))
-  }
 }
 
 object FPGrowthSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to