This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c2edbc8 STATISTICS-92: TruncatedNormalDistribution rejection-sampler 
uses incorrect bounds for mirroring check (#125)
3c2edbc8 is described below

commit 3c2edbc8365056361328a5e8bdeea36a557466ed
Author: Kevin Milner <[email protected]>
AuthorDate: Mon Mar 16 10:43:04 2026 -0700

    STATISTICS-92: TruncatedNormalDistribution rejection-sampler uses incorrect 
bounds for mirroring check (#125)
    
    Fix using std. norm. bounds for rejection sampler
    
    ---------
    
    Co-authored-by: Kevin Milner <[email protected]>
---
 .../distribution/TruncatedNormalDistribution.java  | 16 ++++----
 .../TruncatedNormalDistributionTest.java           | 40 ++++++++++++++++++++
 .../test.truncatednormal.10.properties             | 36 ++++++++++++++++++
 .../test.truncatednormal.11.properties             | 43 ++++++++++++++++++++++
 4 files changed, 127 insertions(+), 8 deletions(-)

diff --git 
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
 
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
index 61d34174..92656007 100644
--- 
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
+++ 
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
@@ -235,12 +235,17 @@ public final class TruncatedNormalDistribution extends 
AbstractContinuousDistrib
     /** {@inheritDoc} */
     @Override
     public Sampler createSampler(UniformRandomProvider rng) {
+        // Map the bounds to a standard normal distribution
+        final double u = parentNormal.getMean();
+        final double s = parentNormal.getStandardDeviation();
+        final double a = (lower - u) / s;
+        final double b = (upper - u) / s;
         // If the truncation covers a reasonable amount of the normal 
distribution
         // then a rejection sampler can be used.
         double threshold = REJECTION_THRESHOLD;
         // If the truncation is entirely in the upper or lower half then 
adjust the
         // threshold as twice the samples can be used
-        if (lower >= 0 || upper <= 0) {
+        if (a >= 0 || b <= 0) {
             threshold *= 0.5;
         }
 
@@ -249,21 +254,16 @@ public final class TruncatedNormalDistribution extends 
AbstractContinuousDistrib
             final ZigguratSampler.NormalizedGaussian sampler = 
ZigguratSampler.NormalizedGaussian.of(rng);
             final DoubleSupplier gen;
             // Use mirroring if possible
-            if (lower >= 0) {
+            if (a >= 0) {
                 // Return the upper-half of the Gaussian
                 gen = () -> Math.abs(sampler.sample());
-            } else if (upper <= 0) {
+            } else if (b <= 0) {
                 // Return the lower-half of the Gaussian
                 gen = () -> -Math.abs(sampler.sample());
             } else {
                 // Return the full range of the Gaussian
                 gen = sampler::sample;
             }
-            // Map the bounds to a standard normal distribution
-            final double u = parentNormal.getMean();
-            final double s = parentNormal.getStandardDeviation();
-            final double a = (lower - u) / s;
-            final double b = (upper - u) / s;
             // Sample in [a, b] using rejection
             return () -> {
                 double x = gen.getAsDouble();
diff --git 
a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
 
b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
index d4bdba7f..c935b873 100644
--- 
a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
+++ 
b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/TruncatedNormalDistributionTest.java
@@ -17,8 +17,10 @@
 
 package org.apache.commons.statistics.distribution;
 
+import java.time.Duration;
 import org.apache.commons.numbers.gamma.Erf;
 import org.apache.commons.numbers.gamma.Erfcx;
+import org.apache.commons.rng.simple.RandomSource;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
@@ -351,6 +353,44 @@ class TruncatedNormalDistributionTest extends 
BaseContinuousDistributionTest {
         }
     }
 
+    /**
+     * This tests that the sampler can correctly sample values when the range
+     * is positive but fully below the mean value. This tests the case where
+     * the rejection sampler threshold is met, and fails due to a timeout
+     * if the rejection sampler incorrectly triggers the upper-half clause and
+     * is stuck in an infinite loop.
+     */
+    @Test
+    void testSamplerPositiveBelowMeanWithRejection() {
+        // this triggers the rejection-sampler case in 
TruncatedNormalDistribution.createSampler(...)
+        final TruncatedNormalDistribution dist = 
TruncatedNormalDistribution.of(1d, 0.1, 0.7, 0.99);
+
+        final double x = 
Assertions.assertTimeoutPreemptively(Duration.ofSeconds(1),
+                () -> 
dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L)).sample());
+
+        Assertions.assertTrue(x >= dist.getSupportLowerBound() && x <= 
dist.getSupportUpperBound(),
+                () -> "Sample outside support: " + x);
+    }
+
+    /**
+     * This tests that the sampler can correctly sample values when the range
+     * is negative but fully above the mean value. This tests the case where
+     * the rejection sampler threshold is met, and fails due to a timeout
+     * if the rejection sampler incorrectly triggers the lower-half clause and
+     * is stuck in an infinite loop.
+     */
+    @Test
+    void testSamplerNegativeAboveMeanWithRejection() {
+        // this triggers the rejection-sampler case in 
TruncatedNormalDistribution.createSampler(...)
+        final TruncatedNormalDistribution dist = 
TruncatedNormalDistribution.of(-1d, 0.1, -0.99, -0.7);
+
+        final double x = 
Assertions.assertTimeoutPreemptively(Duration.ofSeconds(1),
+                () -> 
dist.createSampler(RandomSource.XO_SHI_RO_256_PP.create(123456789L)).sample());
+
+        Assertions.assertTrue(x >= dist.getSupportLowerBound() && x <= 
dist.getSupportUpperBound(),
+                () -> "Sample outside support: " + x);
+    }
+
     /**
      * Assert the mean of the truncated normal distribution is within the 
provided relative error.
      */
diff --git 
a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.10.properties
 
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.10.properties
new file mode 100644
index 00000000..46774a08
--- /dev/null
+++ 
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.10.properties
@@ -0,0 +1,36 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Original bounds that are fully positive
+# STATISTICS-92
+parameters = 1.0, 0.1, 0.7, 1.3
+# Computed using Python with SciPy v1.16.3:
+# mean, std, clip_a, clip_b = 1.0, 0.1, 0.7, 1.3
+# a, b = (clip_a - mean) / std, (clip_b - mean) / std
+# truncnorm.var(a, b, loc=mean, scale=std)
+mean = 1.0
+variance = 0.009733369246625417
+lower = 0.7
+upper = 1.3
+cdf.points = \
+  0.55, 0.65, 0.75, 0.85, 0.95, 1.05, 1.15, 1.25, 1.35
+cdf.values = \
+  0.                 , 0.                 , 0.00487292319299906,\
+  0.06563450301006861, 0.30801922979837476, 0.6919807702016253 ,\
+  0.9343654969899313 , 0.995127076807001  , 1.
+pdf.values = \
+  0.                 , 0.                 , 0.17575751438109988,\
+  1.298682133570557  , 3.530184044629268  , 3.530184044629268  ,\
+  1.2986821335705594 , 0.17575751438109988, 0.
diff --git 
a/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.11.properties
 
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.11.properties
new file mode 100644
index 00000000..02af336c
--- /dev/null
+++ 
b/commons-statistics-distribution/src/test/resources/org/apache/commons/statistics/distribution/test.truncatednormal.11.properties
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Original bounds that are fully negative
+# STATISTICS-92
+parameters = -1.0, 0.1, -1.4, -0.8
+# Computed using Python with SciPy v1.16.3:
+# mean, std, clip_a, clip_b = -1.0, 0.1, -1.4, -0.8
+# a, b = (clip_a - mean) / std, (clip_b - mean) / std
+# truncnorm.var(a, b, loc=mean, scale=std)
+mean = -1.0055112703041382
+variance = 0.00885915482691343
+lower = -1.4
+upper = -0.8
+cdf.points = \
+  -1.45              , -1.3722222222222222, -1.2944444444444443,\
+  -1.2166666666666666, -1.1388888888888888, -1.0611111111111111,\
+  -0.9833333333333333, -0.9055555555555554, -0.8277777777777777,\
+  -0.75
+cdf.values = \
+  0.0000000000000000e+00, 6.8630844903959877e-05,\
+  1.6229782566954216e-03, 1.5450458063193950e-02,\
+  8.4322619964621676e-02, 2.7683821471716175e-01,\
+  5.7935081767532270e-01, 8.4678840615650053e-01,\
+  9.7977472840069957e-01, 1.0000000000000000e+00
+pdf.values = \
+  0.0000000000000000e+00, 4.0027355759704591e-03,\
+  5.3494059692712141e-02, 3.9042072314005954e-01,\
+  1.5561046905609728e+00, 3.3870640463590616e+00,\
+  4.0261194212948235e+00, 2.6135363401653997e+00,\
+  9.2650780123937682e-01, 0.0000000000000000e+00

Reply via email to