Repository: spark
Updated Branches:
  refs/heads/master faeb41de2 -> 84324fbcb


[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero

See inline comment about the bug. I also did some code clean-up. dbtsai I moved 
`update` to a private method of `MultivariateOnlineSummarizer`. I don't think 
it will cause performance regression, but it would be great if you have some 
time to test.

Author: Xiangrui Meng <m...@databricks.com>

Closes #3220 from mengxr/SPARK-4355 and squashes the following commits:

5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and 
some code clean-up


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84324fbc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84324fbc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84324fbc

Branch: refs/heads/master
Commit: 84324fbcb987db6e10e435f463eacace1bae43e2
Parents: faeb41d
Author: Xiangrui Meng <m...@databricks.com>
Authored: Wed Nov 12 01:50:11 2014 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Nov 12 01:50:11 2014 -0800

----------------------------------------------------------------------
 .../stat/MultivariateOnlineSummarizer.scala     | 85 +++++++++-----------
 .../MultivariateOnlineSummarizerSuite.scala     | 11 +++
 2 files changed, 51 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84324fbc/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index fab7c44..654479a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -50,6 +50,29 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
   private var currMin: BDV[Double] = _
 
   /**
+   * Adds input value to position i.
+   */
+  private[this] def add(i: Int, value: Double) = {
+    if (value != 0.0) {
+      if (currMax(i) < value) {
+        currMax(i) = value
+      }
+      if (currMin(i) > value) {
+        currMin(i) = value
+      }
+
+      val prevMean = currMean(i)
+      val diff = value - prevMean
+      currMean(i) = prevMean + diff / (nnz(i) + 1.0)
+      currM2n(i) += (value - currMean(i)) * diff
+      currM2(i) += value * value
+      currL1(i) += math.abs(value)
+
+      nnz(i) += 1.0
+    }
+  }
+
+  /**
    * Add a new sample to this summarizer, and update the statistical summary.
    *
    * @param sample The sample in dense/sparse vector format to be added into 
this summarizer.
@@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
     require(n == sample.size, s"Dimensions mismatch when adding new sample." +
       s" Expecting $n but got ${sample.size}.")
 
-    @inline def update(i: Int, value: Double) = {
-      if (value != 0.0) {
-        if (currMax(i) < value) {
-          currMax(i) = value
-        }
-        if (currMin(i) > value) {
-          currMin(i) = value
-        }
-
-        val tmpPrevMean = currMean(i)
-        currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
-        currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
-        currM2(i) += value * value
-        currL1(i) += math.abs(value)
-
-        nnz(i) += 1.0
-      }
-    }
-
     sample match {
       case dv: DenseVector => {
         var j = 0
         while (j < dv.size) {
-          update(j, dv.values(j))
+          add(j, dv.values(j))
           j += 1
         }
       }
       case sv: SparseVector =>
         var j = 0
         while (j < sv.indices.size) {
-          update(sv.indices(j), sv.values(j))
+          add(sv.indices(j), sv.values(j))
           j += 1
         }
       case v => throw new IllegalArgumentException("Do not support vector type 
" + v.getClass)
@@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       require(n == other.n, s"Dimensions mismatch when merging with another 
summarizer. " +
         s"Expecting $n but got ${other.n}.")
       totalCnt += other.totalCnt
-      val deltaMean: BDV[Double] = currMean - other.currMean
       var i = 0
       while (i < n) {
-        // merge mean together
-        if (other.currMean(i) != 0.0) {
-          currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * 
other.nnz(i)) /
-            (nnz(i) + other.nnz(i))
-        }
-        // merge m2n together
-        if (nnz(i) + other.nnz(i) != 0.0) {
-          currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * 
nnz(i) * other.nnz(i) /
-            (nnz(i) + other.nnz(i))
-        }
-        // merge m2 together
-        if (nnz(i) + other.nnz(i) != 0.0) {
+        val thisNnz = nnz(i)
+        val otherNnz = other.nnz(i)
+        val totalNnz = thisNnz + otherNnz
+        if (totalNnz != 0.0) {
+          val deltaMean = other.currMean(i) - currMean(i)
+          // merge mean together
+          currMean(i) += deltaMean * otherNnz / totalNnz
+          // merge m2n together
+          currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * 
otherNnz / totalNnz
+          // merge m2 together
           currM2(i) += other.currM2(i)
-        }
-        // merge l1 together
-        if (nnz(i) + other.nnz(i) != 0.0) {
+          // merge l1 together
           currL1(i) += other.currL1(i)
+          // merge max and min
+          currMax(i) = math.max(currMax(i), other.currMax(i))
+          currMin(i) = math.min(currMin(i), other.currMin(i))
         }
-
-        if (currMax(i) < other.currMax(i)) {
-          currMax(i) = other.currMax(i)
-        }
-        if (currMin(i) > other.currMin(i)) {
-          currMin(i) = other.currMin(i)
-        }
+        nnz(i) = totalNnz
         i += 1
       }
-      nnz += other.nnz
     } else if (totalCnt == 0 && other.totalCnt != 0) {
       this.n = other.n
       this.currMean = other.currMean.copy

http://git-wip-us.apache.org/repos/asf/spark/blob/84324fbc/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 1e94152..23b0eec 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
 
     assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, 
"variance mismatch")
   }
+
+  test("merging summarizer when one side has zero mean (SPARK-4355)") {
+    val s0 = new MultivariateOnlineSummarizer()
+      .add(Vectors.dense(2.0))
+      .add(Vectors.dense(2.0))
+    val s1 = new MultivariateOnlineSummarizer()
+      .add(Vectors.dense(1.0))
+      .add(Vectors.dense(-1.0))
+    s0.merge(s1)
+    assert(s0.mean(0) ~== 1.0 absTol 1e-14)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to