Repository: spark
Updated Branches:
  refs/heads/master b678e465a -> 7aeb20be7


[MINOR][ML] Avoid 2D array flatten in NB training.

## What changes were proposed in this pull request?
Avoid 2D array flatten in ```NaiveBayes``` training, since flatten method might 
be expensive (It will create another array and copy data there).

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #15359 from yanboliang/nb-theta.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7aeb20be
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7aeb20be
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7aeb20be

Branch: refs/heads/master
Commit: 7aeb20be7e999523784aca7be1a7c9c99dec125e
Parents: b678e46
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Wed Oct 5 23:03:09 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Wed Oct 5 23:03:09 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/NaiveBayes.scala      | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7aeb20be/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 6775745..e565a6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -176,8 +176,8 @@ class NaiveBayes @Since("1.5.0") (
     val numLabels = aggregated.length
     val numDocuments = aggregated.map(_._2._1).sum
 
-    val piArray = Array.fill[Double](numLabels)(0.0)
-    val thetaArrays = Array.fill[Double](numLabels, numFeatures)(0.0)
+    val piArray = new Array[Double](numLabels)
+    val thetaArray = new Array[Double](numLabels * numFeatures)
 
     val lambda = $(smoothing)
     val piLogDenom = math.log(numDocuments + numLabels * lambda)
@@ -193,14 +193,14 @@ class NaiveBayes @Since("1.5.0") (
       }
       var j = 0
       while (j < numFeatures) {
-        thetaArrays(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+        thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - 
thetaLogDenom
         j += 1
       }
       i += 1
     }
 
     val pi = Vectors.dense(piArray)
-    val theta = new DenseMatrix(numLabels, thetaArrays(0).length, 
thetaArrays.flatten, true)
+    val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
     new NaiveBayesModel(uid, pi, theta)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to