This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 2070821b022cee2e79954953c1400166dc2c6b0b Author: Alex Herbert <[email protected]> AuthorDate: Sun Jun 16 21:55:18 2019 +0100 RNG-100: Add a GuideTableDiscreteSampler. This can sample from any distribution defined by an array of probabilities. --- .../distribution/DiscreteSamplersPerformance.java | 18 +- .../distribution/GuideTableDiscreteSampler.java | 201 +++++++++++++++++ .../rng/sampling/distribution/InternalUtils.java | 17 +- .../MarsagliaTsangWangDiscreteSampler.java | 7 +- .../distribution/DiscreteSamplersList.java | 2 + .../GuideTableDiscreteSamplerTest.java | 237 +++++++++++++++++++++ 6 files changed, 473 insertions(+), 9 deletions(-) diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java index 0a72b23..641b652 100644 --- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java +++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/DiscreteSamplersPerformance.java @@ -22,6 +22,7 @@ import org.apache.commons.rng.examples.jmh.RandomSources; import org.apache.commons.rng.sampling.distribution.DiscreteSampler; import org.apache.commons.rng.sampling.distribution.DiscreteUniformSampler; import org.apache.commons.rng.sampling.distribution.GeometricSampler; +import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; import org.apache.commons.rng.sampling.distribution.LargeMeanPoissonSampler; import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler; import org.apache.commons.rng.sampling.distribution.RejectionInversionZipfSampler; @@ -59,6 +60,17 @@ public class DiscreteSamplersPerformance { */ @State(Scope.Benchmark) public static class Sources extends RandomSources { + /** The probabilities for the discrete distribution. */ + private static final double[] DISCRETE_PROBABILITIES; + + static { + // This is not normalised to sum to 1. The samplers should handle this. + DISCRETE_PROBABILITIES = new double[25]; + for (int i = 0; i < DISCRETE_PROBABILITIES.length; i++) { + DISCRETE_PROBABILITIES[i] = (i + 1.0) / DISCRETE_PROBABILITIES.length; + } + } + /** * The sampler type. */ @@ -70,6 +82,7 @@ public class DiscreteSamplersPerformance { "MarsagliaTsangWangDiscreteSampler", "MarsagliaTsangWangPoissonSampler", "MarsagliaTsangWangBinomialSampler", + "GuideTableDiscreteSampler", }) private String samplerType; @@ -101,12 +114,13 @@ public class DiscreteSamplersPerformance { } else if ("GeometricSampler".equals(samplerType)) { sampler = new GeometricSampler(rng, 0.21); } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) { - sampler = MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(rng, - new double[] {0.1, 0.2, 0.3, 0.4}); + sampler = MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(rng, DISCRETE_PROBABILITIES); } else if ("MarsagliaTsangWangPoissonSampler".equals(samplerType)) { sampler = MarsagliaTsangWangDiscreteSampler.createPoissonDistribution(rng, 8.9); } else if ("MarsagliaTsangWangBinomialSampler".equals(samplerType)) { sampler = MarsagliaTsangWangDiscreteSampler.createBinomialDistribution(rng, 20, 0.33); + } else if ("GuideTableDiscreteSampler".equals(samplerType)) { + sampler = new GuideTableDiscreteSampler(rng, DISCRETE_PROBABILITIES); } } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java new file mode 100644 index 0000000..beb7a5f --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSampler.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Compute a sample from a discrete probability distribution. The cumulative probability + * distribution is searched using a guide table to set an initial start point. This implementation + * is based on: + * + * <ul> + * <li> + * <blockquote> + * Devroye, Luc (1986). Non-Uniform Random Variate Generation. + * New York: Springer-Verlag. Chapter 3.2.4 "The method of guide tables" p. 96. + * </blockquote> + * </li> + * </ul> + * + * <p>The size of the guide table can be controlled using a parameter. A larger guide table + * will improve performance at the cost of storage space.</p> + * + * <p>Sampling uses {@link UniformRandomProvider#nextDouble()}.</p> + * + * @since 1.3 + */ +public class GuideTableDiscreteSampler + implements DiscreteSampler { + /** The default value for {@code alpha}. */ + private static final double DEFAULT_ALPHA = 1.0; + /** Underlying source of randomness. */ + private final UniformRandomProvider rng; + /** + * The cumulative probability table ({@code f(x)}). + */ + private final double[] cumulativeProbabilities; + /** + * The inverse cumulative probability guide table. This is a guide map between the cumulative + * probability (f(x)) and the value x. It is used to set the initial point for search + * of the cumulative probability table. + * + * <p>The index in the map is obtained using {@code p * map.length} where {@code p} is the + * known cumulative probability {@code f(x)} or a uniform random deviate {@code u}. The value + * stored at the index is value {@code x+1} when {@code p = f(x)} such that it is the + * exclusive upper bound on the sample value {@code x} for searching the cumulative probability + * table {@code f(x)}. The search of the cumulative probability is towards zero.</p> + */ + private final int[] guideTable; + + /** + * Create a new instance using the default guide table size. + * + * @param rng Generator of uniformly distributed random numbers. + * @param probabilities The probabilities. + * @throws IllegalArgumentException if {@code probabilities} is null or empty, a + * probability is negative, infinite or {@code NaN}, or the sum of all + * probabilities is not strictly positive. + */ + public GuideTableDiscreteSampler(UniformRandomProvider rng, + double[] probabilities) { + this(rng, probabilities, DEFAULT_ALPHA); + } + + /** + * Create a new instance. + * + * <p>The size of the guide table is {@code alpha * probabilities.length}. + * + * @param rng Generator of uniformly distributed random numbers. + * @param probabilities The probabilities. + * @param alpha The alpha factor used to set the guide table size. + * @throws IllegalArgumentException if {@code probabilities} is null or empty, a + * probability is negative, infinite or {@code NaN}, the sum of all + * probabilities is not strictly positive, or {@code alpha} is not strictly positive. + */ + public GuideTableDiscreteSampler(UniformRandomProvider rng, + double[] probabilities, + double alpha) { + validateParameters(probabilities, alpha); + + final int size = probabilities.length; + cumulativeProbabilities = new double[size]; + + double sumProb = 0; + int count = 0; + for (final double prob : probabilities) { + InternalUtils.validateProbability(prob); + + // Compute and store cumulative probability. + sumProb += prob; + cumulativeProbabilities[count++] = sumProb; + } + + if (Double.isInfinite(sumProb) || sumProb <= 0) { + throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); + } + + this.rng = rng; + + // Note: The guide table is at least length 1. Compute the size avoiding overflow + // in case (alpha * size) is too large. + final int guideTableSize = (int) Math.ceil(alpha * size); + guideTable = new int[Math.max(guideTableSize, guideTableSize + 1)]; + + // Compute and store cumulative probability. + for (int x = 0; x < size; x++) { + final double norm = cumulativeProbabilities[x] / sumProb; + cumulativeProbabilities[x] = (norm < 1) ? norm : 1.0; + + // Set the guide table value as an exclusive upper bound (x + 1) + guideTable[getGuideTableIndex(cumulativeProbabilities[x])] = x + 1; + } + + // Edge case for round-off + cumulativeProbabilities[size - 1] = 1.0; + // The final guide table entry is (maximum value of x + 1) + guideTable[guideTable.length - 1] = size; + + // The first non-zero value in the guide table is from f(x=0). + // Any probabilities mapped below this must be sample x=0 so the + // table may initially be filled with zeros. + + // Fill missing values in the guide table. + for (int i = 1; i < guideTable.length; i++) { + guideTable[i] = Math.max(guideTable[i - 1], guideTable[i]); + } + } + + /** + * Validate the parameters. + * + * @param probabilities The probabilities. + * @param alpha The alpha factor used to set the guide table size. + * @throws IllegalArgumentException if {@code probabilities} is null or empty, or + * {@code alpha} is not strictly positive. + */ + private static void validateParameters(double[] probabilities, double alpha) { + if (probabilities == null || probabilities.length == 0) { + throw new IllegalArgumentException("Probabilities must not be empty."); + } + if (alpha <= 0) { + throw new IllegalArgumentException("Alpha must be strictly positive."); + } + } + + /** + * Gets the guide table index for the probability. This is obtained using + * {@code p * (guideTable.length - 1)} so is inside the length of the table. + * + * @param p Cumulative probability. + * @return the guide table index. + */ + private int getGuideTableIndex(double p) { + // Note: This is only ever called when p is in the range of the cumulative + // probability table. So assume 0 <= p <= 1. + return (int) (p * (guideTable.length - 1)); + } + + /** {@inheritDoc} */ + @Override + public int sample() { + // Compute a probability + final double u = rng.nextDouble(); + + // Initialise the search using the guide table to find an initial guess. + // The table provides an upper bound on the sample (x+1) for a known + // cumulative probability (f(x)). + int x = guideTable[getGuideTableIndex(u)]; + // Search down. + // In the edge case where u is 1.0 then 'x' will be 1 outside the range of the + // cumulative probability table and this will decrement to a valid range. + // In the case where 'u' is mapped to the same guide table index as a lower + // cumulative probability f(x) (due to rounding down) then this will not decrement + // and return the exclusive upper bound (x+1). + while (x != 0 && u <= cumulativeProbabilities[x - 1]) { + x--; + } + return x; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return "Guide table deviate [" + rng.toString() + "]"; + } +} diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java index 8d8e010..73d6f16 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/InternalUtils.java @@ -20,7 +20,7 @@ package org.apache.commons.rng.sampling.distribution; /** * Functions used by some of the samplers. * This class is not part of the public API, as it would be - * better to group these utilities in a dedicated components. + * better to group these utilities in a dedicated component. */ final class InternalUtils { // Class is package-private on purpose; do not make it public. /** All long-representable factorials. */ @@ -50,6 +50,21 @@ final class InternalUtils { // Class is package-private on purpose; do not make } /** + * Validate the probability is a finite positive number. + * + * @param probability Probability. + * @throws IllegalArgumentException if {@code probability} is negative, infinite or {@code NaN}. + */ + public static void validateProbability(double probability) { + if (probability < 0 || + Double.isInfinite(probability) || + Double.isNaN(probability)) { + throw new IllegalArgumentException("Invalid probability: " + + probability); + } + } + + /** * Class for computing the natural logarithm of the factorial of {@code n}. * It allows to allocate a cache of precomputed values. * In case of cache miss, computation is performed by a call to diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java index 59e5618..e8e4685 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java @@ -669,12 +669,7 @@ public abstract class MarsagliaTsangWangDiscreteSampler implements DiscreteSampl double sumProb = 0; for (final double prob : probabilities) { - if (prob < 0 || - Double.isInfinite(prob) || - Double.isNaN(prob)) { - throw new IllegalArgumentException("Invalid probability: " + - prob); - } + InternalUtils.validateProbability(prob); sumProb += prob; } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java index b021b50..bf8e2fd 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java @@ -169,6 +169,8 @@ public final class DiscreteSamplersList { final double[] discreteProbabilities = new double[] {0.1, 0.2, 0.3, 0.4, 0.5}; add(LIST, discreteProbabilities, MarsagliaTsangWangDiscreteSampler.createDiscreteDistribution(RandomSource.create(RandomSource.XO_SHI_RO_512_PLUS), discreteProbabilities)); + add(LIST, discreteProbabilities, + new GuideTableDiscreteSampler(RandomSource.create(RandomSource.XO_SHI_RO_512_SS), discreteProbabilities)); } catch (Exception e) { // CHECKSTYLE: stop Regexp System.err.println("Unexpected exception while creating the list of samplers: " + e); diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java new file mode 100644 index 0000000..f312c93 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GuideTableDiscreteSamplerTest.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.math3.distribution.BinomialDistribution; +import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.commons.math3.stat.inference.ChiSquareTest; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test for the {@link GuideTableDiscreteSampler}. + */ +public class GuideTableDiscreteSamplerTest { + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNullProbabilites() { + createSampler(null, 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithZeroLengthProbabilites() { + createSampler(new double[0], 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNegativeProbabilites() { + createSampler(new double[] {-1, 0.1, 0.2}, 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNaNProbabilites() { + createSampler(new double[] {0.1, Double.NaN, 0.2}, 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithInfiniteProbabilites() { + createSampler(new double[] {0.1, Double.POSITIVE_INFINITY, 0.2}, 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithInfiniteSumProbabilites() { + createSampler(new double[] {Double.MAX_VALUE, Double.MAX_VALUE}, 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithZeroSumProbabilites() { + createSampler(new double[4], 1.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithZeroAlpha() { + createSampler(new double[] {0.5, 0.5}, 0.0); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNegativeAlpha() { + createSampler(new double[] {0.5, 0.5}, -1.0); + } + + @Test + public void testToString() { + final GuideTableDiscreteSampler sampler = createSampler(new double[] {0.5, 0.5}, 1.0); + Assert.assertTrue(sampler.toString().toLowerCase().contains("guide table")); + } + + /** + * Creates the sampler. + * + * @param probabilities the probabilities + * @return the alias method discrete sampler + */ + private static GuideTableDiscreteSampler createSampler(double[] probabilities, double alpha) { + final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); + return new GuideTableDiscreteSampler(rng, probabilities, alpha); + } + + /** + * Test sampling from a binomial distribution. + */ + @Test + public void testBinomialSamples() { + final int trials = 67; + final double probabilityOfSuccess = 0.345; + final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess); + final double[] expected = new double[trials + 1]; + for (int i = 0; i < expected.length; i++) { + expected[i] = dist.probability(i); + } + checkSamples(expected, 1.0); + } + + /** + * Test sampling from a Poisson distribution. + */ + @Test + public void testPoissonSamples() { + final double mean = 3.14; + final PoissonDistribution dist = new PoissonDistribution(null, mean, + PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS); + final int maxN = dist.inverseCumulativeProbability(1 - 1e-6); + final double[] expected = new double[maxN]; + for (int i = 0; i < expected.length; i++) { + expected[i] = dist.probability(i); + } + checkSamples(expected, 1.0); + } + + /** + * Test sampling from a non-uniform distribution of probabilities (these sum to 1). + */ + @Test + public void testNonUniformSamplesWithProbabilities() { + final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3}; + checkSamples(expected, 1.0); + } + + /** + * Test sampling from a non-uniform distribution of probabilities with an alpha smaller than + * the default. + */ + @Test + public void testNonUniformSamplesWithProbabilitiesWithSmallAlpha() { + final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3}; + checkSamples(expected, 0.1); + } + + /** + * Test sampling from a non-uniform distribution of probabilities with an alpha larger than + * the default. + */ + @Test + public void testNonUniformSamplesWithProbabilitiesWithLargeAlpha() { + final double[] expected = {0.1, 0.2, 0.3, 0.1, 0.3}; + checkSamples(expected, 10.0); + } + + /** + * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per + * probabilities). + */ + @Test + public void testNonUniformSamplesWithObservations() { + final double[] expected = {1, 2, 3, 1, 3}; + checkSamples(expected, 1.0); + } + + /** + * Test sampling from a non-uniform distribution of probabilities (these sum to 1). + * Extra zero-values are added. + */ + @Test + public void testNonUniformSamplesWithZeroProbabilities() { + final double[] expected = {0.1, 0, 0.2, 0.3, 0.1, 0.3, 0}; + checkSamples(expected, 1.0); + } + + /** + * Test sampling from a non-uniform distribution of observations (i.e. the sum is not 1 as per + * probabilities). Extra zero-values are added. + */ + @Test + public void testNonUniformSamplesWithZeroObservations() { + final double[] expected = {1, 2, 3, 0, 1, 3, 0}; + checkSamples(expected, 1.0); + } + + /** + * Test sampling from a uniform distribution. This is an edge case where there + * are no probabilities less than the mean. + */ + @Test + public void testUniformSamplesWithNoObservationLessThanTheMean() { + final double[] expected = {2, 2, 2, 2, 2, 2}; + checkSamples(expected, 1.0); + } + + /** + * Check the distribution of samples match the expected probabilities. + * + * <p>If the expected probability is zero then this should never be sampled. The non-zero + * probabilities are compared to the sample distribution using a Chi-square test.</p> + * + * @param probabilies the probabilities + * @param alpha the alpha + */ + private static void checkSamples(double[] probabilies, double alpha) { + final GuideTableDiscreteSampler sampler = createSampler(probabilies, alpha); + + final int numberOfSamples = 10000; + final long[] samples = new long[probabilies.length]; + for (int i = 0; i < numberOfSamples; i++) { + samples[sampler.sample()]++; + } + + // Handle a test with some zero-probability observations by mapping them out. + // The results is the Chi-square test is performed using only the non-zero probabilities. + int mapSize = 0; + for (int i = 0; i < probabilies.length; i++) { + if (probabilies[i] != 0) { + mapSize++; + } + } + + final double[] expected = new double[mapSize]; + final long[] observed = new long[mapSize]; + for (int i = 0; i < probabilies.length; i++) { + if (probabilies[i] == 0) { + Assert.assertEquals("No samples expected from zero probability", 0, samples[i]); + } else { + // This can be added for the Chi-square test + --mapSize; + expected[mapSize] = probabilies[i]; + observed[mapSize] = samples[i]; + } + } + + final ChiSquareTest chiSquareTest = new ChiSquareTest(); + // Pass if we cannot reject null hypothesis that the distributions are the same. + Assert.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); + } +}
