Author: srowen
Date: Tue Mar 22 16:15:36 2011
New Revision: 1084234
URL: http://svn.apache.org/viewvc?rev=1084234&view=rev
Log:
MAHOUT-630 weighted average fix and add stddev
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
- copied, changed from r1083546,
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
- copied, changed from r1083546,
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java?rev=1084234&r1=1084233&r2=1084234&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
Tue Mar 22 16:15:36 2011
@@ -21,7 +21,7 @@ import java.io.Serializable;
import com.google.common.base.Preconditions;
-public final class WeightedRunningAverage implements RunningAverage,
Serializable {
+public class WeightedRunningAverage implements RunningAverage, Serializable {
private double totalWeight;
private double average;
@@ -42,7 +42,7 @@ public final class WeightedRunningAverag
if (oldTotalWeight <= 0.0) {
average = datum * weight;
} else {
- average = average * oldTotalWeight / totalWeight + datum / totalWeight;
+ average = average * oldTotalWeight / totalWeight + datum * weight /
totalWeight;
}
}
@@ -58,7 +58,7 @@ public final class WeightedRunningAverag
average = Double.NaN;
totalWeight = 0.0;
} else {
- average = average * oldTotalWeight / totalWeight - datum / totalWeight;
+ average = average * oldTotalWeight / totalWeight - datum * weight /
totalWeight;
}
}
Copied:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
(from r1083546,
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java)
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java&r1=1083546&r2=1084234&rev=1084234&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
Tue Mar 22 16:15:36 2011
@@ -17,79 +17,68 @@
package org.apache.mahout.cf.taste.impl.common;
-import java.io.Serializable;
+/**
+ * This subclass also provides for a weighted estimate of the sample standard
deviation.
+ * See <a
href="http://en.wikipedia.org/wiki/Mean_square_weighted_deviation">estimate
formulae here</a>.
+ */
+public final class WeightedRunningAverageAndStdDev extends
WeightedRunningAverage implements RunningAverageAndStdDev {
-import com.google.common.base.Preconditions;
+ private double totalSquaredWeight;
+ private double totalWeightedData;
+ private double totalWeightedSquaredData;
-public final class WeightedRunningAverage implements RunningAverage,
Serializable {
-
- private double totalWeight;
- private double average;
-
- public WeightedRunningAverage() {
- totalWeight = 0.0;
- average = Double.NaN;
+ public WeightedRunningAverageAndStdDev() {
+ totalSquaredWeight = 0.0;
+ totalWeightedData = 0.0;
+ totalWeightedSquaredData = 0.0;
}
@Override
- public synchronized void addDatum(double datum) {
- addDatum(datum, 1.0);
- }
-
public synchronized void addDatum(double datum, double weight) {
- double oldTotalWeight = totalWeight;
- totalWeight += weight;
- if (oldTotalWeight <= 0.0) {
- average = datum * weight;
- } else {
- average = average * oldTotalWeight / totalWeight + datum / totalWeight;
- }
+ super.addDatum(datum, weight);
+ totalSquaredWeight += weight * weight;
+ double weightedData = datum * weight;
+ totalWeightedData += weightedData;
+ totalWeightedSquaredData += weightedData * datum;
}
@Override
- public synchronized void removeDatum(double datum) {
- removeDatum(datum, 1.0);
- }
-
public synchronized void removeDatum(double datum, double weight) {
- double oldTotalWeight = totalWeight;
- totalWeight -= weight;
- if (totalWeight <= 0.0) {
- average = Double.NaN;
- totalWeight = 0.0;
- } else {
- average = average * oldTotalWeight / totalWeight - datum / totalWeight;
+ super.removeDatum(datum, weight);
+ totalSquaredWeight -= weight * weight;
+ if (totalSquaredWeight <= 0.0) {
+ totalSquaredWeight = 0.0;
+ }
+ double weightedData = datum * weight;
+ totalWeightedData -= weightedData;
+ if (totalWeightedData <= 0.0) {
+ totalWeightedData = 0.0;
+ }
+ totalWeightedSquaredData -= weightedData * datum;
+ if (totalWeightedSquaredData <= 0.0) {
+ totalWeightedSquaredData = 0.0;
}
}
-
+
+ /**
+ * @throws UnsupportedOperationException
+ */
@Override
- public synchronized void changeDatum(double delta) {
- changeDatum(delta, 1.0);
- }
-
public synchronized void changeDatum(double delta, double weight) {
- Preconditions.checkArgument(weight <= totalWeight);
- average += delta * weight / totalWeight;
- }
-
- public synchronized double getTotalWeight() {
- return totalWeight;
- }
-
- /** @return {@link #getTotalWeight()} */
- @Override
- public synchronized int getCount() {
- return (int) totalWeight;
+ throw new UnsupportedOperationException();
}
+
@Override
- public synchronized double getAverage() {
- return average;
+ public synchronized double getStandardDeviation() {
+ double totalWeight = getTotalWeight();
+ return Math.sqrt((totalWeightedSquaredData * totalWeight -
totalWeightedData * totalWeightedData) /
+ (totalWeight * totalWeight - totalSquaredWeight));
}
@Override
public synchronized String toString() {
- return String.valueOf(average);
+ return String.valueOf(String.valueOf(getAverage()) + ',' +
getStandardDeviation());
}
-
+
}
Copied:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
(from r1083546,
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java)
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java?p2=mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java&p1=mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java&r1=1083546&r2=1084234&rev=1084234&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
Tue Mar 22 16:15:36 2011
@@ -20,49 +20,66 @@ package org.apache.mahout.cf.taste.impl.
import org.apache.mahout.cf.taste.impl.TasteTestCase;
import org.junit.Test;
-/** <p>Tests {@link FullRunningAverage}.</p> */
-public final class RunningAverageTest extends TasteTestCase {
+/**
+ * <p>Tests {@link WeightedRunningAverage} and {@link
WeightedRunningAverageAndStdDev}.</p>
+ */
+public final class WeightedRunningAverageTest extends TasteTestCase {
@Test
- public void testFull() {
- doTestRunningAverage(new FullRunningAverage());
- }
-
- @Test
- public void testCompact() {
- doTestRunningAverage(new CompactRunningAverage());
- }
+ public void testWeighted() {
- private static void doTestRunningAverage(RunningAverage runningAverage) {
+ WeightedRunningAverage runningAverage = new WeightedRunningAverage();
assertEquals(0, runningAverage.getCount());
assertTrue(Double.isNaN(runningAverage.getAverage()));
runningAverage.addDatum(1.0);
- assertEquals(1, runningAverage.getCount());
assertEquals(1.0, runningAverage.getAverage(), EPSILON);
- runningAverage.addDatum(1.0);
- assertEquals(2, runningAverage.getCount());
+ runningAverage.addDatum(1.0, 2.0);
assertEquals(1.0, runningAverage.getAverage(), EPSILON);
- runningAverage.addDatum(4.0);
- assertEquals(3, runningAverage.getCount());
+ runningAverage.addDatum(8.0, 0.5);
assertEquals(2.0, runningAverage.getAverage(), EPSILON);
runningAverage.addDatum(-4.0);
- assertEquals(4, runningAverage.getCount());
- assertEquals(0.5, runningAverage.getAverage(), EPSILON);
+ assertEquals(2.0/3.0, runningAverage.getAverage(), EPSILON);
runningAverage.removeDatum(-4.0);
- assertEquals(3, runningAverage.getCount());
assertEquals(2.0, runningAverage.getAverage(), EPSILON);
- runningAverage.removeDatum(4.0);
- assertEquals(2, runningAverage.getCount());
- assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.removeDatum(2.0, 2.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
runningAverage.changeDatum(0.0);
- assertEquals(2, runningAverage.getCount());
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.changeDatum(4.0, 0.5);
+ assertEquals(5.0/1.5, runningAverage.getAverage(), EPSILON);
+ }
+
+ @Test
+ public void testWeightedAndStdDev() {
+
+ WeightedRunningAverageAndStdDev runningAverage = new
WeightedRunningAverageAndStdDev();
+
+ assertEquals(0, runningAverage.getCount());
+ assertTrue(Double.isNaN(runningAverage.getAverage()));
+ assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+
+ runningAverage.addDatum(1.0);
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+ runningAverage.addDatum(1.0, 2.0);
assertEquals(1.0, runningAverage.getAverage(), EPSILON);
- runningAverage.changeDatum(2.0);
- assertEquals(2, runningAverage.getCount());
+ assertEquals(0.0, runningAverage.getStandardDeviation(), EPSILON);
+ runningAverage.addDatum(8.0, 0.5);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(10.5), runningAverage.getStandardDeviation(),
EPSILON);
+ runningAverage.addDatum(-4.0);
+ assertEquals(2.0/3.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(15.75), runningAverage.getStandardDeviation(),
EPSILON);
+
+ runningAverage.removeDatum(-4.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(10.5), runningAverage.getStandardDeviation(),
EPSILON);
+ runningAverage.removeDatum(2.0, 2.0);
assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(31.5), runningAverage.getStandardDeviation(),
EPSILON);
}
}