Repository: spark
Updated Branches:
  refs/heads/master 70beb808e -> 3336c7b14


[SPARK-8559] [MLLIB] Support Association Rule Generation

Distributed generation of single-consequent association rules from a RDD of 
frequent itemsets. Tests referenced against `R`'s implementation of A Priori in 
[arules](http://cran.r-project.org/web/packages/arules/index.html).

Author: Feynman Liang <fli...@databricks.com>

Closes #7005 from feynmanliang/fp-association-rules-distributed and squashes 
the following commits:

466ced0 [Feynman Liang] Refactor AR generation impl
73c1cff [Feynman Liang] Make rule attributes public, remove numTransactions 
from FreqItemset
80f63ff [Feynman Liang] Change default confidence and optimize imports
04cf5b5 [Feynman Liang] Code review with @mengxr, add R to tests
0cc1a6a [Feynman Liang] Java compatibility test
f3c14b5 [Feynman Liang] Fix MiMa test
764375e [Feynman Liang] Fix tests
1187307 [Feynman Liang] Almost working tests
b20779b [Feynman Liang] Working implementation
5395c4e [Feynman Liang] Fix imports
2d34405 [Feynman Liang] Partial implementation of distributed ar
83ace4b [Feynman Liang] Local rule generation without pruning complete
69c2c87 [Feynman Liang] Working local implementation, now to parallelize../..
4e1ec9a [Feynman Liang] Pull FreqItemsets out, refactor type param, tests
69ccedc [Feynman Liang] First implementation of association rule generation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3336c7b1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3336c7b1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3336c7b1

Branch: refs/heads/master
Commit: 3336c7b148ad543d1f9b64ca2b559ea04930f5be
Parents: 70beb80
Author: Feynman Liang <fli...@databricks.com>
Authored: Tue Jul 7 11:34:30 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Tue Jul 7 11:34:30 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/fpm/AssociationRules.scala      | 108 +++++++++++++++++++
 .../org/apache/spark/mllib/fpm/FPGrowth.scala   |   2 +-
 .../mllib/fpm/JavaAssociationRulesSuite.java    |  58 ++++++++++
 .../spark/mllib/fpm/JavaFPGrowthSuite.java      |   5 +-
 .../spark/mllib/fpm/AssociationRulesSuite.scala |  89 +++++++++++++++
 5 files changed, 258 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3336c7b1/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
new file mode 100644
index 0000000..4a0f842
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.mllib.fpm.AssociationRules.Rule
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ *
+ * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method 
only generates
+ * association rules which have a single item as the consequent.
+ */
+@Experimental
+class AssociationRules private (
+    private var minConfidence: Double) extends Logging with Serializable {
+
+  /**
+   * Constructs a default instance with default parameters {minConfidence = 
0.8}.
+   */
+  def this() = this(0.8)
+
+  /**
+   * Sets the minimal confidence (default: `0.8`).
+   */
+  def setMinConfidence(minConfidence: Double): this.type = {
+    this.minConfidence = minConfidence
+    this
+  }
+
+  /**
+   * Computes the association rules with confidence above [[minConfidence]].
+   * @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
+   * @return a [[Set[Rule[Item]]] containing the assocation rules.
+   */
+  def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): 
RDD[Rule[Item]] = {
+    // For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
+    val candidates = freqItemsets.flatMap { itemset =>
+      val items = itemset.items
+      items.flatMap { item =>
+        items.partition(_ == item) match {
+          case (consequent, antecedent) if !antecedent.isEmpty =>
+            Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
+          case _ => None
+        }
+      }
+    }
+
+    // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and 
filter by confidence
+    candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
+      .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
+      new Rule(antecendent.toArray, consequent.toArray, freqUnion, 
freqAntecedent)
+    }.filter(_.confidence >= minConfidence)
+  }
+
+  def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] 
= {
+    val tag = fakeClassTag[Item]
+    run(freqItemsets.rdd)(tag)
+  }
+}
+
+object AssociationRules {
+
+  /**
+   * :: Experimental ::
+   *
+   * An association rule between sets of items.
+   * @param antecedent hypotheses of the rule
+   * @param consequent conclusion of the rule
+   * @tparam Item item type
+   */
+  @Experimental
+  class Rule[Item] private[mllib] (
+      val antecedent: Array[Item],
+      val consequent: Array[Item],
+      freqUnion: Double,
+      freqAntecedent: Double) extends Serializable {
+
+    def confidence: Double = freqUnion.toDouble / freqAntecedent
+
+    require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
+      val sharedItems = antecedent.toSet.intersect(consequent.toSet)
+      s"A valid association rule must have disjoint antecedent and " +
+        s"consequent but ${sharedItems} is present in both."
+    })
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3336c7b1/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index efa8459..0da59e8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, 
Partitioner, SparkException}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
-import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
+import org.apache.spark.mllib.fpm.FPGrowth._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3336c7b1/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
new file mode 100644
index 0000000..b3815ae
--- /dev/null
+++ 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm;
+
+import java.io.Serializable;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
+
+
+public class JavaAssociationRulesSuite implements Serializable {
+  private transient JavaSparkContext sc;
+
+  @Before
+  public void setUp() {
+    sc = new JavaSparkContext("local", "JavaFPGrowth");
+  }
+
+  @After
+  public void tearDown() {
+    sc.stop();
+    sc = null;
+  }
+
+  @Test
+  public void runAssociationRules() {
+
+    @SuppressWarnings("unchecked")
+    JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = 
sc.parallelize(Lists.newArrayList(
+      new FreqItemset<String>(new String[] {"a"}, 15L),
+      new FreqItemset<String>(new String[] {"b"}, 35L),
+      new FreqItemset<String>(new String[] {"a", "b"}, 18L)
+    ));
+
+    JavaRDD<AssociationRules.Rule<String>> results = (new 
AssociationRules()).run(freqItemsets);
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/3336c7b1/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java 
b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index bd0edf2..9ce2c52 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -29,7 +29,6 @@ import static org.junit.Assert.*;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
 
 public class JavaFPGrowthSuite implements Serializable {
   private transient JavaSparkContext sc;
@@ -62,10 +61,10 @@ public class JavaFPGrowthSuite implements Serializable {
       .setNumPartitions(2)
       .run(rdd);
 
-    List<FreqItemset<String>> freqItemsets = 
model.freqItemsets().toJavaRDD().collect();
+    List<FPGrowth.FreqItemset<String>> freqItemsets = 
model.freqItemsets().toJavaRDD().collect();
     assertEquals(18, freqItemsets.size());
 
-    for (FreqItemset<String> itemset: freqItemsets) {
+    for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
       // Test return types.
       List<String> items = itemset.javaItems();
       long freq = itemset.freq();

http://git-wip-us.apache.org/repos/asf/spark/blob/3336c7b1/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala 
b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
new file mode 100644
index 0000000..77a2773
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("association rules using String type") {
+    val freqItemsets = sc.parallelize(Seq(
+      (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), 
(Set("y"), 3L),
+      (Set("r"), 3L),
+      (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", 
"x"), 3L),
+      (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
+      (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 
3L),
+      (Set("t", "y", "x"), 3L),
+      (Set("t", "y", "x", "z"), 3L)
+    ).map {
+      case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
+    })
+
+    val ar = new AssociationRules()
+
+    val results1 = ar
+      .setMinConfidence(0.9)
+      .run(freqItemsets)
+      .collect()
+
+    /* Verify results using the `R` code:
+       transactions = as(sapply(
+         list("r z h k p",
+              "z y x w v u t s",
+              "s x o n r",
+              "x z y m t s q e",
+              "z",
+              "x z y r q t p"),
+         FUN=function(x) strsplit(x," ",fixed=TRUE)),
+         "transactions")
+       ars = apriori(transactions,
+                     parameter = list(support = 0.0, confidence = 0.5, 
target="rules", minlen=2))
+       arsDF = as(ars, "data.frame")
+       arsDF$support = arsDF$support * length(transactions)
+       names(arsDF)[names(arsDF) == "support"] = "freq"
+       > nrow(arsDF)
+       [1] 23
+       > sum(arsDF$confidence == 1)
+       [1] 23
+     */
+    assert(results1.size === 23)
+    assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 
23)
+
+    val results2 = ar
+      .setMinConfidence(0)
+      .run(freqItemsets)
+      .collect()
+
+    /* Verify results using the `R` code:
+       ars = apriori(transactions,
+                  parameter = list(support = 0.5, confidence = 0.5, 
target="rules", minlen=2))
+       arsDF = as(ars, "data.frame")
+       arsDF$support = arsDF$support * length(transactions)
+       names(arsDF)[names(arsDF) == "support"] = "freq"
+       nrow(arsDF)
+       sum(arsDF$confidence == 1)
+       > nrow(arsDF)
+       [1] 30
+       > sum(arsDF$confidence == 1)
+       [1] 23
+     */
+    assert(results2.size === 30)
+    assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 
23)
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to