This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
The following commit(s) were added to refs/heads/master by this push:
new 3c2edbc8 STATISTICS-92: TruncatedNormalDistribution rejection-sampler
uses incorrect bounds for mirroring check (#125)
3c2edbc8 is described below
commit 3c2edbc8365056361328a5e8bdeea36a557466ed
Author: Kevin Milner <[email protected]>
AuthorDate: Mon Mar 16 10:43:04 2026 -0700
STATISTICS-92: TruncatedNormalDistribution rejection-sampler uses incorrect
bounds for mirroring check (#125)
Fix using std. norm. bounds for rejection sampler
---------
Co-authored-by: Kevin Milner <[email protected]>
---
.../distribution/TruncatedNormalDistribution.java | 16 ++++----
.../TruncatedNormalDistributionTest.java | 40 ++++++++++++++++++++
.../test.truncatednormal.10.properties | 36 ++++++++++++++++++
.../test.truncatednormal.11.properties | 43 ++++++++++++++++++++++
4 files changed, 127 insertions(+), 8 deletions(-)
diff --git
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
index 61d34174..92656007 100644
---
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
+++
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
@@ -235,12 +235,17 @@ public final class TruncatedNormalDistribution extends
AbstractContinuousDistrib
/** {@inheritDoc} */
@Override
public Sampler createSampler(UniformRandomProvider rng) {
+ // Map the bounds to a standard normal distribution
+ final double u = parentNormal.getMean();
+ final double s = parentNormal.getStandardDeviation();
+ final double a = (lower - u) / s;
+ final double b = (upper - u) / s;
// If the truncation covers a reasonable amount of the normal
distribution
// then a rejection sampler can be used.
double threshold = REJECTION_THRESHOLD;
// If the truncation is entirely in the upper or lower half then
adjust the
// threshold as twice the samples can be used
- if (lower >= 0 || upper <= 0) {
+ if (a >= 0 || b <= 0) {
threshold *= 0.5;
}
@@ -249,21 +254,16 @@ public final class TruncatedNormalDistribution extends
AbstractContinuousDistrib
final ZigguratSampler.NormalizedGaussian sampler =
ZigguratSampler.NormalizedGaussian.of(rng);
final DoubleSupplier gen;
// Use mirroring if possible
- if (lower >= 0) {
+ if (a >= 0) {
// Return the upper-half of the Gaussian
gen = () -> Math.abs(sampler.sample());
- } else if (upper <= 0) {
+ } else if (b <= 0) {
// Return the lower-half of the Gaussian
gen = () -> -Math.abs(sampler.sample());
} else {
// Return the full range of the Gaussian
gen = sampler::sample;
}
- // Map the bounds to a standard normal distribution
- final double u = parentNormal.getMean();
- final double s = parentNormal.getStandardDeviation();
- final double a = (lower - u) / s;
- final double b = (upper - u) / s;
// Sample in [a, b] using rejection
return () -> {
double x = gen.getAsDouble();
diff --git
a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
index d4bdba7f..c935b873 100644
---
a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
+++
b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
@@ -17,8 +17,10 @@
package org.apache.commons.statistics.distribution;
+import java.time.Duration;
import org.apache.commons.numbers.gamma.Erf;
import org.apache.commons.numbers.gamma.Erfcx;
+import org.apache.commons.rng.simple.RandomSource;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
@@ -351,6 +353,44 @@ class TruncatedNormalDistributionTest extends
BaseContinuousDistributionTest {
}
}
+ /**
+ * This tests that the sampler can correctly sample values when the range
+ * is positive but fully below the mean value. This tests the case where
+ * the rejection sampler threshold is met, and fails due to a timeout
+ * if the rejection sampler incorrectly triggers the upper-half clause and
+ * is stuck in an infinite loop.
+ */
+ @Test
+ void testSamplerPositiveBelowMeanWithRejection() {
+ // this triggers the rejection-sampler case in
TruncatedNormalDistribution.createSampler(...)
+ final TruncatedNormalDistribution dist =
TruncatedNormalDistribution.of(1d, 0.1, 0.7, 0.99);
+
+ final double x =
Assertions.assertTimeoutPreemptively(Duration.ofSeconds(1),
+ () ->
dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L)).sample());
+
+ Assertions.assertTrue(x >= dist.getSupportLowerBound() && x <=
dist.getSupportUpperBound(),
+ () -> "Sample outside support: " + x);
+ }
+
+ /**
+ * This tests that the sampler can correctly sample values when the range
+ * is negative but fully above the mean value. This tests the case where
+ * the rejection sampler threshold is met, and fails due to a timeout
+ * if the rejection sampler incorrectly triggers the lower-half clause and
+ * is stuck in an infinite loop.
+ */
+ @Test
+ void testSamplerNegativeAboveMeanWithRejection() {
+ // this triggers the rejection-sampler case in
TruncatedNormalDistribution.createSampler(...)
+ final TruncatedNormalDistribution dist =
TruncatedNormalDistribution.of(-1d, 0.1, -0.99, -0.7);
+
+ final double x =
Assertions.assertTimeoutPreemptively(Duration.ofSeconds(1),
+ () ->
dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L)).sample());
+
+ Assertions.assertTrue(x >= dist.getSupportLowerBound() && x <=
dist.getSupportUpperBound(),
+ () -> "Sample outside support: " + x);
+ }
+
/**
* Assert the mean of the truncated normal distribution is within the
provided relative error.
*/
diff --git
a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.10.properties
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.10.properties
new file mode 100644
index 00000000..46774a08
--- /dev/null
+++
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.10.properties
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Original bounds that are fully positive
+# STATISTICS-92
+parameters = 1.0, 0.1, 0.7, 1.3
+# Computed using Python with SciPy v1.16.3:
+# mean, std, clip_a, clip_b = 1.0, 0.1, 0.7, 1.3
+# a, b = (clip_a - mean) / std, (clip_b - mean) / std
+# truncnorm.var(a, b, loc=mean, scale=std)
+mean = 1.0
+variance = 0.009733369246625417
+lower = 0.7
+upper = 1.3
+cdf.points = \
+ 0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25, 1.35
+cdf.values = \
+ 0. , 0. , 0.00487292319299906,\
+ 0.06563450301006861, 0.30801922979837476, 0.6919807702016253 ,\
+ 0.9343654969899313 , 0.995127076807001 , 1.
+pdf.values = \
+ 0. , 0. , 0.17575751438109988,\
+ 1.298682133570557 , 3.530184044629268 , 3.530184044629268 ,\
+ 1.2986821335705594 , 0.17575751438109988, 0.
diff --git
a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.11.properties
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.11.properties
new file mode 100644
index 00000000..02af336c
--- /dev/null
+++
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.11.properties
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Original bounds that are fully negative
+# STATISTICS-92
+parameters = -1.0, 0.1, -1.4, -0.8
+# Computed using Python with SciPy v1.16.3:
+# mean, std, clip_a, clip_b = -1.0, 0.1, -1.4, -0.8
+# a, b = (clip_a - mean) / std, (clip_b - mean) / std
+# truncnorm.var(a, b, loc=mean, scale=std)
+mean = -1.0055112703041382
+variance = 0.00885915482691343
+lower = -1.4
+upper = -0.8
+cdf.points = \
+ -1.45 , -1.3722222222222222, -1.2944444444444443,\
+ -1.2166666666666666, -1.1388888888888888, -1.0611111111111111,\
+ -0.9833333333333333, -0.9055555555555554, -0.8277777777777777,\
+ -0.75
+cdf.values = \
+ 0.0000000000000000e+00, 6.8630844903959877e-05,\
+ 1.6229782566954216e-03, 1.5450458063193950e-02,\
+ 8.4322619964621676e-02, 2.7683821471716175e-01,\
+ 5.7935081767532270e-01, 8.4678840615650053e-01,\
+ 9.7977472840069957e-01, 1.0000000000000000e+00
+pdf.values = \
+ 0.0000000000000000e+00, 4.0027355759704591e-03,\
+ 5.3494059692712141e-02, 3.9042072314005954e-01,\
+ 1.5561046905609728e+00, 3.3870640463590616e+00,\
+ 4.0261194212948235e+00, 2.6135363401653997e+00,\
+ 9.2650780123937682e-01, 0.0000000000000000e+00