Repository: spark
Updated Branches:
  refs/heads/master 5e96a57b2 -> d4a637cd4


[SPARK-19940][ML][MINOR] FPGrowthModel.transform should skip duplicated items

## What changes were proposed in this pull request?

This commit moved `distinct` in its intended place to avoid duplicated 
predictions and adds unit test covering the issue.

## How was this patch tested?

Unit tests.

Author: zero323 <zero...@users.noreply.github.com>

Closes #17283 from zero323/SPARK-19940.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4a637cd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4a637cd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4a637cd

Branch: refs/heads/master
Commit: d4a637cd46b6dd5cc71ea17a55c4a26186e592c7
Parents: 5e96a57
Author: zero323 <zero...@users.noreply.github.com>
Authored: Tue Mar 14 07:34:44 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Mar 14 07:34:44 2017 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/ml/fpm/FPGrowth.scala |  4 ++--
 .../scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++++++++++++
 2 files changed, 16 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4a637cd/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 417968d..fa39dd9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -245,10 +245,10 @@ class FPGrowthModel private[ml] (
             rule._2.filter(item => !itemset.contains(item))
           } else {
             Seq.empty
-          })
+          }).distinct
       } else {
         Seq.empty
-      }.distinct }, dt)
+      }}, dt)
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d4a637cd/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 076d55c..910d4b0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with 
MLlibTestSparkContext with Defaul
       FPGrowthSuite.allParamSettings, checkModelData)
   }
 
+  test("FPGrowth prediction should not contain duplicates") {
+    // This should generate rule 1 -> 3, 2 -> 3
+    val dataset = spark.createDataFrame(Seq(
+      Array("1", "3"),
+      Array("2", "3")
+    ).map(Tuple1(_))).toDF("features")
+    val model = new FPGrowth().fit(dataset)
+
+    val prediction = model.transform(
+      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
+    ).first().getAs[Seq[String]]("prediction")
+
+    assert(prediction === Seq("3"))
+  }
 }
 
 object FPGrowthSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to