Repository: spark Updated Branches: refs/heads/master b678e465a -> 7aeb20be7
[MINOR][ML] Avoid 2D array flatten in NB training. ## What changes were proposed in this pull request? Avoid 2D array flatten in ```NaiveBayes``` training, since flatten method might be expensive (It will create another array and copy data there). ## How was this patch tested? Existing tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #15359 from yanboliang/nb-theta. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7aeb20be Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7aeb20be Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7aeb20be Branch: refs/heads/master Commit: 7aeb20be7e999523784aca7be1a7c9c99dec125e Parents: b678e46 Author: Yanbo Liang <yblia...@gmail.com> Authored: Wed Oct 5 23:03:09 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Wed Oct 5 23:03:09 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/NaiveBayes.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/7aeb20be/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 6775745..e565a6f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") ( val numLabels = aggregated.length val numDocuments = aggregated.map(_._2._1).sum - val piArray = Array.fill[Double](numLabels)(0.0) - val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0) + val piArray = new Array[Double](numLabels) + val thetaArray = new Array[Double](numLabels * numFeatures) val lambda = $(smoothing) val piLogDenom = math.log(numDocuments + numLabels * lambda) @@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") ( } var j = 0 while (j < numFeatures) { - thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom j += 1 } i += 1 } val pi = Vectors.dense(piArray) - val theta = new DenseMatrix(numLabels, thetaArrays(0).length, thetaArrays.flatten, true) + val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true) new NaiveBayesModel(uid, pi, theta) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org