Author: tdunning
Date: Thu Sep  2 04:33:02 2010
New Revision: 991804

URL: http://svn.apache.org/viewvc?rev=991804&view=rev
Log:
MAHOUT-495 - Undeprecate Normal distribution.  Extract common test patterns 
into DistributionTest

Added:
    
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
    
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
Modified:
    mahout/trunk/math/pom.xml
    
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
    
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
    
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java
    
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java

Modified: mahout/trunk/math/pom.xml
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/pom.xml?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
--- mahout/trunk/math/pom.xml (original)
+++ mahout/trunk/math/pom.xml Thu Sep  2 04:33:02 2010
@@ -87,6 +87,13 @@
 
   <dependencies>
     <dependency>
+      <groupId>org.apache.commons</groupId>
+      <artifactId>commons-math</artifactId>
+      <version>2.1</version>
+      <scope>test</scope>
+    </dependency>
+
+    <dependency>
       <groupId>com.google.guava</groupId>
       <artifactId>guava</artifactId>
       <version>r03</version>

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
 (original)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractContinousDistribution.java
 Thu Sep  2 04:33:02 2010
@@ -27,8 +27,16 @@ It is provided "as is" without expressed
 package org.apache.mahout.math.jet.random;
 
 /**
- * Abstract base class for all continuous distributions.
+ * Abstract base class for all continuous distributions.  Continuous 
distributions have
+ * probability density and a cumulative distribution functions.
  *
  */
 public abstract class AbstractContinousDistribution extends 
AbstractDistribution {
+  public double cdf(double x) {
+    throw new UnsupportedOperationException("Can't compute pdf for " + 
this.getClass().getName());
+  }
+  
+  public double pdf(double x) {
+    throw new UnsupportedOperationException("Can't compute pdf for " + 
this.getClass().getName());
+  }
 }

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
 (original)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/AbstractDistribution.java
 Thu Sep  2 04:33:02 2010
@@ -31,7 +31,6 @@ import org.apache.mahout.math.function.I
 import org.apache.mahout.math.jet.random.engine.RandomEngine;
 
 public abstract class AbstractDistribution extends PersistentObject implements 
UnaryFunction, IntFunction {
-
   protected RandomEngine randomGenerator;
 
   /** Makes this class non instantiable, but still let's others inherit from 
it. */
@@ -93,22 +92,6 @@ public abstract class AbstractDistributi
     return (int) Math.round(nextDouble());
   }
   
-  public byte nextByte() {
-    return (byte)nextInt();
-  }
-  
-  public char nextChar() {
-    return (char)nextInt();
-  }
-  
-  public long nextLong() {
-    return Math.round(nextDouble());
-  }
-  
-  public float nextFloat() {
-    return (float)nextDouble();
-  }
-
   /** Sets the uniform random generator internally used. */
   protected void setRandomGenerator(RandomEngine randomGenerator) {
     this.randomGenerator = randomGenerator;

Modified: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java 
(original)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/jet/random/Normal.java 
Thu Sep  2 04:33:02 2010
@@ -11,8 +11,11 @@ package org.apache.mahout.math.jet.rando
 import org.apache.mahout.math.jet.random.engine.RandomEngine;
 import org.apache.mahout.math.jet.stat.Probability;
 
-/** @deprecated until unit tests are in place.  Until this time, this 
class/interface is unsupported. */
-...@deprecated
+import java.util.Locale;
+
+/**
+ * Implements a normal distribution specified mean and standard deviation.
+ */
 public class Normal extends AbstractContinousDistribution {
 
   private double mean;
@@ -24,30 +27,39 @@ public class Normal extends AbstractCont
 
   private double normalizer; // performance cache
 
-  // The uniform random number generated shared by all <b>static</b> methods.
-  private static final Normal shared = new Normal(0.0, 1.0, 
makeDefaultGenerator());
-
-  /** Constructs a normal (gauss) distribution. Example: mean=0.0, 
standardDeviation=1.0. */
+  /**
+   * @param mean               The mean of the resulting distribution.
+   * @param standardDeviation  The standard deviation of the distribution.
+   * @param randomGenerator    The random number generator to use.  This can 
be null if you don't
+   * need to generate any numbers.
+   */
   public Normal(double mean, double standardDeviation, RandomEngine 
randomGenerator) {
     setRandomGenerator(randomGenerator);
     setState(mean, standardDeviation);
   }
 
-  /** Returns the cumulative distribution function. */
+  /**
+   * Returns the cumulative distribution function.
+   */
+  @Override
   public double cdf(double x) {
     return Probability.normal(mean, variance, x);
   }
 
-  /** Returns a random number from the distribution. */
+  /** Returns the probability density function. */
   @Override
-  public double nextDouble() {
-    return nextDouble(this.mean, this.standardDeviation);
+  public double pdf(double x) {
+    double diff = x - mean;
+    return normalizer * Math.exp(-(diff * diff) / (2.0 * variance));
   }
 
-  /** Returns a random number from the distribution; bypasses the internal 
state. */
-  public double nextDouble(double mean, double standardDeviation) {
+  /**
+   * Returns a random number from the distribution.
+   */
+  @Override
+  public double nextDouble() {
     // Uses polar Box-Muller transformation.
-    if (cacheFilled && this.mean == mean && this.standardDeviation == 
standardDeviation) {
+    if (cacheFilled) {
       cacheFilled = false;
       return cache;
     }
@@ -62,26 +74,23 @@ public class Normal extends AbstractCont
     } while (r >= 1.0);
 
     double z = Math.sqrt(-2.0 * Math.log(r) / r);
-    cache = mean + standardDeviation * x * z;
+    cache = this.mean + this.standardDeviation * x * z;
     cacheFilled = true;
-    return mean + standardDeviation * y * z;
-  }
-
-  /** Returns the probability distribution function. */
-  public double pdf(double x) {
-    double diff = x - mean;
-    return normalizer * Math.exp(-(diff * diff) / (2.0 * variance));
+    return this.mean + this.standardDeviation * y * z;
   }
 
   /** Sets the uniform random generator internally used. */
-  @Override
-  protected void setRandomGenerator(RandomEngine randomGenerator) {
+  public final void setRandomGenerator(RandomEngine randomGenerator) {
     super.setRandomGenerator(randomGenerator);
     this.cacheFilled = false;
   }
 
-  /** Sets the mean and variance. */
-  public void setState(double mean, double standardDeviation) {
+  /**
+   * Sets the mean and variance.
+   * @param mean The new value for the mean.
+   * @param standardDeviation The new value for the standard deviation.
+   */
+  public final void setState(double mean, double standardDeviation) {
     if (mean != this.mean || standardDeviation != this.standardDeviation) {
       this.mean = mean;
       this.standardDeviation = standardDeviation;
@@ -92,16 +101,8 @@ public class Normal extends AbstractCont
     }
   }
 
-  /** Returns a random number from the distribution with the given mean and 
standard deviation. */
-  public static double staticNextDouble(double mean, double standardDeviation) 
{
-    synchronized (shared) {
-      return shared.nextDouble(mean, standardDeviation);
-    }
-  }
-
   /** Returns a String representation of the receiver. */
   public String toString() {
-    return this.getClass().getName() + '(' + mean + ',' + standardDeviation + 
')';
+    return String.format(Locale.ENGLISH, "%s(m=%f, sd=%f)", 
this.getClass().getName(), mean, standardDeviation);
   }
-
 }

Added: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java?rev=991804&view=auto
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
 (added)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/DistributionTest.java
 Thu Sep  2 04:33:02 2010
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.jet.random;
+
+import org.apache.commons.math.ConvergenceException;
+import org.apache.commons.math.FunctionEvaluationException;
+import org.apache.commons.math.analysis.UnivariateRealFunction;
+import org.apache.commons.math.analysis.integration.RombergIntegrator;
+import org.apache.commons.math.analysis.integration.UnivariateRealIntegrator;
+import org.junit.Assert;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Provides a consistency check for continuous distributions that relates the 
pdf, cdf and
+ * samples.  The pdf is checked against the cdf by quadrature.  The sampling 
is checked
+ * against the cdf using a G^2 (similar to chi^2) test.
+ */
+public class DistributionTest {
+  public void checkDistribution(final AbstractContinousDistribution dist, 
double[] x, double offset, double scale, int n) throws ConvergenceException, 
FunctionEvaluationException {
+    double[] xs = Arrays.copyOf(x, x.length);
+    for (int i = 0; i < xs.length; i++) {
+      xs[i] = xs[i]*scale+ offset;
+    }
+    Arrays.sort(xs);
+
+    // collect samples
+    double[] y = new double[n];
+    for (int i = 0; i < n; i++) {
+      y[i] = dist.nextDouble();
+    }
+    Arrays.sort(y);
+
+    // compute probabilities for bins
+    double[] p = new double[xs.length + 1];
+    double lastP = 0;
+    for (int i = 0; i < xs.length; i++) {
+      double thisP = dist.cdf(xs[i]);
+      p[i] = thisP - lastP;
+      lastP = thisP;
+    }
+    p[p.length - 1] = 1 - lastP;
+
+    // count samples in each bin
+    int[] k = new int[xs.length + 1];
+    int lastJ = 0;
+    for (int i = 0; i < k.length - 1; i++) {
+      int j = 0;
+      while (j < n && y[j] < xs[i]) {
+        j++;
+      }
+      k[i] = j - lastJ;
+      lastJ = j;
+    }
+    k[k.length - 1] = n - lastJ;
+
+    // now verify probabilities by comparing to integral of pdf
+    UnivariateRealIntegrator integrator = new RombergIntegrator();
+    for (int i = 0; i < xs.length - 1; i++) {
+      double delta = integrator.integrate(new UnivariateRealFunction() {
+        public double value(double v) throws FunctionEvaluationException {
+          return dist.pdf(v);
+        }
+      }, xs[i], xs[i + 1]);
+      assertEquals(delta, p[i + 1], 1e-6);
+    }
+
+    // finally compute G^2 of observed versus predicted.  See 
http://en.wikipedia.org/wiki/G-test
+    double sum = 0;
+    for (int i = 0; i < k.length; i++) {
+      if (k[i] != 0) {
+        sum += k[i] * Math.log(k[i] / p[i] / n);
+      }
+    }
+    sum *= 2;
+
+    // sum is chi^2 distributed with degrees of freedom equal to number of 
partitions - 1
+    int dof = k.length - 1;
+    // fisher's approximation is that sqrt(2*x) is approximately unit normal 
with mean sqrt(2*dof-1)
+    double z = Math.sqrt(2 * sum) - Math.sqrt(2 * dof - 1);
+    Assert.assertTrue(String.format("offset=%.3f scale=%.3f Z = %.1f", offset, 
scale, z), Math.abs(z) < 3);
+  }
+
+  protected void checkCdf(double offset, double scale, 
AbstractContinousDistribution dist, double[] breaks, double[] quantiles) {
+    int i = 0;
+    for (double x : breaks) {
+      assertEquals(String.format("m=%.3f sd=%.3f x=%.3f", offset, scale, x), 
quantiles[i], dist.cdf(x * scale + offset), 1e-6);
+      i++;
+    }
+  }
+}

Modified: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java?rev=991804&r1=991803&r2=991804&view=diff
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
 (original)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/ExponentialTest.java
 Thu Sep  2 04:33:02 2010
@@ -17,6 +17,8 @@
 
 package org.apache.mahout.math.jet.random;
 
+import org.apache.commons.math.ConvergenceException;
+import org.apache.commons.math.FunctionEvaluationException;
 import org.apache.mahout.math.jet.random.engine.MersenneTwister;
 import org.junit.Test;
 
@@ -29,7 +31,7 @@ import static org.junit.Assert.assertEqu
  * Created by IntelliJ IDEA. User: tdunning Date: Aug 31, 2010 Time: 7:14:19 
PM To change this
  * template use File | Settings | File Templates.
  */
-public class ExponentialTest {
+public class ExponentialTest extends DistributionTest {
   @Test
   public void testCdf() {
     Exponential dist = new Exponential(5.0, new MersenneTwister(1));
@@ -65,10 +67,13 @@ public class ExponentialTest {
   }
 
   @Test
-  public void testNextDouble() {
-    for (double lambda : new double[] {13.0, 0.02, 1.6}) {
-      Exponential dist = new Exponential(lambda, new MersenneTwister(1));
+  public void testNextDouble() throws ConvergenceException, 
FunctionEvaluationException {
+    double[] x = {-0.01, 0.1053605, 0.2231436, 0.3566749, 0.5108256, 
0.6931472, 0.9162907, 1.2039728, 1.6094379, 2.3025851};
+    Exponential dist = new Exponential(1, new MersenneTwister(1));
+    for (double lambda : new double[]{13.0, 0.02, 1.6}) {
+      dist.setState(lambda);
       checkEmpiricalDistribution(dist, 10000, lambda);
+      checkDistribution(dist, x, 0, 1 / lambda, 10000);
     }
   }
 

Added: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java?rev=991804&view=auto
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
 (added)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/jet/random/NormalTest.java
 Thu Sep  2 04:33:02 2010
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.jet.random;
+
+import org.apache.commons.math.ConvergenceException;
+import org.apache.commons.math.FunctionEvaluationException;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ * Created by IntelliJ IDEA. User: tdunning Date: Sep 1, 2010 Time: 9:09:44 AM 
To change this
+ * template use File | Settings | File Templates.
+ */
+public class NormalTest extends DistributionTest {
+  private double[] breaks = {-1.2815516, -0.8416212, -0.5244005, -0.2533471, 
0.0000000, 0.2533471, 0.5244005, 0.8416212, 1.2815516};
+  private double[] quantiles = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
+
+  @Test
+  public void testCdf() {
+    Random gen = new Random(1);
+    double offset = 0;
+    double scale = 1;
+    for (int k = 0; k < 20; k++) {
+      Normal dist = new Normal(offset, scale, null);
+      checkCdf(offset, scale, dist, breaks, quantiles);
+      offset = gen.nextGaussian();
+      scale = Math.exp(3 * gen.nextGaussian());
+    }
+  }
+
+  @Test
+  public void consistency() throws ConvergenceException, 
FunctionEvaluationException {
+    Random gen = new Random(1);
+    double offset = 0;
+    double scale = 1;
+    for (int k = 0; k < 20; k++) {
+      Normal dist = new Normal(offset, scale, new MersenneTwister());
+      checkDistribution(dist, breaks, offset, scale, 10000);
+      offset = gen.nextGaussian();
+      scale = Math.exp(3 * gen.nextGaussian());
+    }
+  }
+
+  @Test
+  public void testSetState() throws ConvergenceException, 
FunctionEvaluationException {
+    Normal dist = new Normal(0, 1, new MersenneTwister());
+    dist.setState(1.3, 5.9);
+    checkDistribution(dist, breaks, 1.3, 5.9, 10000);
+  }
+
+  @Test
+  public void testToString() {
+    Locale d = Locale.getDefault();
+    Locale.setDefault(Locale.GERMAN);
+    Assert.assertEquals("org.apache.mahout.math.jet.random.Normal(m=1.300000, 
sd=5.900000)", new Normal(1.3, 5.9, null).toString());
+    Locale.setDefault(d);
+  }
+}


Reply via email to