RNG-30: Unit tests for discrete distributions.

Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/438b67b3
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/438b67b3
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/438b67b3

Branch: refs/heads/RNG-30__sampling
Commit: 438b67b3258427e8c5a20d17b5fad0ee4629f0e8
Parents: 74bbdd2
Author: Gilles <er...@apache.org>
Authored: Sat Nov 12 16:51:01 2016 +0100
Committer: Gilles <er...@apache.org>
Committed: Sat Nov 12 16:51:01 2016 +0100

----------------------------------------------------------------------
 .../DiscreteSamplerParametricTest.java          | 161 +++++++++++++++++
 .../distribution/DiscreteSamplerTestData.java   |  60 +++++++
 .../distribution/DiscreteSamplersList.java      | 180 +++++++++++++++++++
 3 files changed, 401 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-rng/blob/438b67b3/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java
new file mode 100644
index 0000000..d96fcb1
--- /dev/null
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.concurrent.Callable;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.ByteArrayInputStream;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.Assume;
+import org.junit.Ignore;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.RandomProviderState;
+import org.apache.commons.rng.RestorableUniformRandomProvider;
+import org.apache.commons.rng.core.RandomProviderDefaultState;
+import org.apache.commons.rng.sampling.DiscreteSampler;
+
+/**
+ * Tests for samplers.
+ */
+@RunWith(value=Parameterized.class)
+public class DiscreteSamplerParametricTest {
+    /** Sampler under test. */
+    private final DiscreteSamplerTestData sampler;
+
+    /**
+     * Initializes generator instance.
+     *
+     * @param rng RNG to be tested.
+     */
+    public DiscreteSamplerParametricTest(DiscreteSamplerTestData data) {
+        sampler = data;
+    }
+
+    @Parameters(name = "{index}: data={0}")
+    public static Iterable<DiscreteSamplerTestData[]> getList() {
+        return DiscreteSamplersList.list();
+    }
+
+    @Test
+    public void testSampling() {
+        final int sampleSize = 10000;
+
+        final double[] prob = sampler.getProbabilities();
+        final int len = prob.length; 
+        final long[] expected = new long[len];
+        for (int i = 0; i < len; i++) {
+            expected[i] = (long) (prob[i] * sampleSize);
+        }
+        check(sampleSize,
+              sampler.getSampler(),
+              sampler.getPoints(),
+              expected);
+    }
+
+    /**
+     * Performs a chi-square test of homogeneity of the observed
+     * distribution with the expected distribution.
+     *
+     * @param sampler Sampler.
+     * @param sampleSize Number of random values to generate.
+     * @param points Outcomes.
+     * @param expected Expected counts of the given outcomes.
+     */
+    private void check(long sampleSize,
+                       DiscreteSampler sampler,
+                       int[] points,
+                       long[] expected) {
+        final int numTests = 50;
+
+        // Run the tests.
+        int numFailures = 0;
+
+        final int numBins = points.length;
+        final long[] observed = new long[numBins];
+
+        // For storing chi2 larger than the critical value.
+        final List<Double> failedStat = new ArrayList<Double>();
+        try {
+            for (int i = 0; i < numTests; i++) {
+                Arrays.fill(observed, 0);
+                SAMPLE: for (long j = 0; j < sampleSize; j++) {
+                    final int value = sampler.sample();
+
+                    for (int k = 0; k < numBins; k++) {
+                        if (value == points[k]) {
+                            ++observed[k];
+                            continue SAMPLE;
+                        }
+                    }
+                }
+
+                // Statistics check. XXX
+                final double chi2stat = chiSquareStat(expected, observed);
+                if (chi2stat < 0.001) {
+                    failedStat.add(chi2stat);
+                    ++numFailures;
+                }
+            }
+        } catch (Exception e) {
+            // Should never happen.
+            throw new RuntimeException("Unexpected", e);
+        }
+
+        if ((double) numFailures / (double) numTests > 0.02) {
+            Assert.fail(sampler + ": Too many failures for sample size = " + 
sampleSize +
+                        " (" + numFailures + " out of " + numTests + " tests 
failed, " +
+                        "chi2=" + Arrays.toString(failedStat.toArray(new 
Double[0])));
+        }
+    }
+
+    /**
+     * @param expected Counts.
+     * @param observed Counts.
+     * @return the chi-square statistics.
+     */
+    private static double chiSquareStat(long[] expected,
+                                        long[] observed) {
+        final int numBins = expected.length;
+        double chi2 = 0;
+        for (int i = 0; i < numBins; i++) {
+            final long diff = observed[i] - expected[i];
+            chi2 += (diff / (double) expected[i]) * diff;
+            // System.out.println("bin[" + i + "]" +
+            //                    " obs=" + observed[i] +
+            //                    " exp=" + expected[i]);
+        }
+
+        final int dof = numBins - 1;
+        final ChiSquaredDistribution dist = new ChiSquaredDistribution(null, 
dof, 1e-8);
+
+        return 1 - dist.cumulativeProbability(chi2);
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/438b67b3/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java
new file mode 100644
index 0000000..fd52f29
--- /dev/null
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import java.util.Arrays;
+
+import org.apache.commons.rng.sampling.DiscreteSampler;
+
+/**
+ * Data store for {@link InverseMethodDiscreteParametricTest}.
+ */
+class DiscreteSamplerTestData {
+    private final DiscreteSampler sampler;
+    private final int[] points;
+    private final double[] probabilities;
+
+    public DiscreteSamplerTestData(DiscreteSampler sampler,
+                                   int[] points,
+                                   double[] probabilities) {
+        this.sampler = sampler;
+        this.points = points.clone();
+        this.probabilities = probabilities.clone();
+    }
+
+    public DiscreteSampler getSampler() {
+        return sampler;
+    }
+
+    public int[] getPoints() {
+        return points.clone();
+    }
+
+    public double[] getProbabilities() {
+        return probabilities.clone();
+    }
+
+    @Override
+    public String toString() {
+        final int len = points.length;
+        final String[] p = new String[len];
+        for (int i = 0; i < len; i++) {
+            p[i] = "p(" + points[i] + ")=" + probabilities[i];
+        }
+        return sampler.toString() + ": " + Arrays.toString(p);
+    }
+}

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/438b67b3/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
new file mode 100644
index 0000000..7a7425f
--- /dev/null
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Collections;
+
+import org.apache.commons.math3.util.MathArrays;
+
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.DiscreteSampler;
+import org.apache.commons.rng.simple.RandomSource;
+
+/**
+ * List of samplers.
+ */
+public class DiscreteSamplersList {
+    /** List of all RNGs implemented in the library. */
+    private static final List<DiscreteSamplerTestData[]> LIST =
+        new ArrayList<DiscreteSamplerTestData[]>();
+
+    static {
+        try {
+            // List of distributions to test.
+
+            // Binomial ("inverse method").
+            final int trialsBinomial = 20;
+            final double probSuccessBinomial = 0.67;
+            add(LIST, new 
org.apache.commons.math3.distribution.BinomialDistribution(trialsBinomial, 
probSuccessBinomial),
+                MathArrays.sequence(8, 9, 1),
+                RandomSource.create(RandomSource.KISS));
+
+            // Geometric ("inverse method").
+            final double probSuccessGeometric = 0.21;
+            add(LIST, new 
org.apache.commons.math3.distribution.GeometricDistribution(probSuccessGeometric),
+                MathArrays.sequence(10, 0, 1),
+                RandomSource.create(RandomSource.ISAAC));
+
+            // Hypergeometric ("inverse method").
+            final int popSizeHyper = 34;
+            final int numSuccessesHyper = 11;
+            final int sampleSizeHyper = 12;
+            add(LIST, new 
org.apache.commons.math3.distribution.HypergeometricDistribution(popSizeHyper, 
numSuccessesHyper, sampleSizeHyper),
+                MathArrays.sequence(10, 0, 1),
+                RandomSource.create(RandomSource.MT));
+
+            // Pascal ("inverse method").
+            final int numSuccessesPascal = 6;
+            final double probSuccessPascal = 0.2;
+            add(LIST, new 
org.apache.commons.math3.distribution.PascalDistribution(numSuccessesPascal, 
probSuccessPascal),
+                MathArrays.sequence(18, 1, 1),
+                RandomSource.create(RandomSource.TWO_CMRES));
+
+            // Uniform ("inverse method").
+            final int loUniform = -3;
+            final int hiUniform = 4;
+            add(LIST, new 
org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, 
hiUniform),
+                MathArrays.sequence(10, -4, 1),
+                RandomSource.create(RandomSource.SPLIT_MIX_64));
+            // Uniform.
+            add(LIST, new 
org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, 
hiUniform),
+                MathArrays.sequence(10, -4, 1),
+                new 
DiscreteUniformSampler(RandomSource.create(RandomSource.MT_64), loUniform, 
hiUniform));
+
+            // Zipf ("inverse method").
+            final int numElementsZipf = 5;
+            final double exponentZipf = 2.345;
+            add(LIST, new 
org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, 
exponentZipf),
+                MathArrays.sequence(5, 0, 1),
+                RandomSource.create(RandomSource.XOR_SHIFT_1024_S));
+            // Zipf.
+            add(LIST, new 
org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, 
exponentZipf),
+                MathArrays.sequence(7, 0, 1),
+                new 
RejectionInversionZipfSampler(RandomSource.create(RandomSource.WELL_19937_C), 
numElementsZipf, exponentZipf));
+
+            // Poisson ("inverse method").
+            final double meanPoisson = 3.21;
+            add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
+                MathArrays.sequence(10, 0, 1),
+                RandomSource.create(RandomSource.MWC_256));
+            // Poisson.
+            add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson),
+                MathArrays.sequence(10, 0, 1),
+                new PoissonSampler(RandomSource.create(RandomSource.KISS), 
meanPoisson));
+            // Poisson (mean > 40).
+            final double largeMeanPoisson = 543.21;
+            add(LIST, new 
org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson),
+                MathArrays.sequence(100, (int) (largeMeanPoisson - 50), 1),
+                new 
PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), 
largeMeanPoisson));
+        } catch (Exception e) {
+            System.err.println("Unexpected exception while creating the list 
of samplers: " + e);
+            e.printStackTrace(System.err);
+            throw new RuntimeException(e);
+        }
+    }
+
+    /**
+     * Class contains only static methods.
+     */
+    private DiscreteSamplersList() {}
+
+    /**
+     * @param list List of data (one the "parameters" tested by the Junit 
parametric test).
+     * @param dist Distribution to which the samples are supposed to conform.
+     * @param points Outcomes selection.
+     * @param rng Generator of uniformly distributed sequences.
+     */
+    private static void add(List<DiscreteSamplerTestData[]> list,
+                            final 
org.apache.commons.math3.distribution.IntegerDistribution dist,
+                            int[] points,
+                            UniformRandomProvider rng) {
+        final DiscreteSampler inverseMethodSampler =
+            new InverseMethodDiscreteSampler(rng,
+                                             new 
DiscreteInverseCumulativeProbabilityFunction() {
+                                                 @Override
+                                                 public int 
inverseCumulativeProbability(double p) {
+                                                     return 
dist.inverseCumulativeProbability(p);
+                                                 }
+                                             });
+        list.add(new DiscreteSamplerTestData[] { new 
DiscreteSamplerTestData(inverseMethodSampler,
+                                                                             
points,
+                                                                             
getProbabilities(dist, points)) });
+     }
+
+    /**
+     * @param list List of data (one the "parameters" tested by the Junit 
parametric test).
+     * @param dist Distribution to which the samples are supposed to conform.
+     * @param points Outcomes selection.
+     * @param sampler Sampler.
+     */
+    private static void add(List<DiscreteSamplerTestData[]> list,
+                            final 
org.apache.commons.math3.distribution.IntegerDistribution dist,
+                            int[] points,
+                            final DiscreteSampler sampler) {
+        list.add(new DiscreteSamplerTestData[] { new 
DiscreteSamplerTestData(sampler,
+                                                                             
points,
+                                                                             
getProbabilities(dist, points)) });
+    }
+
+    /**
+     * Subclasses that are "parametric" tests can forward the call to
+     * the "@Parameters"-annotated method to this method.
+     *
+     * @return the list of all generators.
+     */
+    public static Iterable<DiscreteSamplerTestData[]> list() {
+        return Collections.unmodifiableList(LIST);
+    }
+
+    /**
+     * @param dist Distribution.
+     * @param points Points.
+     * @return the probabilities of the given points according to the 
distribution. 
+     */
+    private static double[] 
getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist,
+                                             int[] points) {
+        final int len = points.length;
+        final double[] prob = new double[len];
+        for (int i = 0; i < len; i++) {
+            prob[i] = dist.probability(points[i]);
+        }
+        return prob;
+    }
+}

Reply via email to