Repository: spark Updated Branches: refs/heads/master e10b8741d -> 25db51675
[SPARK-16561][MLLIB] fix multivarOnlineSummary min/max bug ## What changes were proposed in this pull request? renaming var names to make code more clear: nnz => weightSum weightSum => totalWeightSum and add a new member vector `nnz` (not `nnz` in previous code, which renamed to `weightSum`) to count each dimensions non-zero value number. using `nnz` which I added above instead of `weightSum` when calculating min/max so that it fix several numerical error in some extreme case. ## How was this patch tested? A new testcase added. Author: WeichenXu <weichenxu...@outlook.com> Closes #14216 from WeichenXu123/multivarOnlineSummary. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25db5167 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25db5167 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25db5167 Branch: refs/heads/master Commit: 25db51675f43048d61ced8221dcb4885cc5143c1 Parents: e10b874 Author: WeichenXu <weichenxu...@outlook.com> Authored: Sat Jul 23 12:32:30 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Sat Jul 23 12:32:30 2016 +0100 ---------------------------------------------------------------------- .../stat/MultivariateOnlineSummarizer.scala | 63 +++++++++++--------- .../MultivariateOnlineSummarizerSuite.scala | 25 ++++++++ 2 files changed, 60 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/25db5167/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index d4de0fd..964f419 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -47,9 +47,10 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currM2: Array[Double] = _ private var currL1: Array[Double] = _ private var totalCnt: Long = 0 - private var weightSum: Double = 0.0 + private var totalWeightSum: Double = 0.0 private var weightSquareSum: Double = 0.0 - private var nnz: Array[Double] = _ + private var weightSum: Array[Double] = _ + private var nnz: Array[Long] = _ private var currMax: Array[Double] = _ private var currMin: Array[Double] = _ @@ -74,7 +75,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n = Array.ofDim[Double](n) currM2 = Array.ofDim[Double](n) currL1 = Array.ofDim[Double](n) - nnz = Array.ofDim[Double](n) + weightSum = Array.ofDim[Double](n) + nnz = Array.ofDim[Long](n) currMax = Array.fill[Double](n)(Double.MinValue) currMin = Array.fill[Double](n)(Double.MaxValue) } @@ -86,7 +88,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val localCurrM2n = currM2n val localCurrM2 = currM2 val localCurrL1 = currL1 - val localNnz = nnz + val localWeightSum = weightSum + val localNumNonzeros = nnz val localCurrMax = currMax val localCurrMin = currMin instance.foreachActive { (index, value) => @@ -100,16 +103,17 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val prevMean = localCurrMean(index) val diff = value - prevMean - localCurrMean(index) = prevMean + weight * diff / (localNnz(index) + weight) + localCurrMean(index) = prevMean + weight * diff / (localWeightSum(index) + weight) localCurrM2n(index) += weight * (value - localCurrMean(index)) * diff localCurrM2(index) += weight * value * value localCurrL1(index) += weight * math.abs(value) - localNnz(index) += weight + localWeightSum(index) += weight + localNumNonzeros(index) += 1 } } - weightSum += weight + totalWeightSum += weight weightSquareSum += weight * weight totalCnt += 1 this @@ -124,17 +128,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.weightSum != 0.0 && other.weightSum != 0.0) { + if (this.totalWeightSum != 0.0 && other.totalWeightSum != 0.0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt - weightSum += other.weightSum + totalWeightSum += other.totalWeightSum weightSquareSum += other.weightSquareSum var i = 0 while (i < n) { - val thisNnz = nnz(i) - val otherNnz = other.nnz(i) + val thisNnz = weightSum(i) + val otherNnz = other.weightSum(i) val totalNnz = thisNnz + otherNnz + val totalCnnz = nnz(i) + other.nnz(i) if (totalNnz != 0.0) { val deltaMean = other.currMean(i) - currMean(i) // merge mean together @@ -149,18 +154,20 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMax(i) = math.max(currMax(i), other.currMax(i)) currMin(i) = math.min(currMin(i), other.currMin(i)) } - nnz(i) = totalNnz + weightSum(i) = totalNnz + nnz(i) = totalCnnz i += 1 } - } else if (weightSum == 0.0 && other.weightSum != 0.0) { + } else if (totalWeightSum == 0.0 && other.totalWeightSum != 0.0) { this.n = other.n this.currMean = other.currMean.clone() this.currM2n = other.currM2n.clone() this.currM2 = other.currM2.clone() this.currL1 = other.currL1.clone() this.totalCnt = other.totalCnt - this.weightSum = other.weightSum + this.totalWeightSum = other.totalWeightSum this.weightSquareSum = other.weightSquareSum + this.weightSum = other.weightSum.clone() this.nnz = other.nnz.clone() this.currMax = other.currMax.clone() this.currMin = other.currMin.clone() @@ -174,12 +181,12 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def mean: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { - realMean(i) = currMean(i) * (nnz(i) / weightSum) + realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum) i += 1 } Vectors.dense(realMean) @@ -191,11 +198,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def variance: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") val realVariance = Array.ofDim[Double](n) - val denominator = weightSum - (weightSquareSum / weightSum) + val denominator = totalWeightSum - (weightSquareSum / totalWeightSum) // Sample variance is computed, if the denominator is less than 0, the variance is just 0. if (denominator > 0.0) { @@ -203,8 +210,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * - (weightSum - nnz(i)) / weightSum) / denominator + realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator i += 1 } } @@ -224,9 +231,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def numNonzeros: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") - Vectors.dense(nnz) + Vectors.dense(weightSum) } /** @@ -235,11 +242,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def max: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < weightSum) && (currMax(i) < 0.0)) currMax(i) = 0.0 + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } Vectors.dense(currMax) @@ -251,11 +258,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.1.0") override def min: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") var i = 0 while (i < n) { - if ((nnz(i) < weightSum) && (currMin(i) > 0.0)) currMin(i) = 0.0 + if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } Vectors.dense(currMin) @@ -267,7 +274,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL2: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") val realMagnitude = Array.ofDim[Double](n) @@ -286,7 +293,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S */ @Since("1.2.0") override def normL1: Vector = { - require(weightSum > 0, s"Nothing has been added to this summarizer.") + require(totalWeightSum > 0, s"Nothing has been added to this summarizer.") Vectors.dense(currL1) } http://git-wip-us.apache.org/repos/asf/spark/blob/25db5167/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index b6d41db..165a3f3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -245,4 +245,29 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { absTol 1E-8, "normL2 mismatch") assert(summarizer.normL1 ~== Vectors.dense(0.21, 0.4265, 0.61) absTol 1E-10, "normL1 mismatch") } + + test("test min/max with weighted samples (SPARK-16561)") { + val summarizer1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(10.0, -10.0), 1e10) + .add(Vectors.dense(0.0, 0.0), 1e-7) + + val summarizer2 = new MultivariateOnlineSummarizer() + summarizer2.add(Vectors.dense(10.0, -10.0), 1e10) + for (i <- 1 to 100) { + summarizer2.add(Vectors.dense(0.0, 0.0), 1e-7) + } + + val summarizer3 = new MultivariateOnlineSummarizer() + for (i <- 1 to 100) { + summarizer3.add(Vectors.dense(0.0, 0.0), 1e-7) + } + summarizer3.add(Vectors.dense(10.0, -10.0), 1e10) + + assert(summarizer1.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer1.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer2.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer2.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) + assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org