This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 9f96ca27d0611c213c0cd68cb4e8126872d12c84 Author: aherbert <[email protected]> AuthorDate: Mon May 23 13:18:45 2022 +0100 RNG-179: Add fast loaded dice roller discrete sampler --- .../EnumeratedDistributionSamplersPerformance.java | 182 ++++- .../FastLoadedDiceRollerDiscreteSampler.java | 856 +++++++++++++++++++++ .../distribution/DiscreteSamplersList.java | 5 + .../FastLoadedDiceRollerDiscreteSamplerTest.java | 452 +++++++++++ src/main/resources/pmd/pmd-ruleset.xml | 16 +- 5 files changed, 1501 insertions(+), 10 deletions(-) diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java index bc55377c..b91969ce 100644 --- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java +++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/EnumeratedDistributionSamplersPerformance.java @@ -22,7 +22,9 @@ import org.apache.commons.math3.distribution.IntegerDistribution; import org.apache.commons.math3.distribution.PoissonDistribution; import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.DirichletSampler; import org.apache.commons.rng.sampling.distribution.DiscreteSampler; +import org.apache.commons.rng.sampling.distribution.FastLoadedDiceRollerDiscreteSampler; import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler; import org.apache.commons.rng.simple.RandomSource; @@ -41,7 +43,6 @@ import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; import java.util.Arrays; -import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -115,6 +116,9 @@ public class EnumeratedDistributionSamplersPerformance { "AliasMethodDiscreteSampler", "GuideTableDiscreteSampler", "MarsagliaTsangWangDiscreteSampler", + "FastLoadedDiceRollerDiscreteSampler", + "FastLoadedDiceRollerDiscreteSamplerLong", + "FastLoadedDiceRollerDiscreteSampler53", // Uncomment to test non-default parameters //"AliasMethodDiscreteSamplerNoPad", // Not optimal for sampling @@ -187,6 +191,19 @@ public class EnumeratedDistributionSamplersPerformance { factory = () -> GuideTableDiscreteSampler.of(rng, probabilities, 8); } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) { factory = () -> MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities); + } else if ("FastLoadedDiceRollerDiscreteSampler".equals(samplerType)) { + factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities); + } else if ("FastLoadedDiceRollerDiscreteSamplerLong".equals(samplerType)) { + // Avoid exact floating-point arithmetic in construction. + // Frequencies must sum to less than 2^63; here the sum is ~2^62. + // This conversion may omit very small probabilities. + final double sum = Arrays.stream(probabilities).sum(); + final long[] frequencies = Arrays.stream(probabilities) + .mapToLong(x -> Math.round(0x1.0p62 * x / sum)) + .toArray(); + factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies); + } else if ("FastLoadedDiceRollerDiscreteSampler53".equals(samplerType)) { + factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, 53); } else { throw new IllegalStateException(); } @@ -335,12 +352,115 @@ public class EnumeratedDistributionSamplersPerformance { /** {@inheritDoc} */ @Override protected double[] createProbabilities() { - final double[] probabilities = new double[randomNonUniformSize]; - final ThreadLocalRandom rng = ThreadLocalRandom.current(); - for (int i = 0; i < probabilities.length; i++) { - probabilities[i] = rng.nextDouble(); - } - return probabilities; + return RandomSource.XO_RO_SHI_RO_128_PP.create() + .doubles(randomNonUniformSize).toArray(); + } + } + + /** + * Sample random probability arrays from a Dirichlet distribution. + * + * <p>The distribution ensures the probabilities sum to 1. + * The <a href="https://en.wikipedia.org/wiki/Entropy_(information_theory)">entropy</a> + * of the probabilities increases with parameters k and alpha. + * The following shows the mean and sd of the entropy from 100 samples + * for a range of parameters. + * <pre> + * k alpha mean sd + * 4 0.500 1.299 0.374 + * 4 1.000 1.531 0.294 + * 4 2.000 1.754 0.172 + * 8 0.500 2.087 0.348 + * 8 1.000 2.490 0.266 + * 8 2.000 2.707 0.142 + * 16 0.500 3.023 0.287 + * 16 1.000 3.454 0.166 + * 16 2.000 3.693 0.095 + * 32 0.500 4.008 0.182 + * 32 1.000 4.406 0.125 + * 32 2.000 4.692 0.075 + * 64 0.500 4.986 0.151 + * 64 1.000 5.392 0.115 + * 64 2.000 5.680 0.048 + * </pre> + */ + @State(Scope.Benchmark) + public static class DirichletDistributionSources extends SamplerSources { + /** Number of categories. */ + @Param({"4", "8", "16"}) + private int k; + + /** Concentration parameter. */ + @Param({"0.5", "1", "2"}) + private double alpha; + + /** {@inheritDoc} */ + @Override + protected double[] createProbabilities() { + return DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(), + k, alpha).sample(); + } + } + + /** + * The {@link FastLoadedDiceRollerDiscreteSampler} samplers to use for testing. + * Creates the sampler for each random source and the probabilities using + * a Dirichlet distribution. + * + * <p>This class is a specialized source to allow examination of the effect of the + * {@link FastLoadedDiceRollerDiscreteSampler} {@code alpha} parameter. + */ + @State(Scope.Benchmark) + public static class FastLoadedDiceRollerDiscreteSamplerSources extends LocalRandomSources { + /** Number of categories. */ + @Param({"4", "8", "16"}) + private int k; + + /** Concentration parameter. */ + @Param({"0.5", "1", "2"}) + private double concentration; + + /** The constructor {@code alpha} parameter. */ + @Param({"0", "30", "53"}) + private int alpha; + + /** The factory. */ + private Supplier<DiscreteSampler> factory; + + /** The sampler. */ + private DiscreteSampler sampler; + + /** + * Gets the sampler. + * + * @return the sampler. + */ + public DiscreteSampler getSampler() { + return sampler; + } + + /** Create the distribution probabilities (per iteration as it may vary), the sampler + * factory and instantiates sampler. */ + @Override + @Setup(Level.Iteration) + public void setup() { + super.setup(); + + final double[] probabilities = + DirichletSampler.symmetric(RandomSource.XO_RO_SHI_RO_128_PP.create(), + k, concentration).sample(); + final UniformRandomProvider rng = getGenerator(); + factory = () -> FastLoadedDiceRollerDiscreteSampler.of(rng, probabilities, alpha); + sampler = factory.get(); + } + + /** + * Creates a new instance of the sampler. + * + * @return The sampler. + */ + public DiscreteSampler createSampler() { + return factory.get(); } } @@ -480,7 +600,7 @@ public class EnumeratedDistributionSamplersPerformance { } /** - * Run the sampler. + * Create and run the sampler. * * @param sources Source of randomness. * @return the sample value @@ -502,7 +622,7 @@ public class EnumeratedDistributionSamplersPerformance { } /** - * Run the sampler. + * Create and run the sampler. * * @param sources Source of randomness. * @return the sample value @@ -511,4 +631,48 @@ public class EnumeratedDistributionSamplersPerformance { public int singleSampleRandom(RandomDistributionSources sources) { return sources.createSampler().sample(); } + + /** + * Run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int sampleDirichlet(DirichletDistributionSources sources) { + return sources.getSampler().sample(); + } + + /** + * Create and run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int singleSampleDirichlet(DirichletDistributionSources sources) { + return sources.createSampler().sample(); + } + + /** + * Run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int sampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) { + return sources.getSampler().sample(); + } + + /** + * Create and run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int singleSampleFast(FastLoadedDiceRollerDiscreteSamplerSources sources) { + return sources.createSampler().sample(); + } } diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSampler.java new file mode 100644 index 00000000..625a73c1 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSampler.java @@ -0,0 +1,856 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.math.BigInteger; +import java.util.Arrays; +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Distribution sampler that uses the Fast Loaded Dice Roller (FLDR). It can be used to + * sample from {@code n} values each with an associated relative weight. If all unique items + * are assigned the same weight it is more efficient to use the {@link DiscreteUniformSampler}. + * + * <p>Given a list {@code L} of {@code n} positive numbers, + * where {@code L[i]} represents the relative weight of the {@code i}th side, FLDR returns + * integer {@code i} with relative probability {@code L[i]}. + * + * <p>FLDR produces <em>exact</em> samples from the specified probability distribution. + * <ul> + * <li>For integer weights, the probability of returning {@code i} is precisely equal to the + * rational number {@code L[i] / m}, where {@code m} is the sum of {@code L}. + * <li>For floating-points weights, each weight {@code L[i]} is converted to the + * corresponding rational number {@code p[i] / q[i]} where {@code p[i]} is a positive integer and + * {@code q[i]} is a power of 2. The rational weights are then normalized (exactly) to sum to unity. + * </ul> + * + * <p>Note that if <em>exact</em> samples are not required then an alternative sampler that + * ignores very small relative weights may have improved sampling performance. + * + * <p>This implementation is based on the algorithm in: + * + * <blockquote> + * Feras A. Saad, Cameron E. Freer, Martin C. Rinard, and Vikash K. Mansinghka. + * The Fast Loaded Dice Roller: A Near-Optimal Exact Sampler for Discrete Probability + * Distributions. In AISTATS 2020: Proceedings of the 23rd International Conference on + * Artificial Intelligence and Statistics, Proceedings of Machine Learning Research 108, + * Palermo, Sicily, Italy, 2020. + * </blockquote> + * + * <p>Sampling uses {@link UniformRandomProvider#nextInt()} as the source of random bits. + * + * @see <a href="https://arxiv.org/abs/2003.03830">Saad et al (2020) + * Proceedings of the 23rd International Conference on Artificial Intelligence and Statistics, + * PMLR 108:1036-1046.</a> + * @since 1.5 + */ +public abstract class FastLoadedDiceRollerDiscreteSampler + implements SharedStateDiscreteSampler { + /** + * The maximum size of an array. + * + * <p>This value is taken from the limit in Open JDK 8 {@code java.util.ArrayList}. + * It allows VMs to reserve some header words in an array. + */ + private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; + /** The maximum biased exponent for a finite double. + * This is offset by 1023 from {@code Math.getExponent(Double.MAX_VALUE)}. */ + private static final int MAX_BIASED_EXPONENT = 2046; + /** Size of the mantissa of a double. Equal to 52 bits. */ + private static final int MANTISSA_SIZE = 52; + /** Mask to extract the 52-bit mantissa from a long representation of a double. */ + private static final long MANTISSA_MASK = 0x000f_ffff_ffff_ffffL; + /** BigInteger representation of {@link Long#MAX_VALUE}. */ + private static final BigInteger MAX_LONG = BigInteger.valueOf(Long.MAX_VALUE); + /** The maximum offset that will avoid loss of bits for a left shift of a 53-bit value. + * The value will remain positive for any shift {@code <=} this value. */ + private static final int MAX_OFFSET = 10; + /** Initial value for no leaf node label. */ + private static final int NO_LABEL = Integer.MAX_VALUE; + /** Name of the sampler. */ + private static final String SAMPLER_NAME = "Fast Loaded Dice Roller"; + + /** + * Class to handle the edge case of observations in only one category. + */ + private static class FixedValueDiscreteSampler extends FastLoadedDiceRollerDiscreteSampler { + /** The sample value. */ + private final int sampleValue; + + /** + * @param sampleValue Sample value. + */ + FixedValueDiscreteSampler(int sampleValue) { + this.sampleValue = sampleValue; + } + + @Override + public int sample() { + return sampleValue; + } + + @Override + public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return this; + } + + @Override + public String toString() { + return SAMPLER_NAME; + } + } + + /** + * Class to implement the FLDR sample algorithm. + */ + private static class FLDRSampler extends FastLoadedDiceRollerDiscreteSampler { + /** Empty boolean source. This is the location of the sign-bit after 31 right shifts on + * the boolean source. */ + private static final int EMPTY_BOOL_SOURCE = 1; + + /** Underlying source of randomness. */ + private final UniformRandomProvider rng; + /** Number of categories. */ + private final int n; + /** Number of levels in the discrete distribution generating (DDG) tree. + * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. */ + private final int k; + /** Number of leaf nodes at each level. */ + private final int[] h; + /** Stores the leaf node labels in increasing order. Named {@code H} in the FLDR paper. */ + private final int[] lH; + + /** + * Provides a bit source for booleans. + * + * <p>A cached value from a call to {@link UniformRandomProvider#nextInt()}. + * + * <p>Only stores 31-bits when full as 1 bit has already been consumed. + * The sign bit is a flag that shifts down so the source eventually equals 1 + * when all bits are consumed and will trigger a refill. + */ + private int booleanSource = EMPTY_BOOL_SOURCE; + + /** + * Creates a sampler. + * + * <p>The input parameters are not validated and must be correctly computed tables. + * + * @param rng Generator of uniformly distributed random numbers. + * @param n Number of categories + * @param k Number of levels in the discrete distribution generating (DDG) tree. + * Equal to {@code ceil(log2(m))} where {@code m} is the sum of observations. + * @param h Number of leaf nodes at each level. + * @param lH Stores the leaf node labels in increasing order. + */ + FLDRSampler(UniformRandomProvider rng, + int n, + int k, + int[] h, + int[] lH) { + this.rng = rng; + this.n = n; + this.k = k; + // Deliberate direct storage of input arrays + this.h = h; + this.lH = lH; + } + + /** + * Creates a copy with a new source of randomness. + * + * @param rng Generator of uniformly distributed random numbers. + * @param source Source to copy. + */ + private FLDRSampler(UniformRandomProvider rng, + FLDRSampler source) { + this.rng = rng; + this.n = source.n; + this.k = source.k; + this.h = source.h; + this.lH = source.lH; + } + + /** {@inheritDoc} */ + @Override + public int sample() { + // ALGORITHM 5: SAMPLE + int c = 0; + int d = 0; + for (;;) { + // b = flip() + // d = 2 * d + (1 - b) + d = (d << 1) + flip(); + if (d < h[c]) { + // z = H[d][c] + final int z = lH[d * k + c]; + // assert z != NO_LABEL + if (z < n) { + return z; + } + d = 0; + c = 0; + } else { + d = d - h[c]; + c++; + } + } + } + + /** + * Provides a source of boolean bits. + * + * <p>Note: This replicates the boolean cache functionality of + * {@code o.a.c.rng.core.source32.IntProvider}. The method has been simplified to return + * an {@code int} value rather than a {@code boolean}. + * + * @return the bit (0 or 1) + */ + private int flip() { + int bits = booleanSource; + if (bits == 1) { + // Refill + bits = rng.nextInt(); + // Store a refill flag in the sign bit and the unused 31 bits, return lowest bit + booleanSource = Integer.MIN_VALUE | (bits >>> 1); + return bits & 0x1; + } + // Shift down eventually triggering refill, return current lowest bit + booleanSource = bits >>> 1; + return bits & 0x1; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return SAMPLER_NAME + " [" + rng.toString() + "]"; + } + + /** {@inheritDoc} */ + @Override + public FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng) { + return new FLDRSampler(rng, this); + } + } + + /** Package-private constructor. */ + FastLoadedDiceRollerDiscreteSampler() { + // Intentionally empty + } + + /** {@inheritDoc} */ + // Redeclare the signature to return a FastLoadedDiceRollerSampler not a SharedStateLongSampler + @Override + public abstract FastLoadedDiceRollerDiscreteSampler withUniformRandomProvider(UniformRandomProvider rng); + + /** + * Creates a sampler. + * + * <p>Note: The discrete distribution generating (DDG) tree requires {@code (n + 1) * k} entries + * where {@code n} is the number of categories, {@code k == ceil(log2(m))} and {@code m} + * is the sum of the observed frequencies. An exception is raised if this cannot be allocated + * as a single array. + * + * <p>For reference the sum is limited to {@link Long#MAX_VALUE} and the value {@code k} to 63. + * The number of categories is limited to approximately {@code ((2^31 - 1) / k) = 34,087,042} + * when the sum of frequencies is large enough to create k=63. + * + * @param rng Generator of uniformly distributed random numbers. + * @param frequencies Observed frequencies of the discrete distribution. + * @return the sampler + * @throws IllegalArgumentException if {@code frequencies} is null or empty, a + * frequency is negative, the sum of all frequencies is either zero or + * above {@link Long#MAX_VALUE}, or the size of the discrete distribution generating tree + * is too large. + */ + public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, + long[] frequencies) { + final long m = sum(frequencies); + + // Obtain indices of non-zero frequencies + final int[] indices = indicesOfNonZero(frequencies); + + // Edge case for 1 non-zero weight. This also handles edge case for 1 observation + // (as log2(m) == 0 will break the computation of the DDG tree). + if (indices.length == 1) { + return new FixedValueDiscreteSampler(indexOfNonZero(frequencies)); + } + + return createSampler(rng, frequencies, indices, m); + } + + /** + * Creates a sampler. + * + * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is a power of 2. + * The numerators {@code p} are scaled to use a common denominator before summing. + * + * <p>All weights are used to create the sampler. Weights with a small magnitude relative + * to the largest weight can be excluded using the constructor method with the + * relative magnitude parameter {@code alpha} (see {@link #of(UniformRandomProvider, double[], int)}). + * + * @param rng Generator of uniformly distributed random numbers. + * @param weights Weights of the discrete distribution. + * @return the sampler + * @throws IllegalArgumentException if {@code weights} is null or empty, a + * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, or the size + * of the discrete distribution generating tree is too large. + * @see #of(UniformRandomProvider, double[], int) + */ + public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, + double[] weights) { + return of(rng, weights, 0); + } + + /** + * Creates a sampler. + * + * <p>Weights are converted to rational numbers {@code p / q} where {@code q} is + * a power of 2. The numerators {@code p} are scaled to use a common + * denominator before summing. + * + * <p>Note: The discrete distribution generating (DDG) tree requires + * {@code (n + 1) * k} entries where {@code n} is the number of categories, + * {@code k == ceil(log2(m))} and {@code m} is the sum of the weight numerators + * {@code q}. An exception is raised if this cannot be allocated as a single + * array. + * + * <p>For reference the value {@code k} is equal to or greater than the ratio of + * the largest to the smallest weight expressed as a power of 2. For + * {@code Double.MAX_VALUE / Double.MIN_VALUE} this is ~2098. The value + * {@code k} increases with the sum of the weight numerators. A number of + * weights in excess of 1,000,000 with values equal to {@link Double#MAX_VALUE} + * would be required to raise an exception when the minimum weight is + * {@link Double#MIN_VALUE}. + * + * <p>Weights with a small magnitude relative to the largest weight can be + * excluded using the relative magnitude parameter {@code alpha}. This will set + * any weight to zero if the magnitude is approximately 2<sup>alpha</sup> + * <em>smaller</em> than the largest weight. This comparison is made using only + * the exponent of the input weights. The {@code alpha} parameter is ignored if + * not above zero. Note that a small {@code alpha} parameter will exclude more + * weights than a large {@code alpha} parameter. + * + * <p>The alpha parameter can be used to exclude categories that + * have a very low probability of occurrence and will improve the construction + * performance of the sampler. The effect on sampling performance depends on + * the relative weights of the excluded categories; typically a high {@code alpha} + * is used to exclude categories that would be visited with a very low probability + * and the sampling performance is unchanged. + * + * <p><b>Implementation Note</b> + * + * <p>This method creates a sampler with <em>exact</em> samples from the + * specified probability distribution. It is recommended to use this method: + * <ul> + * <li>if the weights are computed, for example from a probability mass function; or + * <li>if the weights sum to an infinite value. + * </ul> + * + * <p>If the weights are computed from empirical observations then it is + * recommended to use the factory method + * {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This + * requires the total number of observations to be representable as a long + * integer. + * + * <p>Note that if all weights are scaled by a power of 2 to be integers, and + * each integer can be represented as a positive 64-bit long value, then the + * sampler created using this method will match the output from a sampler + * created with the scaled weights converted to long values for the factory + * method {@link #of(UniformRandomProvider, long[]) accepting frequencies}. This + * assumes the sum of the integer values does not overflow. + * + * <p>It should be noted that the conversion of weights to rational numbers has + * a performance overhead during construction (sampling performance is not + * affected). This may be avoided by first converting them to integer values + * that can be summed without overflow. For example by scaling values by + * {@code 2^62 / sum} and converting to long by casting or rounding. + * + * <p>This approach may increase the efficiency of construction. The resulting + * sampler may no longer produce <em>exact</em> samples from the distribution. + * In particular any weights with a converted frequency of zero cannot be + * sampled. + * + * @param rng Generator of uniformly distributed random numbers. + * @param weights Weights of the discrete distribution. + * @param alpha Alpha parameter. + * @return the sampler + * @throws IllegalArgumentException if {@code weights} is null or empty, a + * weight is negative, infinite or {@code NaN}, the sum of all weights is zero, + * or the size of the discrete distribution generating tree is too large. + * @see #of(UniformRandomProvider, long[]) + */ + public static FastLoadedDiceRollerDiscreteSampler of(UniformRandomProvider rng, + double[] weights, + int alpha) { + final int n = checkWeightsNonZeroLength(weights); + + // Convert floating-point double to a relative weight + // using a shifted integer representation + final long[] frequencies = new long[n]; + final int[] offsets = new int[n]; + convertToIntegers(weights, frequencies, offsets, alpha); + + // Obtain indices of non-zero weights + final int[] indices = indicesOfNonZero(frequencies); + + // Edge case for 1 non-zero weight. + if (indices.length == 1) { + return new FixedValueDiscreteSampler(indexOfNonZero(frequencies)); + } + + final BigInteger m = sum(frequencies, offsets, indices); + + // Use long arithmetic if possible. This occurs when the weights are similar in magnitude. + if (m.compareTo(MAX_LONG) <= 0) { + // Apply the offset + for (int i = 0; i < n; i++) { + frequencies[i] <<= offsets[i]; + } + return createSampler(rng, frequencies, indices, m.longValue()); + } + + return createSampler(rng, frequencies, offsets, indices, m); + } + + /** + * Sum the frequencies. + * + * @param frequencies Frequencies. + * @return the sum + * @throws IllegalArgumentException if {@code frequencies} is null or empty, a + * frequency is negative, or the sum of all frequencies is either zero or above + * {@link Long#MAX_VALUE} + */ + private static long sum(long[] frequencies) { + // Validate + if (frequencies == null || frequencies.length == 0) { + throw new IllegalArgumentException("frequencies must contain at least 1 value"); + } + + // Sum the values. + // Combine all the sign bits in the observations and the intermediate sum in a flag. + long m = 0; + long signFlag = 0; + for (final long o : frequencies) { + m += o; + signFlag |= o | m; + } + + // Check for a sign-bit. + if (signFlag < 0) { + // One or more observations were negative, or the sum overflowed. + for (final long o : frequencies) { + if (o < 0) { + throw new IllegalArgumentException("frequencies must contain positive values: " + o); + } + } + throw new IllegalArgumentException("Overflow when summing frequencies"); + } + if (m == 0) { + throw new IllegalArgumentException("Sum of frequencies is zero"); + } + return m; + } + + /** + * Convert the floating-point weights to relative weights represented as + * integers {@code value * 2^exponent}. The relative weight as an integer is: + * + * <pre> + * BigInteger.valueOf(value).shiftLeft(exponent) + * </pre> + * + * <p>Note that the weights are created using a common power-of-2 scaling + * operation so the minimum exponent is zero. + * + * <p>A positive {@code alpha} parameter is used to set any weight to zero if + * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the + * largest weight. This comparison is made using only the exponent of the input + * weights. + * + * @param weights Weights of the discrete distribution. + * @param values Output floating-point mantissas converted to 53-bit integers. + * @param exponents Output power of 2 exponent. + * @param alpha Alpha parameter. + * @throws IllegalArgumentException if a weight is negative, infinite or + * {@code NaN}, or the sum of all weights is zero. + */ + private static void convertToIntegers(double[] weights, long[] values, int[] exponents, int alpha) { + int maxExponent = Integer.MIN_VALUE; + for (int i = 0; i < weights.length; i++) { + final double weight = weights[i]; + // Ignore zero. + // When creating the integer value later using bit shifts the result will remain zero. + if (weight == 0) { + continue; + } + final long bits = Double.doubleToRawLongBits(weight); + + // For the IEEE 754 format see Double.longBitsToDouble(long). + + // Extract the exponent (with the sign bit) + int exp = (int) (bits >>> MANTISSA_SIZE); + // Detect negative, infinite or NaN. + // Note: Negative values sign bit will cause the exponent to be too high. + if (exp > MAX_BIASED_EXPONENT) { + throw new IllegalArgumentException("Invalid weight: " + weight); + } + long mantissa; + if (exp == 0) { + // Sub-normal number: + mantissa = (bits & MANTISSA_MASK) << 1; + // Here we convert to a normalised number by counting the leading zeros + // to obtain the number of shifts of the most significant bit in + // the mantissa that is required to get a 1 at position 53 (i.e. as + // if it were a normal number with assumed leading bit). + final int shift = Long.numberOfLeadingZeros(mantissa << 11); + mantissa <<= shift; + exp -= shift; + } else { + // Normal number. Add the implicit leading 1-bit. + mantissa = (bits & MANTISSA_MASK) | (1L << MANTISSA_SIZE); + } + + // Here the floating-point number is equal to: + // mantissa * 2^(exp-1075) + + values[i] = mantissa; + exponents[i] = exp; + maxExponent = Math.max(maxExponent, exp); + } + + // No exponent indicates that all weights are zero + if (maxExponent == Integer.MIN_VALUE) { + throw new IllegalArgumentException("Sum of weights is zero"); + } + + filterWeights(values, exponents, alpha, maxExponent); + scaleWeights(values, exponents); + } + + /** + * Filters small weights using the {@code alpha} parameter. + * A positive {@code alpha} parameter is used to set any weight to zero if + * the magnitude is approximately 2<sup>alpha</sup> <em>smaller</em> than the + * largest weight. This comparison is made using only the exponent of the input + * weights. + * + * @param values 53-bit values. + * @param exponents Power of 2 exponent. + * @param alpha Alpha parameter. + * @param maxExponent Maximum exponent. + */ + private static void filterWeights(long[] values, int[] exponents, int alpha, int maxExponent) { + if (alpha > 0) { + // Filter weights. This must be done before the values are shifted so + // the exponent represents the approximate magnitude of the value. + for (int i = 0; i < exponents.length; i++) { + if (maxExponent - exponents[i] > alpha) { + values[i] = 0; + } + } + } + } + + /** + * Scale the weights represented as integers {@code value * 2^exponent} to use a + * minimum exponent of zero. The values are scaled to remove any common trailing zeros + * in their representation. This ultimately reduces the size of the discrete distribution + * generating (DGG) tree. + * + * @param values 53-bit values. + * @param exponents Power of 2 exponent. + */ + private static void scaleWeights(long[] values, int[] exponents) { + // Find the minimum exponent and common trailing zeros. + int minExponent = Integer.MAX_VALUE; + for (int i = 0; i < exponents.length; i++) { + if (values[i] != 0) { + minExponent = Math.min(minExponent, exponents[i]); + } + } + // Trailing zeros occur when the original weights have a representation with + // less than 52 binary digits, e.g. {1.5, 0.5, 0.25}. + int trailingZeros = Long.SIZE; + for (int i = 0; i < values.length && trailingZeros != 0; i++) { + trailingZeros = Math.min(trailingZeros, Long.numberOfTrailingZeros(values[i])); + } + // Scale by a power of 2 so the minimum exponent is zero. + for (int i = 0; i < exponents.length; i++) { + exponents[i] -= minExponent; + } + // Remove common trailing zeros. + if (trailingZeros != 0) { + for (int i = 0; i < values.length; i++) { + values[i] >>>= trailingZeros; + } + } + } + + /** + * Sum the integers at the specified indices. + * Integers are represented as {@code value * 2^exponent}. + * + * @param values 53-bit values. + * @param exponents Power of 2 exponent. + * @param indices Indices to sum. + * @return the sum + */ + private static BigInteger sum(long[] values, int[] exponents, int[] indices) { + BigInteger m = BigInteger.ZERO; + for (final int i : indices) { + m = m.add(toBigInteger(values[i], exponents[i])); + } + return m; + } + + /** + * Convert the value and left shift offset to a BigInteger. + * It is assumed the value is at most 53-bits. This allows optimising the left + * shift if it is below 11 bits. + * + * @param value 53-bit value. + * @param offset Left shift offset (must be positive). + * @return the BigInteger + */ + private static BigInteger toBigInteger(long value, int offset) { + // Ignore zeros. The sum method uses indices of non-zero values. + if (offset <= MAX_OFFSET) { + // Assume (value << offset) <= Long.MAX_VALUE + return BigInteger.valueOf(value << offset); + } + return BigInteger.valueOf(value).shiftLeft(offset); + } + + /** + * Creates the sampler. + * + * <p>It is assumed the frequencies are all positive and the sum does not + * overflow. + * + * @param rng Generator of uniformly distributed random numbers. + * @param frequencies Observed frequencies of the discrete distribution. + * @param indices Indices of non-zero frequencies. + * @param m Sum of the frequencies. + * @return the sampler + */ + private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, + long[] frequencies, + int[] indices, + long m) { + // ALGORITHM 5: PREPROCESS + // a == frequencies + // m = sum(a) + // h = leaf node count + // H = leaf node label (lH) + + final int n = frequencies.length; + + // k = ceil(log2(m)) + final int k = 64 - Long.numberOfLeadingZeros(m - 1); + // r = a(n+1) = 2^k - m + final long r = (1L << k) - m; + + // Note: + // A sparse matrix can often be used for H, as most of its entries are empty. + // This implementation uses a 1D array for efficiency at the cost of memory. + // This is limited to approximately ((2^31 - 1) / k), e.g. 34087042 when the sum of + // observations is large enough to create k=63. + // This could be handled using a 2D array. In practice a number of categories this + // large is not expected and is currently not supported. + final int[] h = new int[k]; + final int[] lH = new int[checkArraySize((n + 1L) * k)]; + Arrays.fill(lH, NO_LABEL); + + int d; + for (int j = 0; j < k; j++) { + final int shift = (k - 1) - j; + final long bitMask = 1L << shift; + + d = 0; + for (final int i : indices) { + // bool w ← (a[i] >> (k − 1) − j)) & 1 + // h[j] = h[j] + w + // if w then: + if ((frequencies[i] & bitMask) != 0) { + h[j]++; + // H[d][j] = i + lH[d * k + j] = i; + d++; + } + } + // process a(n+1) without extending the input frequencies array by 1 + if ((r & bitMask) != 0) { + h[j]++; + lH[d * k + j] = n; + } + } + + return new FLDRSampler(rng, n, k, h, lH); + } + + /** + * Creates the sampler. Frequencies are are represented as a 53-bit value with a + * left-shift offset. + * <pre> + * BigInteger.valueOf(value).shiftLeft(offset) + * </pre> + * + * <p>It is assumed the frequencies are all positive. + * + * @param rng Generator of uniformly distributed random numbers. + * @param frequencies Observed frequencies of the discrete distribution. + * @param offsets Left shift offsets (must be positive). + * @param indices Indices of non-zero frequencies. + * @param m Sum of the frequencies. + * @return the sampler + */ + private static FastLoadedDiceRollerDiscreteSampler createSampler(UniformRandomProvider rng, + long[] frequencies, + int[] offsets, + int[] indices, + BigInteger m) { + // Repeat the logic from createSampler(...) using extended arithmetic to test the bits + + // ALGORITHM 5: PREPROCESS + // a == frequencies + // m = sum(a) + // h = leaf node count + // H = leaf node label (lH) + + final int n = frequencies.length; + + // k = ceil(log2(m)) + final int k = m.subtract(BigInteger.ONE).bitLength(); + // r = a(n+1) = 2^k - m + final BigInteger r = BigInteger.ONE.shiftLeft(k).subtract(m); + + final int[] h = new int[k]; + final int[] lH = new int[checkArraySize((n + 1L) * k)]; + Arrays.fill(lH, NO_LABEL); + + int d; + for (int j = 0; j < k; j++) { + final int shift = (k - 1) - j; + + d = 0; + for (final int i : indices) { + // bool w ← (a[i] >> (k − 1) − j)) & 1 + // h[j] = h[j] + w + // if w then: + if (testBit(frequencies[i], offsets[i], shift)) { + h[j]++; + // H[d][j] = i + lH[d * k + j] = i; + d++; + } + } + // process a(n+1) without extending the input frequencies array by 1 + if (r.testBit(shift)) { + h[j]++; + lH[d * k + j] = n; + } + } + + return new FLDRSampler(rng, n, k, h, lH); + } + + /** + * Test the logical bit of the shifted integer representation. + * The value is assumed to have at most 53-bits of information. The offset + * is assumed to be positive. This is functionally equivalent to: + * <pre> + * BigInteger.valueOf(value).shiftLeft(offset).testBit(n) + * </pre> + * + * @param value 53-bit value. + * @param offset Left shift offset. + * @param n Index of bit to test. + * @return true if the bit is 1 + */ + private static boolean testBit(long value, int offset, int n) { + if (n < offset) { + // All logical trailing bits are zero + return false; + } + // Test if outside the 53-bit value (note that the implicit 1 bit + // has been added to the 52-bit mantissas for 'normal' floating-point numbers). + final int bit = n - offset; + return bit <= MANTISSA_SIZE && (value & (1L << bit)) != 0; + } + + /** + * Check the weights have a non-zero length. + * + * @param weights Weights. + * @return the length + */ + private static int checkWeightsNonZeroLength(double[] weights) { + if (weights == null || weights.length == 0) { + throw new IllegalArgumentException("weights must contain at least 1 value"); + } + return weights.length; + } + + /** + * Create the indices of non-zero values. + * + * @param values Values. + * @return the indices + */ + private static int[] indicesOfNonZero(long[] values) { + int n = 0; + final int[] indices = new int[values.length]; + for (int i = 0; i < values.length; i++) { + if (values[i] != 0) { + indices[n++] = i; + } + } + return Arrays.copyOf(indices, n); + } + + /** + * Find the index of the first non-zero frequency. + * + * @param frequencies Frequencies. + * @return the index + * @throws IllegalStateException if all frequencies are zero. + */ + static int indexOfNonZero(long[] frequencies) { + for (int i = 0; i < frequencies.length; i++) { + if (frequencies[i] != 0) { + return i; + } + } + throw new IllegalStateException("All frequencies are zero"); + } + + /** + * Check the size is valid for a 1D array. + * + * @param size Size + * @return the size as an {@code int} + * @throws IllegalArgumentException if the size is too large for a 1D array. + */ + static int checkArraySize(long size) { + if (size > MAX_ARRAY_SIZE) { + throw new IllegalArgumentException("Unable to allocate array of size: " + size); + } + return (int) size; + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java index 81218db1..1a33921d 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java @@ -181,12 +181,17 @@ public final class DiscreteSamplersList { // Any discrete distribution final int[] discretePoints = {0, 1, 2, 3, 4}; final double[] discreteProbabilities = {0.1, 0.2, 0.3, 0.4, 0.5}; + final long[] discreteFrequencies = {1, 2, 3, 4, 5}; add(LIST, discretePoints, discreteProbabilities, MarsagliaTsangWangDiscreteSampler.Enumerated.of(RandomSource.XO_SHI_RO_512_PLUS.create(), discreteProbabilities)); add(LIST, discretePoints, discreteProbabilities, GuideTableDiscreteSampler.of(RandomSource.XO_SHI_RO_512_SS.create(), discreteProbabilities)); add(LIST, discretePoints, discreteProbabilities, AliasMethodDiscreteSampler.of(RandomSource.KISS.create(), discreteProbabilities)); + add(LIST, discretePoints, discreteProbabilities, + FastLoadedDiceRollerDiscreteSampler.of(RandomSource.L64_X128_MIX.create(), discreteFrequencies)); + add(LIST, discretePoints, discreteProbabilities, + FastLoadedDiceRollerDiscreteSampler.of(RandomSource.L64_X128_SS.create(), discreteProbabilities)); } catch (Exception e) { // CHECKSTYLE: stop Regexp System.err.println("Unexpected exception while creating the list of samplers: " + e); diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSamplerTest.java new file mode 100644 index 00000000..27350dc9 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/FastLoadedDiceRollerDiscreteSamplerTest.java @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.Arrays; +import java.util.function.DoubleUnaryOperator; +import java.util.stream.Stream; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.commons.math3.stat.inference.ChiSquareTest; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.RandomAssert; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Test for the {@link FastLoadedDiceRollerDiscreteSampler}. + */ +class FastLoadedDiceRollerDiscreteSamplerTest { + /** + * Creates the sampler. + * + * @param frequencies Observed frequencies. + * @return the FLDR sampler + */ + private static SharedStateDiscreteSampler createSampler(long... frequencies) { + final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create(); + return FastLoadedDiceRollerDiscreteSampler.of(rng, frequencies); + } + + /** + * Creates the sampler. + * + * @param weights Weights. + * @return the FLDR sampler + */ + private static SharedStateDiscreteSampler createSampler(double... weights) { + final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create(); + return FastLoadedDiceRollerDiscreteSampler.of(rng, weights); + } + + /** + * Return a stream of invalid frequencies for a discrete distribution. + * + * @return the stream of invalid frequencies + */ + static Stream<long[]> testFactoryConstructorFrequencies() { + return Stream.of( + // Null or empty + (long[]) null, + new long[0], + // Negative + new long[] {-1, 2, 3}, + new long[] {1, -2, 3}, + new long[] {1, 2, -3}, + // Overflow of sum + new long[] {Long.MAX_VALUE, Long.MAX_VALUE}, + // x+x+2 == 0 + new long[] {Long.MAX_VALUE, Long.MAX_VALUE, 2}, + // x+x+x == x - 2 (i.e. positive) + new long[] {Long.MAX_VALUE, Long.MAX_VALUE, Long.MAX_VALUE}, + // Zero sum + new long[1], + new long[4] + ); + } + + @ParameterizedTest + @MethodSource + void testFactoryConstructorFrequencies(long[] frequencies) { + Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(frequencies)); + } + + /** + * Return a stream of invalid weights for a discrete distribution. + * + * @return the stream of invalid weights + */ + static Stream<double[]> testFactoryConstructorWeights() { + return Stream.of( + // Null or empty + (double[]) null, + new double[0], + // Negative, infinite or NaN + new double[] {-1, 2, 3}, + new double[] {1, -2, 3}, + new double[] {1, 2, -3}, + new double[] {Double.POSITIVE_INFINITY, 2, 3}, + new double[] {1, Double.POSITIVE_INFINITY, 3}, + new double[] {1, 2, Double.POSITIVE_INFINITY}, + new double[] {Double.NaN, 2, 3}, + new double[] {1, Double.NaN, 3}, + new double[] {1, 2, Double.NaN}, + // Zero sum + new double[1], + new double[4] + ); + } + + @ParameterizedTest + @MethodSource + void testFactoryConstructorWeights(double[] weights) { + Assertions.assertThrows(IllegalArgumentException.class, () -> createSampler(weights)); + } + + @Test + void testToString() { + for (final long[] observed : new long[][] {{42}, {1, 2, 3}}) { + final SharedStateDiscreteSampler sampler = createSampler(observed); + Assertions.assertTrue(sampler.toString().toLowerCase().contains("fast loaded dice roller")); + } + } + + @Test + void testSingleCategory() { + final int n = 13; + final int[] expected = new int[n]; + Assertions.assertArrayEquals(expected, createSampler(42).samples(n).toArray()); + Assertions.assertArrayEquals(expected, createSampler(0.55).samples(n).toArray()); + } + + @Test + void testSingleFrequency() { + final long[] frequencies = new long[5]; + final int category = 2; + frequencies[category] = 1; + final SharedStateDiscreteSampler sampler = createSampler(frequencies); + final int n = 7; + final int[] expected = new int[n]; + Arrays.fill(expected, category); + Assertions.assertArrayEquals(expected, sampler.samples(n).toArray()); + } + + @Test + void testSingleWeight() { + final double[] weights = new double[5]; + final int category = 3; + weights[category] = 1.5; + final SharedStateDiscreteSampler sampler = createSampler(weights); + final int n = 6; + final int[] expected = new int[n]; + Arrays.fill(expected, category); + Assertions.assertArrayEquals(expected, sampler.samples(n).toArray()); + } + + @Test + void testIndexOfNonZero() { + Assertions.assertThrows(IllegalStateException.class, + () -> FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(new long[3])); + final long[] data = new long[3]; + for (int i = 0; i < data.length; i++) { + data[i] = 13; + Assertions.assertEquals(i, FastLoadedDiceRollerDiscreteSampler.indexOfNonZero(data)); + data[i] = 0; + } + } + + @ParameterizedTest + @ValueSource(longs = {0, 1, -1, Integer.MAX_VALUE, 1L << 34}) + void testCheckArraySize(long size) { + // This is the same value as the sampler + final int max = Integer.MAX_VALUE - 8; + // Note: The method does not test for negatives. + // This is not required when validating a positive int multiplied by another positive int. + if (size > max) { + Assertions.assertThrows(IllegalArgumentException.class, + () -> FastLoadedDiceRollerDiscreteSampler.checkArraySize(size)); + } else { + Assertions.assertEquals((int) size, FastLoadedDiceRollerDiscreteSampler.checkArraySize(size)); + } + } + + /** + * Return a stream of expected frequencies for a discrete distribution. + * + * @return the stream of expected frequencies + */ + static Stream<long[]> testSamplesFrequencies() { + return Stream.of( + // Single category + new long[] {0, 0, 42, 0, 0}, + // Sum to a power of 2 + new long[] {1, 1, 2, 3, 1}, + new long[] {0, 1, 1, 0, 2, 3, 1, 0}, + // Do not sum to a power of 2 + new long[] {1, 2, 3, 1, 3}, + new long[] {1, 0, 2, 0, 3, 1, 3}, + // Large frequencies + new long[] {5126734627834L, 213267384684832L, 126781236718L, 71289979621378L} + ); + } + + /** + * Check the distribution of samples match the expected probabilities. + * + * @param expectedFrequencies Expected frequencies. + */ + @ParameterizedTest + @MethodSource + void testSamplesFrequencies(long[] expectedFrequencies) { + final SharedStateDiscreteSampler sampler = createSampler(expectedFrequencies); + final int numberOfSamples = 10000; + final long[] samples = new long[expectedFrequencies.length]; + sampler.samples(numberOfSamples).forEach(x -> samples[x]++); + + // Handle a test with some zero-probability observations by mapping them out + int mapSize = 0; + double sum = 0; + for (final double f : expectedFrequencies) { + if (f != 0) { + mapSize++; + sum += f; + } + } + + // Single category will break the Chi-square test + if (mapSize == 1) { + int index = 0; + while (index < expectedFrequencies.length) { + if (expectedFrequencies[index] != 0) { + break; + } + index++; + } + Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples"); + return; + } + + final double[] expected = new double[mapSize]; + final long[] observed = new long[mapSize]; + for (int i = 0; i < expectedFrequencies.length; i++) { + if (expectedFrequencies[i] != 0) { + --mapSize; + expected[mapSize] = expectedFrequencies[i] / sum; + observed[mapSize] = samples[i]; + } else { + Assertions.assertEquals(0, samples[i], "No samples expected from zero probability"); + } + } + + final ChiSquareTest chiSquareTest = new ChiSquareTest(); + // Pass if we cannot reject null hypothesis that the distributions are the same. + Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); + } + + /** + * Return a stream of expected weights for a discrete distribution. + * + * @return the stream of expected weights + */ + static Stream<double[]> testSamplesWeights() { + return Stream.of( + // Single category + new double[] {0, 0, 0.523, 0, 0}, + // Sum to a power of 2 + new double[] {0.125, 0.125, 0.25, 0.375, 0.125}, + new double[] {0, 0.125, 0.125, 0.25, 0, 0.375, 0.125, 0}, + // Do not sum to a power of 2 + new double[] {0.1, 0.2, 0.3, 0.1, 0.3}, + new double[] {0.1, 0, 0.2, 0, 0.3, 0.1, 0.3}, + // Sub-normal numbers + new double[] {5 * Double.MIN_NORMAL, 2 * Double.MIN_NORMAL, 3 * Double.MIN_NORMAL, 9 * Double.MIN_NORMAL}, + new double[] {2 * Double.MIN_NORMAL, Double.MIN_NORMAL, 0.5 * Double.MIN_NORMAL, 0.75 * Double.MIN_NORMAL}, + new double[] {Double.MIN_VALUE, 2 * Double.MIN_VALUE, 3 * Double.MIN_VALUE, 7 * Double.MIN_VALUE}, + // Large range of magnitude + new double[] {1.0, 2.0, Math.scalb(3.0, -32), Math.scalb(4.0, -65), 5.0}, + new double[] {Math.scalb(1.0, 35), Math.scalb(2.0, 35), Math.scalb(3.0, -32), Math.scalb(4.0, -65), Math.scalb(5.0, 35)}, + // Sum to infinite + new double[] {Double.MAX_VALUE, Double.MAX_VALUE, Double.MAX_VALUE / 2, Double.MAX_VALUE / 4} + ); + } + + /** + * Check the distribution of samples match the expected weights. + * + * @param weights Category weights. + */ + @ParameterizedTest + @MethodSource + void testSamplesWeights(double[] weights) { + final SharedStateDiscreteSampler sampler = createSampler(weights); + final int numberOfSamples = 10000; + final long[] samples = new long[weights.length]; + sampler.samples(numberOfSamples).forEach(x -> samples[x]++); + + // Handle a test with some zero-probability observations by mapping them out + int mapSize = 0; + double sum = 0; + // Handle infinite sum using a rolling mean for normalisation + final Mean mean = new Mean(); + for (final double w : weights) { + if (w != 0) { + mapSize++; + sum += w; + mean.increment(w); + } + } + + // Single category will break the Chi-square test + if (mapSize == 1) { + int index = 0; + while (index < weights.length) { + if (weights[index] != 0) { + break; + } + index++; + } + Assertions.assertEquals(numberOfSamples, samples[index], "Invalid single category samples"); + return; + } + + final double mu = mean.getResult(); + final int n = mapSize; + final double s = sum; + final DoubleUnaryOperator normalise = Double.isInfinite(sum) ? + x -> (x / mu) * n : + x -> x / s; + + final double[] expected = new double[mapSize]; + final long[] observed = new long[mapSize]; + for (int i = 0; i < weights.length; i++) { + if (weights[i] != 0) { + --mapSize; + expected[mapSize] = normalise.applyAsDouble(weights[i]); + observed[mapSize] = samples[i]; + } else { + Assertions.assertEquals(0, samples[i], "No samples expected from zero probability"); + } + } + + final ChiSquareTest chiSquareTest = new ChiSquareTest(); + // Pass if we cannot reject null hypothesis that the distributions are the same. + Assertions.assertFalse(chiSquareTest.chiSquareTest(expected, observed, 0.001)); + } + + /** + * Check the distribution of samples when the frequencies can be converted to weights without + * loss of precision. + * + * @param frequencies Expected frequencies. + */ + @ParameterizedTest + @MethodSource(value = {"testSamplesFrequencies"}) + void testSamplesWeightsMatchesFrequencies(long[] frequencies) { + final double[] weights = new double[frequencies.length]; + for (int i = 0; i < frequencies.length; i++) { + final double w = frequencies[i]; + Assumptions.assumeTrue((long) w == frequencies[i]); + // Ensure the exponent is set in the event of simple frequencies + weights[i] = Math.scalb(w, -35); + } + final long seed = RandomSource.createLong(); + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed); + final SharedStateDiscreteSampler sampler1 = + FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies); + final SharedStateDiscreteSampler sampler2 = + FastLoadedDiceRollerDiscreteSampler.of(rng2, weights); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Test scaled weights. The sampler uses the relative magnitude of weights and the + * output should be invariant to scaling. The weights are sampled from the 2^53 dyadic + * rationals in [0, 1). A scale factor of -1021 is the lower limit if a weight is + * 2^-53 to maintain a non-zero weight. The upper limit is 1023 if a weight is 1 to avoid + * infinite values. Note that it does not matter if the sum of weights is infinite; only + * the individual weights must be finite. + * + * @param scaleFactor the scale factor + */ + @ParameterizedTest + @ValueSource(ints = {1023, 67, 1, -59, -1020, -1021}) + void testScaledWeights(int scaleFactor) { + // Weights in [0, 1) + final double[] w1 = RandomSource.KISS.create().doubles(10).toArray(); + final double scale = Math.scalb(1.0, scaleFactor); + final double[] w2 = Arrays.stream(w1).map(x -> x * scale).toArray(); + final long seed = RandomSource.createLong(); + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed); + final SharedStateDiscreteSampler sampler1 = + FastLoadedDiceRollerDiscreteSampler.of(rng1, w1); + final SharedStateDiscreteSampler sampler2 = + FastLoadedDiceRollerDiscreteSampler.of(rng2, w2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } + + /** + * Test the alpha parameter removes small relative weights. + * Weights should be removed if they are {@code 2^alpha} smaller than the largest + * weight. + * + * @param alpha Alpha parameter + */ + @ParameterizedTest + @ValueSource(ints = {13, 30, 53}) + void testAlphaRemovesWeights(int alpha) { + // The small weight must be > 2^alpha smaller so scale by (alpha + 1) + final double small = Math.scalb(1.0, -(alpha + 1)); + final double[] w1 = {1, 0.5, 0.5, 0}; + final double[] w2 = {1, 0.5, 0.5, small}; + final long seed = RandomSource.createLong(); + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(seed); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(seed); + final UniformRandomProvider rng3 = RandomSource.SPLIT_MIX_64.create(seed); + + final int n = 10; + final int[] s1 = FastLoadedDiceRollerDiscreteSampler.of(rng1, w1).samples(n).toArray(); + final int[] s2 = FastLoadedDiceRollerDiscreteSampler.of(rng2, w2, alpha).samples(n).toArray(); + final int[] s3 = FastLoadedDiceRollerDiscreteSampler.of(rng3, w2, alpha + 1).samples(n).toArray(); + + Assertions.assertArrayEquals(s1, s2, "alpha parameter should ignore the small weight"); + Assertions.assertFalse(Arrays.equals(s1, s3), "alpha+1 parameter should not ignore the small weight"); + } + + static Stream<long[]> testSharedStateSampler() { + return Stream.of( + new long[] {42}, + new long[] {1, 1, 2, 3, 1} + ); + } + + @ParameterizedTest + @MethodSource + void testSharedStateSampler(long[] frequencies) { + final UniformRandomProvider rng1 = RandomSource.SPLIT_MIX_64.create(0L); + final UniformRandomProvider rng2 = RandomSource.SPLIT_MIX_64.create(0L); + final SharedStateDiscreteSampler sampler1 = + FastLoadedDiceRollerDiscreteSampler.of(rng1, frequencies); + final SharedStateDiscreteSampler sampler2 = sampler1.withUniformRandomProvider(rng2); + RandomAssert.assertProduceSameSequence(sampler1, sampler2); + } +} diff --git a/src/main/resources/pmd/pmd-ruleset.xml b/src/main/resources/pmd/pmd-ruleset.xml index 288690f7..1d4f0ecf 100644 --- a/src/main/resources/pmd/pmd-ruleset.xml +++ b/src/main/resources/pmd/pmd-ruleset.xml @@ -77,7 +77,7 @@ <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[@SimpleName='PoissonSamplerCache' or @SimpleName='AliasMethodDiscreteSampler' or @SimpleName='GuideTableDiscreteSampler' or @SimpleName='SharedStateDiscreteProbabilitySampler' - or @SimpleName='DirichletSampler']"/> + or @SimpleName='DirichletSampler' or @SimpleName='FastLoadedDiceRollerDiscreteSampler']"/> </properties> </rule> <rule ref="category/java/bestpractices.xml/SystemPrintln"> @@ -144,6 +144,14 @@ <property name="violationSuppressXPath" value="//ClassOrInterfaceDeclaration[matches(@SimpleName, '^.*Builder$')]"/> </properties> </rule> + <rule ref="category/java/codestyle.xml/PrematureDeclaration"> + <properties> + <!-- False positive where minExponent is stored before a possible exit point. --> + <property name="violationSuppressXPath" + value="./ancestor::ClassOrInterfaceDeclaration[@SimpleName='FastLoadedDiceRollerDiscreteSampler'] and + ./ancestor::MethodName[@Image='of']"/> + </properties> + </rule> <rule ref="category/java/design.xml/NPathComplexity"> <properties> @@ -229,6 +237,12 @@ value="../MethodDeclaration[@Name='jump' or @Name='longJump']"/> </properties> </rule> + <rule ref="category/java/design.xml/GodClass"> + <properties> + <property name="violationSuppressXPath" + value="./ancestor-or-self::ClassOrInterfaceDeclaration[@SimpleName='FastLoadedDiceRollerDiscreteSampler']"/> + </properties> + </rule> <rule ref="category/java/errorprone.xml/AvoidLiteralsInIfCondition"> <properties>
