This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit 34117b8fec244931274fd2c4950e5db481951aa6
Author: aherbert <aherb...@apache.org>
AuthorDate: Fri Jul 2 14:33:16 2021 +0100

    RNG-149: Add ZigguratExponentialSampler
---
 .../ContinuousSamplersPerformance.java             |   4 +
 .../distribution/ZigguratExponentialSampler.java   | 231 +++++++++++++++++++++
 .../distribution/ContinuousSamplersList.java       |   3 +
 .../ZigguratExponentialSamplerTest.java            |  64 ++++++
 4 files changed, 302 insertions(+)

diff --git 
a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ContinuousSamplersPerformance.java
 
b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ContinuousSamplersPerformance.java
index 2adbd7b..91bb2ee 100644
--- 
a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ContinuousSamplersPerformance.java
+++ 
b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ContinuousSamplersPerformance.java
@@ -29,6 +29,7 @@ import 
org.apache.commons.rng.sampling.distribution.InverseTransformParetoSample
 import org.apache.commons.rng.sampling.distribution.LevySampler;
 import org.apache.commons.rng.sampling.distribution.LogNormalSampler;
 import 
org.apache.commons.rng.sampling.distribution.MarsagliaNormalizedGaussianSampler;
+import org.apache.commons.rng.sampling.distribution.ZigguratExponentialSampler;
 import 
org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
 
 import org.openjdk.jmh.annotations.Benchmark;
@@ -77,6 +78,7 @@ public class ContinuousSamplersPerformance {
                 "MarsagliaNormalizedGaussianSampler",
                 "ZigguratNormalizedGaussianSampler",
                 "AhrensDieterExponentialSampler",
+                "ZigguratExponentialSampler",
                 "AhrensDieterGammaSampler",
                 "MarsagliaTsangGammaSampler",
                 "LevySampler",
@@ -113,6 +115,8 @@ public class ContinuousSamplersPerformance {
                 sampler = ZigguratNormalizedGaussianSampler.of(rng);
             } else if ("AhrensDieterExponentialSampler".equals(samplerType)) {
                 sampler = AhrensDieterExponentialSampler.of(rng, 4.56);
+            } else if ("ZigguratExponentialSampler".equals(samplerType)) {
+                sampler = ZigguratExponentialSampler.of(rng, 4.56);
             } else if ("AhrensDieterGammaSampler".equals(samplerType)) {
                 // This tests the Ahrens-Dieter algorithm since alpha < 1
                 sampler = AhrensDieterMarsagliaTsangGammaSampler.of(rng, 0.76, 
9.8);
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratExponentialSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratExponentialSampler.java
new file mode 100644
index 0000000..09276ee
--- /dev/null
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/ZigguratExponentialSampler.java
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.rng.sampling.distribution;
+
+import org.apache.commons.rng.UniformRandomProvider;
+
+/**
+ * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm";>
+ * Marsaglia and Tsang "Ziggurat" method</a> for sampling from an exponential
+ * distribution.
+ *
+ * <p>The algorithm is explained in this
+ * <a 
href="http://www.jstatsoft.org/article/view/v005i08/ziggurat.pdf";>paper</a>
+ * and this implementation has been adapted from the C code provided 
therein.</p>
+ *
+ * <p>Sampling uses:</p>
+ *
+ * <ul>
+ *   <li>{@link UniformRandomProvider#nextLong()}
+ *   <li>{@link UniformRandomProvider#nextDouble()}
+ * </ul>
+ *
+ * @since 1.4
+ */
+public class ZigguratExponentialSampler implements 
SharedStateContinuousSampler {
+    /** Start of tail. */
+    private static final double R = 7.69711747013104972;
+    /** Index of last entry in the tables (which have a size that is a power 
of 2). */
+    private static final int LAST = 255;
+    /** Auxiliary table. */
+    private static final long[] K;
+    /** Auxiliary table. */
+    private static final double[] W;
+    /** Auxiliary table. */
+    private static final double[] F;
+
+    /** Underlying source of randomness. */
+    private final UniformRandomProvider rng;
+
+    static {
+        // Filling the tables.
+        // Rectangle area.
+        final double v = 0.0039496598225815571993;
+        // No support for unsigned long so the upper bound is 2^63
+        final double max = Math.pow(2, 63);
+        final double oneOverMax = 1d / max;
+
+        K = new long[LAST + 1];
+        W = new double[LAST + 1];
+        F = new double[LAST + 1];
+
+        double d = R;
+        double t = d;
+        double fd = pdf(d);
+        final double q = v / fd;
+
+        K[0] = (long) ((d / q) * max);
+        K[1] = 0;
+
+        W[0] = q * oneOverMax;
+        W[LAST] = d * oneOverMax;
+
+        F[0] = 1;
+        F[LAST] = fd;
+
+        for (int i = LAST - 1; i >= 1; i--) {
+            d = -Math.log(v / d + fd);
+            fd = pdf(d);
+
+            K[i + 1] = (long) ((d / t) * max);
+            t = d;
+
+            F[i] = fd;
+
+            W[i] = d * oneOverMax;
+        }
+    }
+
+    /**
+     * Specialisation of the ZigguratExponentialSampler which multiplies the 
standard
+     * exponential result by the mean.
+     */
+    private static class ZigguratExponentialMeanSampler extends 
ZigguratExponentialSampler {
+        /** Mean. */
+        private final double mean;
+
+        /**
+         * @param rng Generator of uniformly distributed random numbers.
+         * @param mean Mean.
+         */
+        ZigguratExponentialMeanSampler(UniformRandomProvider rng, double mean) 
{
+            super(rng);
+            this.mean = mean;
+        }
+
+        @Override
+        public double sample() {
+            return createSample() * mean;
+        }
+
+        @Override
+        public ZigguratExponentialMeanSampler 
withUniformRandomProvider(UniformRandomProvider rng) {
+            return new ZigguratExponentialMeanSampler(rng, this.mean);
+        }
+    }
+
+    /**
+     * @param rng Generator of uniformly distributed random numbers.
+     */
+    private ZigguratExponentialSampler(UniformRandomProvider rng) {
+        this.rng = rng;
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public double sample() {
+        return createSample();
+    }
+
+    /**
+     * Creates the exponential sample with {@code mean = 1}.
+     *
+     * <p>Note: This has been extracted to a separate method so that the 
recursive call
+     * when sampling tries again targets this function. Otherwise the sub-class
+     * {@code ZigguratExponentialMeanSampler.sample()} method will recursively 
call
+     * the overloaded sample() method when trying again which creates a bad 
sample due
+     * to compound multiplication of the mean.
+     *
+     * @return the sample
+     */
+    final double createSample() {
+        // An unsigned long in [0, 2^63)
+        final long j = rng.nextLong() >>> 1;
+        final int i = ((int) j) & LAST;
+        if (j < K[i]) {
+            // This branch is called about 0.977777 times per call into 
createSample.
+            // Note: Frequencies have been empirically measured for the first 
call to
+            // createSample; recursion due to retries have been ignored. 
Frequencies sum to 1.
+            return j * W[i];
+        }
+        return fix(j, i);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public String toString() {
+        return "Ziggurat exponential deviate [" + rng.toString() + "]";
+    }
+
+    /**
+     * Gets the value from the tail of the distribution.
+     *
+     * @param jz Start random integer.
+     * @param iz Index of cell corresponding to {@code jz}.
+     * @return the requested random value.
+     */
+    private double fix(long jz,
+                       int iz) {
+        if (iz == 0) {
+            // Base strip.
+            // This branch is called about 0.000448867 times per call into 
createSample.
+            return R - Math.log(rng.nextDouble());
+        }
+        // Wedge of other strips.
+        final double x = jz * W[iz];
+        if (F[iz] + rng.nextDouble() * (F[iz - 1] - F[iz]) < pdf(x)) {
+            // This branch is called about 0.0107820 times per call into 
createSample.
+            return x;
+        }
+        // Try again.
+        // This branch is called about 0.0109920 times per call into 
createSample
+        // i.e. this is the recursion frequency.
+        return createSample();
+    }
+
+    /**
+     * Compute the exponential probability density function {@code f(x) = 
e^-x}.
+     *
+     * @param x Argument.
+     * @return {@code e^-x}
+     */
+    private static double pdf(double x) {
+        return Math.exp(-x);
+    }
+
+    /** {@inheritDoc} */
+    @Override
+    public ZigguratExponentialSampler 
withUniformRandomProvider(UniformRandomProvider rng) {
+        return new ZigguratExponentialSampler(rng);
+    }
+
+    /**
+     * Create a new exponential sampler with {@code mean = 1}.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @return the sampler
+     */
+    public static ZigguratExponentialSampler of(UniformRandomProvider rng) {
+        return new ZigguratExponentialSampler(rng);
+    }
+
+    /**
+     * Create a new exponential sampler with the specified {@code mean}.
+     *
+     * @param rng Generator of uniformly distributed random numbers.
+     * @param mean Mean.
+     * @return the sampler
+     * @throws IllegalArgumentException if the mean is not strictly positive 
({@code mean <= 0})
+     */
+    public static ZigguratExponentialSampler of(UniformRandomProvider rng, 
double mean) {
+        if (mean > 0) {
+            return new ZigguratExponentialMeanSampler(rng, mean);
+        }
+        throw new IllegalArgumentException("Mean is not strictly positive: " + 
mean);
+    }
+}
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplersList.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplersList.java
index 6802b49..94c6e3c 100644
--- 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplersList.java
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ContinuousSamplersList.java
@@ -99,6 +99,9 @@ public final class ContinuousSamplersList {
             // Exponential.
             add(LIST, new 
org.apache.commons.math3.distribution.ExponentialDistribution(unusedRng, 
meanExp),
                 AhrensDieterExponentialSampler.of(RandomSource.MT.create(), 
meanExp));
+            // Exponential ("Ziggurat").
+            add(LIST, new 
org.apache.commons.math3.distribution.ExponentialDistribution(unusedRng, 
meanExp),
+                ZigguratExponentialSampler.of(RandomSource.KISS.create(), 
meanExp));
 
             // F ("inverse method").
             final int numDofF = 4;
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratExponentialSamplerTest.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratExponentialSamplerTest.java
new file mode 100644
index 0000000..d390cc6
--- /dev/null
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/ZigguratExponentialSamplerTest.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.rng.sampling.distribution;
+
+import org.junit.Test;
+
+import org.apache.commons.rng.RestorableUniformRandomProvider;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.RandomAssert;
+import org.apache.commons.rng.simple.RandomSource;
+
+/**
+ * Test for {@link ZigguratExponentialSampler}.
+ */
+public class ZigguratExponentialSamplerTest {
+    /**
+     * Test the constructor with a bad mean.
+     */
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithZeroMean() {
+        final RestorableUniformRandomProvider rng = 
RandomSource.SPLIT_MIX_64.create(0L);
+        final double mean = 0;
+        ZigguratExponentialSampler.of(rng, mean);
+    }
+
+    /**
+     * Test the SharedStateSampler implementation.
+     */
+    @Test
+    public void testSharedStateSampler() {
+        final UniformRandomProvider rng1 = 
RandomSource.SPLIT_MIX_64.create(0L);
+        final UniformRandomProvider rng2 = 
RandomSource.SPLIT_MIX_64.create(0L);
+        final ZigguratExponentialSampler sampler1 = 
ZigguratExponentialSampler.of(rng1);
+        final ZigguratExponentialSampler sampler2 = 
sampler1.withUniformRandomProvider(rng2);
+        RandomAssert.assertProduceSameSequence(sampler1, sampler2);
+    }
+
+    /**
+     * Test the SharedStateSampler implementation with a mean.
+     */
+    @Test
+    public void testSharedStateSamplerWithMean() {
+        final UniformRandomProvider rng1 = 
RandomSource.SPLIT_MIX_64.create(0L);
+        final UniformRandomProvider rng2 = 
RandomSource.SPLIT_MIX_64.create(0L);
+        final double mean = 1.23;
+        final ZigguratExponentialSampler sampler1 = 
ZigguratExponentialSampler.of(rng1, mean);
+        final ZigguratExponentialSampler sampler2 = 
sampler1.withUniformRandomProvider(rng2);
+        RandomAssert.assertProduceSameSequence(sampler1, sampler2);
+    }
+}

Reply via email to