Repository: spark Updated Branches: refs/heads/master 5e96a57b2 -> d4a637cd4
[SPARK-19940][ML][MINOR] FPGrowthModel.transform should skip duplicated items ## What changes were proposed in this pull request? This commit moved `distinct` in its intended place to avoid duplicated predictions and adds unit test covering the issue. ## How was this patch tested? Unit tests. Author: zero323 <zero...@users.noreply.github.com> Closes #17283 from zero323/SPARK-19940. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4a637cd Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4a637cd Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4a637cd Branch: refs/heads/master Commit: d4a637cd46b6dd5cc71ea17a55c4a26186e592c7 Parents: 5e96a57 Author: zero323 <zero...@users.noreply.github.com> Authored: Tue Mar 14 07:34:44 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Mar 14 07:34:44 2017 -0700 ---------------------------------------------------------------------- .../main/scala/org/apache/spark/ml/fpm/FPGrowth.scala | 4 ++-- .../scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d4a637cd/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 417968d..fa39dd9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -245,10 +245,10 @@ class FPGrowthModel private[ml] ( rule._2.filter(item => !itemset.contains(item)) } else { Seq.empty - }) + }).distinct } else { Seq.empty - }.distinct }, dt) + }}, dt) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } http://git-wip-us.apache.org/repos/asf/spark/blob/d4a637cd/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 076d55c..910d4b0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul FPGrowthSuite.allParamSettings, checkModelData) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("features") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } } object FPGrowthSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org