This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
The following commit(s) were added to refs/heads/master by this push:
new 6edeb10 RNG-146: Prevent infinite standard deviation
6edeb10 is described below
commit 6edeb102c5310895480781e14c14facaf12ed864
Author: aherbert <[email protected]>
AuthorDate: Fri Jul 9 11:29:17 2021 +0100
RNG-146: Prevent infinite standard deviation
---
.../rng/sampling/distribution/GaussianSampler.java | 22 +++++++--
.../sampling/distribution/GaussianSamplerTest.java | 54 +++++++++++++++++++++-
src/changes/changes.xml | 3 ++
3 files changed, 74 insertions(+), 5 deletions(-)
diff --git
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
index 5540018..38e0537 100644
---
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
+++
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
@@ -22,6 +22,14 @@ import org.apache.commons.rng.UniformRandomProvider;
* Sampling from a Gaussian distribution with given mean and
* standard deviation.
*
+ * <h2>Note</h2>
+ *
+ * <p>The mean and standard deviation are validated to ensure they are finite.
This prevents
+ * generation of NaN samples by avoiding invalid arithmetic (inf * 0 or inf -
inf).
+ * However use of an extremely large standard deviation and/or mean may result
in samples that are
+ * infinite; that is the parameters are not validated to prevent truncation of
the output
+ * distribution.
+ *
* @since 1.1
*/
public class GaussianSampler implements SharedStateContinuousSampler {
@@ -36,14 +44,19 @@ public class GaussianSampler implements
SharedStateContinuousSampler {
* @param normalized Generator of N(0,1) Gaussian distributed random
numbers.
* @param mean Mean of the Gaussian distribution.
* @param standardDeviation Standard deviation of the Gaussian
distribution.
- * @throws IllegalArgumentException if {@code standardDeviation <= 0}
+ * @throws IllegalArgumentException if {@code standardDeviation <= 0} or
is infinite;
+ * or {@code mean} is infinite
*/
public GaussianSampler(NormalizedGaussianSampler normalized,
double mean,
double standardDeviation) {
- if (standardDeviation <= 0) {
+ if (!(standardDeviation > 0 && standardDeviation <
Double.POSITIVE_INFINITY)) {
throw new IllegalArgumentException(
- "standard deviation is not strictly positive: " +
standardDeviation);
+ "standard deviation is not strictly positive and finite: " +
standardDeviation);
+ }
+ // To be replaced by JDK 1.8 Double.isFinite. This will detect NaN
values.
+ if (!(Math.abs(mean) <= Double.MAX_VALUE)) {
+ throw new IllegalArgumentException("mean is not finite: " + mean);
}
this.normalized = normalized;
this.mean = mean;
@@ -102,7 +115,8 @@ public class GaussianSampler implements
SharedStateContinuousSampler {
* @param mean Mean of the Gaussian distribution.
* @param standardDeviation Standard deviation of the Gaussian
distribution.
* @return the sampler
- * @throws IllegalArgumentException if {@code standardDeviation <= 0}
+ * @throws IllegalArgumentException if {@code standardDeviation <= 0} or
is infinite;
+ * or {@code mean} is infinite
* @see #withUniformRandomProvider(UniformRandomProvider)
* @since 1.3
*/
diff --git
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
index 9bfabba..2ad15f8 100644
---
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
+++
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
@@ -28,7 +28,7 @@ import org.junit.Test;
*/
public class GaussianSamplerTest {
/**
- * Test the constructor with a bad standard deviation.
+ * Test the constructor with a zero standard deviation.
*/
@Test(expected = IllegalArgumentException.class)
public void testConstructorThrowsWithZeroStandardDeviation() {
@@ -41,6 +41,58 @@ public class GaussianSamplerTest {
}
/**
+ * Test the constructor with an infinite standard deviation.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithInfiniteStandardDeviation() {
+ final RestorableUniformRandomProvider rng =
+ RandomSource.SPLIT_MIX_64.create(0L);
+ final NormalizedGaussianSampler gauss = new
ZigguratNormalizedGaussianSampler(rng);
+ final double mean = 1;
+ final double standardDeviation = Double.POSITIVE_INFINITY;
+ GaussianSampler.of(gauss, mean, standardDeviation);
+ }
+
+ /**
+ * Test the constructor with a NaN standard deviation.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithNaNStandardDeviation() {
+ final RestorableUniformRandomProvider rng =
+ RandomSource.SPLIT_MIX_64.create(0L);
+ final NormalizedGaussianSampler gauss = new
ZigguratNormalizedGaussianSampler(rng);
+ final double mean = 1;
+ final double standardDeviation = Double.NaN;
+ GaussianSampler.of(gauss, mean, standardDeviation);
+ }
+
+ /**
+ * Test the constructor with an infinite mean.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithInfiniteMean() {
+ final RestorableUniformRandomProvider rng =
+ RandomSource.SPLIT_MIX_64.create(0L);
+ final NormalizedGaussianSampler gauss = new
ZigguratNormalizedGaussianSampler(rng);
+ final double mean = Double.POSITIVE_INFINITY;
+ final double standardDeviation = 1;
+ GaussianSampler.of(gauss, mean, standardDeviation);
+ }
+
+ /**
+ * Test the constructor with a NaN mean.
+ */
+ @Test(expected = IllegalArgumentException.class)
+ public void testConstructorThrowsWithNaNMean() {
+ final RestorableUniformRandomProvider rng =
+ RandomSource.SPLIT_MIX_64.create(0L);
+ final NormalizedGaussianSampler gauss = new
ZigguratNormalizedGaussianSampler(rng);
+ final double mean = Double.NaN;
+ final double standardDeviation = 1;
+ GaussianSampler.of(gauss, mean, standardDeviation);
+ }
+
+ /**
* Test the SharedStateSampler implementation.
*/
@Test
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index e211d01..c71270f 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -77,6 +77,9 @@ re-run tests that fail, and pass the build if they succeed
within the allotted number of reruns (the test will be marked
as 'flaky' in the report).
">
+ <action dev="aherbert" type="fix" issue="146">
+ "GaussianSampler": Prevent infinite mean and standard deviation.
+ </action>
<action dev="aherbert" type="update" issue="154">
Update Gaussian samplers to avoid infinity in the tails of the
distribution. Applies
to: ZigguratNormalisedGaussianSampler;
BoxMullerNormalizedGaussianSampler; and