Repository: spark Updated Branches: refs/heads/master faeb41de2 -> 84324fbcb
[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test. Author: Xiangrui Meng <m...@databricks.com> Closes #3220 from mengxr/SPARK-4355 and squashes the following commits: 5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84324fbc Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84324fbc Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84324fbc Branch: refs/heads/master Commit: 84324fbcb987db6e10e435f463eacace1bae43e2 Parents: faeb41d Author: Xiangrui Meng <m...@databricks.com> Authored: Wed Nov 12 01:50:11 2014 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed Nov 12 01:50:11 2014 -0800 ---------------------------------------------------------------------- .../stat/MultivariateOnlineSummarizer.scala | 85 +++++++++----------- .../MultivariateOnlineSummarizerSuite.scala | 11 +++ 2 files changed, 51 insertions(+), 45 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/84324fbc/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index fab7c44..654479a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -50,6 +50,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var currMin: BDV[Double] = _ /** + * Adds input value to position i. + */ + private[this] def add(i: Int, value: Double) = { + if (value != 0.0) { + if (currMax(i) < value) { + currMax(i) = value + } + if (currMin(i) > value) { + currMin(i) = value + } + + val prevMean = currMean(i) + val diff = value - prevMean + currMean(i) = prevMean + diff / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * diff + currM2(i) += value * value + currL1(i) += math.abs(value) + + nnz(i) += 1.0 + } + } + + /** * Add a new sample to this summarizer, and update the statistical summary. * * @param sample The sample in dense/sparse vector format to be added into this summarizer. @@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - @inline def update(i: Int, value: Double) = { - if (value != 0.0) { - if (currMax(i) < value) { - currMax(i) = value - } - if (currMin(i) > value) { - currMin(i) = value - } - - val tmpPrevMean = currMean(i) - currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - currM2(i) += value * value - currL1(i) += math.abs(value) - - nnz(i) += 1.0 - } - } - sample match { case dv: DenseVector => { var j = 0 while (j < dv.size) { - update(j, dv.values(j)) + add(j, dv.values(j)) j += 1 } } case sv: SparseVector => var j = 0 while (j < sv.indices.size) { - update(sv.indices(j), sv.values(j)) + add(sv.indices(j), sv.values(j)) j += 1 } case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) @@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt - val deltaMean: BDV[Double] = currMean - other.currMean var i = 0 while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { - currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / - (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { - currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / - (nnz(i) + other.nnz(i)) - } - // merge m2 together - if (nnz(i) + other.nnz(i) != 0.0) { + val thisNnz = nnz(i) + val otherNnz = other.nnz(i) + val totalNnz = thisNnz + otherNnz + if (totalNnz != 0.0) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherNnz / totalNnz + // merge m2n together + currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz + // merge m2 together currM2(i) += other.currM2(i) - } - // merge l1 together - if (nnz(i) + other.nnz(i) != 0.0) { + // merge l1 together currL1(i) += other.currL1(i) + // merge max and min + currMax(i) = math.max(currMax(i), other.currMax(i)) + currMin(i) = math.min(currMin(i), other.currMin(i)) } - - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) - } + nnz(i) = totalNnz i += 1 } - nnz += other.nnz } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n this.currMean = other.currMean.copy http://git-wip-us.apache.org/repos/asf/spark/blob/84324fbc/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 1e94152..23b0eec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } + + test("merging summarizer when one side has zero mean (SPARK-4355)") { + val s0 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(2.0)) + .add(Vectors.dense(2.0)) + val s1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(1.0)) + .add(Vectors.dense(-1.0)) + s0.merge(s1) + assert(s0.mean(0) ~== 1.0 absTol 1e-14) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org