This is an automated email from the ASF dual-hosted git repository.

erans pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-math.git

commit e0b2efc2acc999ffe7df6dd6d2799de294dd811f
Author: Sam Ritchie <samritc...@google.com>
AuthorDate: Tue Oct 20 16:24:29 2020 -0600

    MATH-1558: Fix MidPointIntegrator incremental implementation
---
 .../analysis/integration/MidPointIntegrator.java   | 36 ++++++++++++++--------
 .../integration/MidPointIntegratorTest.java        | 36 ++++++++++++++++++----
 2 files changed, 53 insertions(+), 19 deletions(-)

diff --git 
a/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java
 
b/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java
index bb8b763..de252a4 100644
--- 
a/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java
+++ 
b/src/main/java/org/apache/commons/math4/analysis/integration/MidPointIntegrator.java
@@ -25,7 +25,7 @@ import 
org.apache.commons.math4.exception.TooManyEvaluationsException;
 import org.apache.commons.math4.util.FastMath;
 
 /**
- * Implements the <a href="http://en.wikipedia.org/wiki/Midpoint_method";>
+ * Implements the <a 
href="https://en.wikipedia.org/wiki/Riemann_sum#Midpoint_rule";>
  * Midpoint Rule</a> for integration of real univariate functions. For
  * reference, see <b>Numerical Mathematics</b>, ISBN 0387989595,
  * chapter 9.2.
@@ -36,8 +36,10 @@ import org.apache.commons.math4.util.FastMath;
  */
 public class MidPointIntegrator extends BaseAbstractUnivariateIntegrator {
 
-    /** Maximum number of iterations for midpoint. */
-    private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 63;
+    /** Maximum number of iterations for midpoint. 39 = floor(log_3(2^63)), the
+     * maximum number of triplings allowed before exceeding 64-bit bounds.
+     */
+    private static final int MIDPOINT_MAX_ITERATIONS_COUNT = 39;
 
     /**
      * Build a midpoint integrator with given accuracies and iterations counts.
@@ -50,7 +52,7 @@ public class MidPointIntegrator extends 
BaseAbstractUnivariateIntegrator {
      * @exception NumberIsTooSmallException if maximal number of iterations
      * is lesser than or equal to the minimal number of iterations
      * @exception NumberIsTooLargeException if maximal number of iterations
-     * is greater than 63.
+     * is greater than 39.
      */
     public MidPointIntegrator(final double relativeAccuracy,
                               final double absoluteAccuracy,
@@ -73,7 +75,7 @@ public class MidPointIntegrator extends 
BaseAbstractUnivariateIntegrator {
      * @exception NumberIsTooSmallException if maximal number of iterations
      * is lesser than or equal to the minimal number of iterations
      * @exception NumberIsTooLargeException if maximal number of iterations
-     * is greater than 63.
+     * is greater than 39.
      */
     public MidPointIntegrator(final int minimalIterationCount,
                               final int maximalIterationCount)
@@ -98,11 +100,11 @@ public class MidPointIntegrator extends 
BaseAbstractUnivariateIntegrator {
      * This function should only be called by API <code>integrate()</code> in 
the package.
      * To save time it does not verify arguments - caller does.
      * <p>
-     * The interval is divided equally into 2^n sections rather than an
+     * The interval is divided equally into 3^n sections rather than an
      * arbitrary m sections because this configuration can best utilize the
      * already computed values.</p>
      *
-     * @param n the stage of 1/2 refinement. Must be larger than 0.
+     * @param n the stage of 1/3 refinement. Must be larger than 0.
      * @param previousStageResult Result from the previous call to the
      * {@code stage} method.
      * @param min Lower bound of the integration interval.
@@ -118,21 +120,29 @@ public class MidPointIntegrator extends 
BaseAbstractUnivariateIntegrator {
                          double diffMaxMin)
         throws TooManyEvaluationsException {
 
-        // number of new points in this stage
-        final long np = 1L << (n - 1);
+        // number of points in the previous stage. This stage will contribute
+        // 2*3^{n-1} more points.
+        final long np = (long) FastMath.pow(3, n - 1);
         double sum = 0;
 
         // spacing between adjacent new points
         final double spacing = diffMaxMin / np;
+        final double leftOffset = spacing / 6;
+        final double rightOffset = 5 * leftOffset;
 
-        // the first new point
-        double x = min + 0.5 * spacing;
+        double x = min;
         for (long i = 0; i < np; i++) {
-            sum += computeObjectiveValue(x);
+            // The first and second new points are located at the new midpoints
+            // generated when each previous integration slice is split into 3.
+            //
+            // |--------x--------|
+            // |--x--|--x--|--x–-|
+            sum += computeObjectiveValue(x + leftOffset);
+            sum += computeObjectiveValue(x + rightOffset);
             x += spacing;
         }
         // add the new sum to previously calculated result
-        return 0.5 * (previousStageResult + sum * spacing);
+        return (previousStageResult + sum * spacing) / 3.0;
     }
 
 
diff --git 
a/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java
 
b/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java
index 0474d27..1d227dd 100644
--- 
a/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java
+++ 
b/src/test/java/org/apache/commons/math4/analysis/integration/MidPointIntegratorTest.java
@@ -36,6 +36,25 @@ public final class MidPointIntegratorTest {
     private static final int NUM_ITER = 30;
 
     /**
+     * The initial iteration contributes 1 evaluation. Each successive 
iteration
+     * contributes 2 points to each previous slice.
+     *
+     * The total evaluation count == 1 + 2*3^0 + 2*3^1 + ... 2*3^n
+     *
+     * the series 3^0 + 3^1 + ... + 3^n sums to 3^(n-1) / (3-1), so the total
+     * expected evaluations == 1 + 2*(3^(n-1) - 1)/2 == 3^(n-1).
+     *
+     * The n in the series above is offset by 1 from the MidPointIntegrator
+     * iteration count so the actual result == 3^n.
+     *
+     * Without the incremental implementation, the same result would require
+     * (3^(n + 1) - 1) / 2 evaluations; just under 50% more.
+     */
+    private long expectedEvaluations(int iterations) {
+        return (long) FastMath.pow(3, iterations);
+    }
+
+    /**
      * Test of integrator for the sine function.
      */
     @Test
@@ -48,8 +67,9 @@ public final class MidPointIntegratorTest {
         double expected = -3697001.0 / 48.0;
         double tolerance = FastMath.abs(expected * 
integrator.getRelativeAccuracy());
         double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
-        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
+        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
         Assert.assertTrue(integrator.getIterations() < NUM_ITER);
+        Assert.assertEquals(expectedEvaluations(integrator.getIterations()), 
integrator.getEvaluations());
         Assert.assertEquals(expected, result, tolerance);
 
     }
@@ -67,8 +87,9 @@ public final class MidPointIntegratorTest {
         double expected = 2;
         double tolerance = FastMath.abs(expected * 
integrator.getRelativeAccuracy());
         double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
-        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
+        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
         Assert.assertTrue(integrator.getIterations() < NUM_ITER);
+        Assert.assertEquals(expectedEvaluations(integrator.getIterations()), 
integrator.getEvaluations());
         Assert.assertEquals(expected, result, tolerance);
 
         min = -FastMath.PI/3;
@@ -76,8 +97,9 @@ public final class MidPointIntegratorTest {
         expected = -0.5;
         tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
         result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
-        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
+        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
         Assert.assertTrue(integrator.getIterations() < NUM_ITER);
+        Assert.assertEquals(expectedEvaluations(integrator.getIterations()), 
integrator.getEvaluations());
         Assert.assertEquals(expected, result, tolerance);
 
     }
@@ -95,8 +117,9 @@ public final class MidPointIntegratorTest {
         double expected = -1.0 / 48;
         double tolerance = FastMath.abs(expected * 
integrator.getRelativeAccuracy());
         double result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
-        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
+        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
         Assert.assertTrue(integrator.getIterations() < NUM_ITER);
+        Assert.assertEquals(expectedEvaluations(integrator.getIterations()), 
integrator.getEvaluations());
         Assert.assertEquals(expected, result, tolerance);
 
         min = 0;
@@ -104,7 +127,7 @@ public final class MidPointIntegratorTest {
         expected = 11.0 / 768;
         tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
         result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
-        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
+        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
         Assert.assertTrue(integrator.getIterations() < NUM_ITER);
         Assert.assertEquals(expected, result, tolerance);
 
@@ -113,8 +136,9 @@ public final class MidPointIntegratorTest {
         expected = 2048 / 3.0 - 78 + 1.0 / 48;
         tolerance = FastMath.abs(expected * integrator.getRelativeAccuracy());
         result = integrator.integrate(Integer.MAX_VALUE, f, min, max);
-        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 2);
+        Assert.assertTrue(integrator.getEvaluations() < Integer.MAX_VALUE / 3);
         Assert.assertTrue(integrator.getIterations() < NUM_ITER);
+        Assert.assertEquals(expectedEvaluations(integrator.getIterations()), 
integrator.getEvaluations());
         Assert.assertEquals(expected, result, tolerance);
 
     }

Reply via email to