Repository: spark Updated Branches: refs/heads/master 07508bd01 -> 23405f324
[SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type ## What changes were proposed in this pull request? Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type. See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153). ## How was this patch tested? Add unit test. Author: Yanbo Liang <yblia...@gmail.com> Closes #15431 from yanboliang/spark-15153-2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23405f32 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23405f32 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23405f32 Branch: refs/heads/master Commit: 23405f324a8089f86ebcbede9bb32944137508e8 Parents: 07508bd Author: Yanbo Liang <yblia...@gmail.com> Authored: Tue Oct 11 12:41:35 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Oct 11 12:41:35 2016 -0700 ---------------------------------------------------------------------- R/pkg/inst/tests/testthat/test_mllib.R | 10 ++++++++++ .../scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala | 1 + 2 files changed, 11 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/23405f32/R/pkg/inst/tests/testthat/test_mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a1eaaf2..c993157 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -481,6 +481,16 @@ test_that("spark.naiveBayes", { expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) expect_equal(as.character(predict(m, t1[1, ])), "Yes") } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) }) test_that("spark.survreg", { http://git-wip-us.apache.org/repos/asf/spark/blob/23405f32/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index d1a39fe..4fdab2d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) + .setForceIndexLabel(true) RWrapperUtils.checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org