Repository: spark
Updated Branches:
  refs/heads/master e10b8741d -> 25db51675


[SPARK-16561][MLLIB] fix multivarOnlineSummary min/max bug

## What changes were proposed in this pull request?

renaming var names to make code more clear:
nnz => weightSum
weightSum => totalWeightSum

and add a new member vector `nnz` (not `nnz` in previous code, which renamed to 
`weightSum`) to count each dimensions non-zero value number.
using `nnz` which I added above instead of `weightSum` when calculating min/max 
so that it fix several numerical error in some extreme case.

## How was this patch tested?

A new testcase added.

Author: WeichenXu <weichenxu...@outlook.com>

Closes #14216 from WeichenXu123/multivarOnlineSummary.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25db5167
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25db5167
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25db5167

Branch: refs/heads/master
Commit: 25db51675f43048d61ced8221dcb4885cc5143c1
Parents: e10b874
Author: WeichenXu <weichenxu...@outlook.com>
Authored: Sat Jul 23 12:32:30 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Sat Jul 23 12:32:30 2016 +0100

----------------------------------------------------------------------
 .../stat/MultivariateOnlineSummarizer.scala     | 63 +++++++++++---------
 .../MultivariateOnlineSummarizerSuite.scala     | 25 ++++++++
 2 files changed, 60 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/25db5167/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index d4de0fd..964f419 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -47,9 +47,10 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
   private var currM2: Array[Double] = _
   private var currL1: Array[Double] = _
   private var totalCnt: Long = 0
-  private var weightSum: Double = 0.0
+  private var totalWeightSum: Double = 0.0
   private var weightSquareSum: Double = 0.0
-  private var nnz: Array[Double] = _
+  private var weightSum: Array[Double] = _
+  private var nnz: Array[Long] = _
   private var currMax: Array[Double] = _
   private var currMin: Array[Double] = _
 
@@ -74,7 +75,8 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       currM2n = Array.ofDim[Double](n)
       currM2 = Array.ofDim[Double](n)
       currL1 = Array.ofDim[Double](n)
-      nnz = Array.ofDim[Double](n)
+      weightSum = Array.ofDim[Double](n)
+      nnz = Array.ofDim[Long](n)
       currMax = Array.fill[Double](n)(Double.MinValue)
       currMin = Array.fill[Double](n)(Double.MaxValue)
     }
@@ -86,7 +88,8 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
     val localCurrM2n = currM2n
     val localCurrM2 = currM2
     val localCurrL1 = currL1
-    val localNnz = nnz
+    val localWeightSum = weightSum
+    val localNumNonzeros = nnz
     val localCurrMax = currMax
     val localCurrMin = currMin
     instance.foreachActive { (index, value) =>
@@ -100,16 +103,17 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
 
         val prevMean = localCurrMean(index)
         val diff = value - prevMean
-        localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + 
weight)
+        localCurrMean(index) = prevMean + weight * diff / 
(localWeightSum(index) + weight)
         localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff
         localCurrM2(index) += weight * value * value
         localCurrL1(index) += weight * math.abs(value)
 
-        localNnz(index) += weight
+        localWeightSum(index) += weight
+        localNumNonzeros(index) += 1
       }
     }
 
-    weightSum += weight
+    totalWeightSum += weight
     weightSquareSum += weight * weight
     totalCnt += 1
     this
@@ -124,17 +128,18 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.1.0")
   def merge(other: MultivariateOnlineSummarizer): this.type = {
-    if (this.weightSum != 0.0 && other.weightSum != 0.0) {
+    if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) {
       require(n == other.n, s"Dimensions mismatch when merging with another 
summarizer. " +
         s"Expecting $n but got ${other.n}.")
       totalCnt += other.totalCnt
-      weightSum += other.weightSum
+      totalWeightSum += other.totalWeightSum
       weightSquareSum += other.weightSquareSum
       var i = 0
       while (i < n) {
-        val thisNnz = nnz(i)
-        val otherNnz = other.nnz(i)
+        val thisNnz = weightSum(i)
+        val otherNnz = other.weightSum(i)
         val totalNnz = thisNnz + otherNnz
+        val totalCnnz = nnz(i) + other.nnz(i)
         if (totalNnz != 0.0) {
           val deltaMean = other.currMean(i) - currMean(i)
           // merge mean together
@@ -149,18 +154,20 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
           currMax(i) = math.max(currMax(i), other.currMax(i))
           currMin(i) = math.min(currMin(i), other.currMin(i))
         }
-        nnz(i) = totalNnz
+        weightSum(i) = totalNnz
+        nnz(i) = totalCnnz
         i += 1
       }
-    } else if (weightSum == 0.0 && other.weightSum != 0.0) {
+    } else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) {
       this.n = other.n
       this.currMean = other.currMean.clone()
       this.currM2n = other.currM2n.clone()
       this.currM2 = other.currM2.clone()
       this.currL1 = other.currL1.clone()
       this.totalCnt = other.totalCnt
-      this.weightSum = other.weightSum
+      this.totalWeightSum = other.totalWeightSum
       this.weightSquareSum = other.weightSquareSum
+      this.weightSum = other.weightSum.clone()
       this.nnz = other.nnz.clone()
       this.currMax = other.currMax.clone()
       this.currMin = other.currMin.clone()
@@ -174,12 +181,12 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.1.0")
   override def mean: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
     val realMean = Array.ofDim[Double](n)
     var i = 0
     while (i < n) {
-      realMean(i) = currMean(i) * (nnz(i) / weightSum)
+      realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
       i += 1
     }
     Vectors.dense(realMean)
@@ -191,11 +198,11 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.1.0")
   override def variance: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
     val realVariance = Array.ofDim[Double](n)
 
-    val denominator = weightSum - (weightSquareSum / weightSum)
+    val denominator = totalWeightSum - (weightSquareSum / totalWeightSum)
 
     // Sample variance is computed, if the denominator is less than 0, the 
variance is just 0.
     if (denominator > 0.0) {
@@ -203,8 +210,8 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       var i = 0
       val len = currM2n.length
       while (i < len) {
-        realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) *
-          (weightSum - nnz(i)) / weightSum) / denominator
+        realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * 
weightSum(i) *
+          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
         i += 1
       }
     }
@@ -224,9 +231,9 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.1.0")
   override def numNonzeros: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
-    Vectors.dense(nnz)
+    Vectors.dense(weightSum)
   }
 
   /**
@@ -235,11 +242,11 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.1.0")
   override def max: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
     var i = 0
     while (i < n) {
-      if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0
+      if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
       i += 1
     }
     Vectors.dense(currMax)
@@ -251,11 +258,11 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.1.0")
   override def min: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
     var i = 0
     while (i < n) {
-      if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0
+      if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
       i += 1
     }
     Vectors.dense(currMin)
@@ -267,7 +274,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.2.0")
   override def normL2: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
     val realMagnitude = Array.ofDim[Double](n)
 
@@ -286,7 +293,7 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
    */
   @Since("1.2.0")
   override def normL1: Vector = {
-    require(weightSum > 0, s"Nothing has been added to this summarizer.")
+    require(totalWeightSum > 0, s"Nothing has been added to this summarizer.")
 
     Vectors.dense(currL1)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/25db5167/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index b6d41db..165a3f3 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends 
SparkFunSuite {
       absTol 1E-8, "normL2 mismatch")
     assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 
1E-10, "normL1 mismatch")
   }
+
+  test("test min/max with weighted samples (SPARK-16561)") {
+    val summarizer1 = new MultivariateOnlineSummarizer()
+      .add(Vectors.dense(10.0, -10.0), 1e10)
+      .add(Vectors.dense(0.0, 0.0), 1e-7)
+
+    val summarizer2 = new MultivariateOnlineSummarizer()
+    summarizer2.add(Vectors.dense(10.0, -10.0), 1e10)
+    for (i <- 1 to 100) {
+      summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7)
+    }
+
+    val summarizer3 = new MultivariateOnlineSummarizer()
+    for (i <- 1 to 100) {
+      summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7)
+    }
+    summarizer3.add(Vectors.dense(10.0, -10.0), 1e10)
+
+    assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+    assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+    assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+    assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+    assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
+    assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to