Author: gregs
Date: Sat Sep 10 04:18:31 2011
New Revision: 1167451
URL: http://svn.apache.org/viewvc?rev=1167451&view=rev
Log:
(MATH-649) SimpleRegression needs the ability to suppress the intercept
This commit pushes changes to allow the estimation of the a regression in which
the intercept is constrained to be zero. I am also pushing two unit tests.
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java
commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
Modified:
commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java?rev=1167451&r1=1167450&r2=1167451&view=diff
==============================================================================
---
commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java
(original)
+++
commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/SimpleRegression.java
Sat Sep 10 04:18:31 2011
@@ -84,13 +84,23 @@ public class SimpleRegression implements
/** mean of accumulated y values, used in updating formulas */
private double ybar = 0;
+ /** include an intercept or not */
+ private final boolean hasIntercept;
// ---------------------Public
methods--------------------------------------
/**
* Create an empty SimpleRegression instance
*/
public SimpleRegression() {
+ this(true);
+ }
+ /**
+ * Secondary constructor which allows the user the ability to
include/exclude const
+ * @param includeIntercept boolean flag, true includes an intercept
+ */
+ public SimpleRegression(boolean includeIntercept){
super();
+ hasIntercept = includeIntercept;
}
/**
@@ -106,22 +116,32 @@ public class SimpleRegression implements
* @param x independent variable value
* @param y dependent variable value
*/
- public void addData(double x, double y) {
+ public void addData(final double x, final double y){
if (n == 0) {
xbar = x;
ybar = y;
} else {
- double dx = x - xbar;
- double dy = y - ybar;
- sumXX += dx * dx * (double) n / (n + 1d);
- sumYY += dy * dy * (double) n / (n + 1d);
- sumXY += dx * dy * (double) n / (n + 1d);
- xbar += dx / (n + 1.0);
- ybar += dy / (n + 1.0);
+ if( hasIntercept ){
+ final double fact1 = 1.0 + (double) n;
+ final double fact2 = ((double) n) / (1.0 + (double) n);
+ final double dx = x - xbar;
+ final double dy = y - ybar;
+ sumXX += dx * dx * fact2;
+ sumYY += dy * dy * fact2;
+ sumXY += dx * dy * fact2;
+ xbar += dx / fact1;
+ ybar += dy / fact1;
+ }
+ }
+ if( !hasIntercept ){
+ sumXX += x * x ;
+ sumYY += y * y ;
+ sumXY += x * y ;
}
sumX += x;
sumY += y;
n++;
+ return;
}
@@ -140,17 +160,29 @@ public class SimpleRegression implements
*/
public void removeData(double x, double y) {
if (n > 0) {
- double dx = x - xbar;
- double dy = y - ybar;
- sumXX -= dx * dx * (double) n / (n - 1d);
- sumYY -= dy * dy * (double) n / (n - 1d);
- sumXY -= dx * dy * (double) n / (n - 1d);
- xbar -= dx / (n - 1.0);
- ybar -= dy / (n - 1.0);
- sumX -= x;
- sumY -= y;
- n--;
+ if (hasIntercept) {
+ final double fact1 = (double) n - 1.0;
+ final double fact2 = ((double) n) / ((double) n - 1.0);
+ final double dx = x - xbar;
+ final double dy = y - ybar;
+ sumXX -= dx * dx * fact2;
+ sumYY -= dy * dy * fact2;
+ sumXY -= dx * dy * fact2;
+ xbar -= dx / fact1;
+ ybar -= dy / fact1;
+ } else {
+ final double fact1 = (double) n - 1.0;
+ sumXX -= x * x;
+ sumYY -= y * y;
+ sumXY -= x * y;
+ xbar -= x / fact1;
+ ybar -= y / fact1;
+ }
+ sumX -= x;
+ sumY -= y;
+ n--;
}
+ return;
}
/**
@@ -235,7 +267,10 @@ public class SimpleRegression implements
*/
public double predict(double x) {
double b1 = getSlope();
- return getIntercept(b1) + b1 * x;
+ if (hasIntercept) {
+ return getIntercept(b1) + b1 * x;
+ }
+ return b1 * x;
}
/**
@@ -255,7 +290,16 @@ public class SimpleRegression implements
* @return the intercept of the regression line
*/
public double getIntercept() {
- return getIntercept(getSlope());
+ return hasIntercept ? getIntercept(getSlope()) : 0.0;
+ }
+
+ /**
+ * Returns true if a constant has been included false otherwise.
+ *
+ * @return true if constant exists, false otherwise
+ */
+ public boolean hasIntercept(){
+ return hasIntercept;
}
/**
@@ -391,7 +435,7 @@ public class SimpleRegression implements
if (n < 3) {
return Double.NaN;
}
- return getSumSquaredErrors() / (n - 2);
+ return hasIntercept ? (getSumSquaredErrors() / (n - 2)) :
(getSumSquaredErrors() / (n - 1));
}
/**
@@ -443,11 +487,15 @@ public class SimpleRegression implements
* <p>
* If there are fewer that <strong>three</strong> observations in the
* model, or if there is no variation in x, this returns
- * <code>Double.NaN</code>.</p>
+ * <code>Double.NaN</code>.</p> Additionally, a <code>Double.NaN</code> is
+ * returned when the intercept is constrained to be zero
*
* @return standard error associated with intercept estimate
*/
public double getInterceptStdErr() {
+ if( !hasIntercept ){
+ return Double.NaN;
+ }
return FastMath.sqrt(
getMeanSquareError() * ((1d / (double) n) + (xbar * xbar) /
sumXX));
}
@@ -572,8 +620,11 @@ public class SimpleRegression implements
* @param slope current slope
* @return the intercept of the regression line
*/
- private double getIntercept(double slope) {
+ private double getIntercept(double slope){
+ if( hasIntercept){
return (sumY - slope * sumX) / n;
+ }
+ return 0.0;
}
/**
Modified:
commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
URL:
http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java?rev=1167451&r1=1167450&r2=1167451&view=diff
==============================================================================
---
commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
(original)
+++
commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/SimpleRegressionTest.java
Sat Sep 10 04:18:31 2011
@@ -80,6 +80,76 @@ public final class SimpleRegressionTest
{5, -1 }, {6, 12 }
};
+
+ /*
+ * Data from NIST NOINT1
+ */
+ private double[][] noint1 = {
+ {130.0,60.0},
+ {131.0,61.0},
+ {132.0,62.0},
+ {133.0,63.0},
+ {134.0,64.0},
+ {135.0,65.0},
+ {136.0,66.0},
+ {137.0,67.0},
+ {138.0,68.0},
+ {139.0,69.0},
+ {140.0,70.0}
+ };
+
+ /*
+ * Data from NIST NOINT2
+ *
+ */
+ private double[][] noint2 = {
+ {3.0,4},
+ {4,5},
+ {4,6}
+ };
+
+ @Test
+ public void testNoInterceot_noint2(){
+ SimpleRegression regression = new SimpleRegression(false);
+ regression.addData(noint2[0][1], noint2[0][0]);
+ regression.addData(noint2[1][1], noint2[1][0]);
+ regression.addData(noint2[2][1], noint2[2][0]);
+ Assert.assertEquals("slope", 0.727272727272727,
+ regression.getSlope(), 10E-12);
+ Assert.assertEquals("slope std err", 0.420827318078432E-01,
+ regression.getSlopeStdErr(),10E-12);
+ Assert.assertEquals("number of observations", 3, regression.getN());
+ Assert.assertEquals("r-square", 0.993348115299335,
+ regression.getRSquare(), 10E-12);
+ Assert.assertEquals("SSR", 40.7272727272727,
+ regression.getRegressionSumSquares(), 10E-9);
+ Assert.assertEquals("MSE", 0.136363636363636,
+ regression.getMeanSquareError(), 10E-10);
+ Assert.assertEquals("SSE", 0.272727272727273,
+ regression.getSumSquaredErrors(),10E-9);
+ }
+
+ @Test
+ public void testNoIntercept_noint1(){
+ SimpleRegression regression = new SimpleRegression(false);
+ for (int i = 0; i < noint1.length; i++) {
+ regression.addData(noint1[i][1], noint1[i][0]);
+ }
+ Assert.assertEquals("slope", 2.07438016528926, regression.getSlope(),
10E-12);
+ Assert.assertEquals("slope std err", 0.165289256198347E-01,
+ regression.getSlopeStdErr(),10E-12);
+ Assert.assertEquals("number of observations", 11, regression.getN());
+ Assert.assertEquals("r-square", 0.999365492298663,
+ regression.getRSquare(), 10E-12);
+ Assert.assertEquals("SSR", 200457.727272727,
+ regression.getRegressionSumSquares(), 10E-9);
+ Assert.assertEquals("MSE", 12.7272727272727,
+ regression.getMeanSquareError(), 10E-10);
+ Assert.assertEquals("SSE", 127.272727272727,
+ regression.getSumSquaredErrors(),10E-9);
+
+ }
+
@Test
public void testNorris() {
SimpleRegression regression = new SimpleRegression();