Github user smurching commented on a diff in the pull request: https://github.com/apache/spark/pull/19381#discussion_r148708939 --- Diff: mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala --- @@ -165,6 +165,35 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa Vector, NaiveBayesModel](model, testDataset) } + test("prediction on single instance") { + val nPoints = 1000 + val piArray = Array(0.5, 0.1, 0.4).map(math.log) + val thetaArray = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + val pi = Vectors.dense(piArray) + val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) + + val testDataset = + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") + val model = nb.fit(testDataset) + + validateModelFit(pi, theta, model) --- End diff -- Do we need lines 184-186? They seem unrelated to what we want to test (that `predict` produces the same result as `transform` on a single instance). Similarly, I don't think we need to create `piArray`, `thetaArray`, `pi`, `theta`, etc; this test should just fit a model on a dataset and compare the fitted model's `predict` and `transform` outputs.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org