Repository: commons-math
Updated Branches:
  refs/heads/master 1b5925b56 -> e14d9ce8e


[MATH-837] Support aggregation of any kind of StatisticalSummary in 
AggregateSummaryStatistics.


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/e14d9ce8
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/e14d9ce8
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/e14d9ce8

Branch: refs/heads/master
Commit: e14d9ce8e3c57b7352a647c0a8e62eeb88ebfba4
Parents: 1b5925b
Author: Thomas Neidhart <[email protected]>
Authored: Mon Oct 19 21:41:16 2015 +0200
Committer: Thomas Neidhart <[email protected]>
Committed: Mon Oct 19 21:41:16 2015 +0200

----------------------------------------------------------------------
 src/changes/changes.xml                         |  4 ++
 .../descriptive/AggregateSummaryStatistics.java | 12 +++---
 .../AggregateSummaryStatisticsTest.java         | 44 ++++++++++++++++----
 3 files changed, 48 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/e14d9ce8/src/changes/changes.xml
----------------------------------------------------------------------
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 91553f2..0ec27e7 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -54,6 +54,10 @@ If the output is not quite correct, check for invisible 
trailing spaces!
     </release>
 
     <release version="4.0" date="XXXX-XX-XX" description="">
+      <action dev="tn" type="add" issue="MATH-837"> <!-- backported to 3.6 -->
+        "AggregateSummaryStatistics" can now aggregate any kind of
+        "StatisticalSummary".
+      </action>
       <action dev="erans" type="fix" issue="MATH-1279"> <!-- backported to 3.6 
-->
         Check precondition (class "o.a.c.m.random.EmpiricalDistribution").
       </action>

http://git-wip-us.apache.org/repos/asf/commons-math/blob/e14d9ce8/src/main/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatistics.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatistics.java
 
b/src/main/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatistics.java
index 52fc9cc..40d989c 100644
--- 
a/src/main/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatistics.java
+++ 
b/src/main/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatistics.java
@@ -309,20 +309,21 @@ public class AggregateSummaryStatistics implements 
StatisticalSummary,
      * @param statistics collection of SummaryStatistics to aggregate
      * @return summary statistics for the combined dataset
      */
-    public static StatisticalSummaryValues 
aggregate(Collection<SummaryStatistics> statistics) {
+    public static StatisticalSummaryValues aggregate(Collection<? extends 
StatisticalSummary> statistics) {
         if (statistics == null) {
             return null;
         }
-        Iterator<SummaryStatistics> iterator = statistics.iterator();
+        Iterator<? extends StatisticalSummary> iterator = 
statistics.iterator();
         if (!iterator.hasNext()) {
             return null;
         }
-        SummaryStatistics current = iterator.next();
+        StatisticalSummary current = iterator.next();
         long n = current.getN();
         double min = current.getMin();
         double sum = current.getSum();
         double max = current.getMax();
-        double m2 = current.getSecondMoment();
+        double var = current.getVariance();
+        double m2 = var * (n - 1d);
         double mean = current.getMean();
         while (iterator.hasNext()) {
             current = iterator.next();
@@ -338,7 +339,8 @@ public class AggregateSummaryStatistics implements 
StatisticalSummary,
             n += curN;
             final double meanDiff = current.getMean() - mean;
             mean = sum / n;
-            m2 = m2 + current.getSecondMoment() + meanDiff * meanDiff * oldN * 
curN / n;
+            final double curM2 = current.getVariance() * (curN - 1d);
+            m2 = m2 + curM2 + meanDiff * meanDiff * oldN * curN / n;
         }
         final double variance;
         if (n == 0) {

http://git-wip-us.apache.org/repos/asf/commons-math/blob/e14d9ce8/src/test/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatisticsTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatisticsTest.java
 
b/src/test/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatisticsTest.java
index 1a8324d..5324ecf 100644
--- 
a/src/test/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatisticsTest.java
+++ 
b/src/test/java/org/apache/commons/math4/stat/descriptive/AggregateSummaryStatisticsTest.java
@@ -25,10 +25,6 @@ import 
org.apache.commons.math4.distribution.IntegerDistribution;
 import org.apache.commons.math4.distribution.RealDistribution;
 import org.apache.commons.math4.distribution.UniformIntegerDistribution;
 import org.apache.commons.math4.distribution.UniformRealDistribution;
-import org.apache.commons.math4.stat.descriptive.AggregateSummaryStatistics;
-import org.apache.commons.math4.stat.descriptive.StatisticalSummary;
-import org.apache.commons.math4.stat.descriptive.StatisticalSummaryValues;
-import org.apache.commons.math4.stat.descriptive.SummaryStatistics;
 import org.apache.commons.math4.util.Precision;
 import org.junit.Assert;
 import org.junit.Test;
@@ -36,7 +32,6 @@ import org.junit.Test;
 
 /**
  * Test cases for {@link AggregateSummaryStatistics}
- *
  */
 public class AggregateSummaryStatisticsTest {
 
@@ -132,7 +127,6 @@ public class AggregateSummaryStatisticsTest {
      * partition and comparing the result of aggregate(...) applied to the 
collection
      * of per-partition SummaryStatistics with a single SummaryStatistics 
computed
      * over the full sample.
-     *
      */
     @Test
     public void testAggregate() {
@@ -166,6 +160,42 @@ public class AggregateSummaryStatisticsTest {
         assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
     }
 
+    /**
+     * Similar to {@link #testAggregate()} but operating on
+     * {@link StatisticalSummary} instead.
+     */
+    @Test
+    public void testAggregateStatisticalSummary() {
+
+        // Generate a random sample and random partition
+        double[] totalSample = generateSample();
+        double[][] subSamples = generatePartition(totalSample);
+        int nSamples = subSamples.length;
+
+        // Compute combined stats directly
+        SummaryStatistics totalStats = new SummaryStatistics();
+        for (int i = 0; i < totalSample.length; i++) {
+            totalStats.addValue(totalSample[i]);
+        }
+
+        // Now compute subsample stats individually and aggregate
+        SummaryStatistics[] subSampleStats = new SummaryStatistics[nSamples];
+        for (int i = 0; i < nSamples; i++) {
+            subSampleStats[i] = new SummaryStatistics();
+        }
+        Collection<StatisticalSummary> aggregate = new 
ArrayList<StatisticalSummary>();
+        for (int i = 0; i < nSamples; i++) {
+            for (int j = 0; j < subSamples[i].length; j++) {
+                subSampleStats[i].addValue(subSamples[i][j]);
+            }
+            aggregate.add(subSampleStats[i].getSummary());
+        }
+
+        // Compare values
+        StatisticalSummary aggregatedStats = 
AggregateSummaryStatistics.aggregate(aggregate);
+        assertEquals(totalStats.getSummary(), aggregatedStats, 10E-12);
+    }
+
 
     @Test
     public void testAggregateDegenerate() {
@@ -269,7 +299,7 @@ public class AggregateSummaryStatisticsTest {
         final double[][] out = new double[5][];
         int cur = 0;          // beginning of current partition segment
         int offset = 0;       // end of current partition segment
-        int sampleCount = 0;  // number of segments defined 
+        int sampleCount = 0;  // number of segments defined
         for (int i = 0; i < 5; i++) {
             if (cur == length || offset == length) {
                 break;

Reply via email to