Hi Jim.

> The test as it is has a test case (I just chose random numbers to
> check
> and got lucky - d'oh!) that generates 1 solution from the new code
> even
> though the equation had 2 distinct solutions that weren't even near
> each
> other...

I figured out why this happens. It's because of cancellation in the
computation of D (two large numbers are subtracted and the result is
supposed to be 0 or close to 0, but it's about 1e-7, which wasn't
enough to pass the iszero test). I've been working on this and I
came up with a couple of different ways. They are in the attached
file (it's a modified version of the file your CubicSolve.java).

The first thing I did was to modify solveCubicOld. I tried to get
a bit fancy and although I think I fixed the problems it had, the
end result is ugly, complicated and it has small problems, like
returning 3 very close roots when there should only be one.

The other solution is to just check if the roots of the derivative
are also roots of the cubic polynomial if only 1 root was computed
by the closed form algorithm. This doesn't have the numerical
accuracy of the first way (which used bisectRoots when things went wrong)
but it's much faster, doesn't have the multiple roots problem, and it's
much simpler. I called your trySolve function on a few hundred
polynomials with random roots in [-10, 10] and it never finds fewer
roots than there actually are. Sometimes it finds 3 roots when there are
only 2, but I don't think this is a huge problem.

I've attached what I have so far.

Regards,
Denis.

----- Original Message -----
> Hi Denis,
> 
> I'm attaching a test program I wrote that compares the old and new
> algorithms.
> 
> Obviously the old one missed a bunch of solutions because it
> classified
> all solutions as 1 or 3, but the new one also sometimes misses a
> solution. You might want to turn this into an automated test for the
> bug (and maybe use it as a stress test with a random number
> generator).
> 
> I think one problem might be that you use "is close to zero" to check
> if
> you should use special processing. I think any tests which say "do it
> this way and get fewer roots" should be conservative and if we are on
> the borderline and we can do the code that generates more solutions
> then
> we should generate more and them maybe refine the roots and eliminate
> duplicates. That way we can be (more) sure not to leave any roots
> unsolved.
> 

> 
> ...jim
import java.awt.geom.QuadCurve2D;
import java.util.Arrays;
import java.util.Random;

import static java.lang.Math.abs;
import static java.lang.Math.max;
import static java.lang.Math.ulp;

public class CubicSolver {
    public static int solveCubicOld(double eqn[], double res[]) {
        if (res == eqn) {
            // Copy the eqn so that we don't clobber it with the
            // roots.
            eqn = new double[4];
            System.arraycopy(res, 0, eqn, 0, 4);
        }

        // From Numerical Recipes, 5.6, Quadratic and Cubic Equations
        double d = eqn[3];
        if (d == 0.0) {
            // The cubic has degenerated to quadratic (or line or ...).
            return QuadCurve2D.solveQuadratic(eqn, res);
        }
        double a = eqn[2] / d;
        double b = eqn[1] / d;
        double c = eqn[0] / d;
        int roots = 0;
        double Q = (a * a - 3.0 * b) / 9.0;
        double R = (2.0 * a * a * a - 9.0 * a * b + 27.0 * c) / 54.0;
        double R2 = R * R;
        double Q3 = Q * Q * Q;
        a = a / 3.0;
        if (R2 < Q3) {
            double theta = Math.acos(R / Math.sqrt(Q3));
            Q = -2.0 * Math.sqrt(Q);

            res[roots++] = Q * Math.cos(theta / 3.0) - a;
            res[roots++] = Q * Math.cos((theta + Math.PI * 2.0)/ 3.0) - a;
            res[roots++] = Q * Math.cos((theta - Math.PI * 2.0)/ 3.0) - a;
        } else {
            boolean neg = (R < 0.0);
            double S = Math.sqrt(R2 - Q3);
            if (neg) {
                R = -R;
            }
            double A = Math.pow(R + S, 1.0 / 3.0);
            if (!neg) {
                A = -A;
            }
            double B = (A == 0.0) ? 0.0 : (Q / A);
            res[roots++] = (A + B) - a;
        }
        
    	// we need it to have length 4. We will put the roots of the derivative
    	// in deriv[1] and deriv[2]
    	final double[] deriv = {eqn[1], 2*eqn[2], 3*eqn[3], 0};
        int critCount = QuadCurve2D.solveQuadratic(deriv, deriv);
        Arrays.sort(deriv, 0, critCount);
        Arrays.sort(res, 0, roots);
        // Even if there are fewer than 2 roots, this won't cause problems.
        deriv[2] = deriv[1];
        deriv[1] = deriv[0];
        // The roots of any polynomial must lie in [-M, M] where M = 1 + (max{i=0,n-1}abs(ai))/abs(an)
        // http://en.wikipedia.org/wiki/Sturm%27s_theorem#Applications
        // Wikipedia says this result is due to Cauchy. There's no proof in the link,
        // but I proved it myself (it's a bit long to include here).
        double M = 1 + max(max(abs(a), abs(b)), abs(c)) / abs(d);
        M += max(ulp(M), 1);
        deriv[0] = -M;
        deriv[critCount + 1] = M;
        
        if (critCount == 2 && deriv[1] == deriv[2]) {
        	// solveQuadratic shouldn't do this, but sometimes it does.
        	critCount = 1;
        	deriv[2] = deriv[3];
        }
        
        if (roots == 3) {
        	if (critCount == 2) {
        		// -eqn[3] has the same sign as lim{t->-inf}(eqn(t))
        		double c1 = solveEqn(eqn, 3, deriv[1]);
        		double c2 = solveEqn(eqn, 3, deriv[2]);
        		if (oppositeSigns(-eqn[3], c1) && res[0] < deriv[1] &&
        			oppositeSigns(c1, c2) && deriv[1] < res[1] && res[1] < deriv[2] &&
        			oppositeSigns(c2, eqn[3]) && res[2] > deriv[2]) {
        			// all is good.
        			return roots;
        		}
        	}
        	return iterativeCubicRoots(eqn, deriv, critCount + 1, res);
        } else {
        	if (critCount == 2) {
        		double c1 = solveEqn(eqn, 3, deriv[1]);
        		double c2 = solveEqn(eqn, 3, deriv[2]);
        		double root = res[0];
        		roots = 0;
        		// The following is almost an unrolled version of the loop
        		// in iterativeCubicRoots. We don't just call that function
        		// because this way we can possibly save a bisectRoot call 
        		// by using the root computed by the closed form algorithm
        		// if it lies in the correct interval.
        		boolean int1Root = false;
        		if (oppositeSigns(-eqn[3], c1)) {
        			if (root < deriv[1]) {
        				res[roots++] = root;
        			} else {
        				res[roots++] = bisectRoot(eqn, deriv[0], deriv[1]);
        			}
        			int1Root = true;
        		}

        		boolean int2Root = false;
        		if (oppositeSigns(c1, c2)) {
    				if (deriv[1] < root && root < deriv[2]) {
    					res[roots++] = root;
    				} else {
    					res[roots++] = bisectRoot(eqn, deriv[1], deriv[2]);
    				}
    				int2Root = true;
        		} else if (!int1Root && iszero(c1)) {
        			res[roots++] = deriv[1];
        			int2Root = true;
        		}

        		if (oppositeSigns(c2, eqn[3])) {
    				if (root > deriv[2]) {
    					res[roots++] = root;
    				} else {
    					res[roots++] = bisectRoot(eqn, deriv[2], deriv[3]);
    				}
        		} else if (!int2Root && iszero(c2)) {
        			res[roots++] = deriv[2];
        		}
        	}
        	// If there is only one critical point that means there's only 1
        	// root so we assume we have the correct one. XXX: might want to
        	// do check that it's in the correct interval anyway.
        }

        if (roots == 3) {
            fixRoots(res, eqn);
        }

        return roots;
    }

    private static int iterativeCubicRoots(double[] eqn, double[] intervals, int numInts, double[] res) {
		int numRoots = 0;

		double x0 = intervals[0];
		double fx0 = solveEqn(eqn, 3, x0);
		// if the polynomial is very flat we wouldn't want x0 to be mistaken
		// for a root.
		fx0 += Math.signum(fx0);
		// iterate the intervals.
		for (int i = 1; i < numInts + 1; i++) {
			final double xe = intervals[i];
			final double fxe = solveEqn(eqn, 3, xe);

			if (oppositeSigns(fx0, fxe)) {
				res[numRoots++] = bisectRoot(eqn, x0, xe);
			} else if (iszero(fx0)) {
				res[numRoots++] = x0;
			}

			x0 = xe;
			fx0 = fxe;
		}
		return numRoots;
    }

    private static double bisectRoot(double[] eqn, double x0, double xe) {
		double fx0 = solveEqn(eqn, 3, x0);
		double m = x0 + (xe - x0) / 2;
		double fm = solveEqn(eqn, 3, m);
		while (m != x0 && m != xe) {
			if (fm == 0) {
				return m;
			}
			if (oppositeSigns(fx0, fm)) {
				xe = m;
			} else {
				fx0 = fm;
				x0 = m;
			}
			m = x0 + (xe-x0)/2;
			fm = solveEqn(eqn, 3, m);
		}
		return m;
    }

    private static boolean iszero(double x) {
        return within(x, 0, 1e-9);
    }
    
    private static boolean within(final double x, final double y, final double err) {
        final double d = y - x;
        return (d <= err && d >= -err);
    }

    private static void fixRoots(double res[], double eqn[]) {
        final double EPSILON = 1E-5;
        for (int i = 0; i < 3; i++) {
            double t = res[i];
            if (Math.abs(t) < EPSILON) {
                res[i] = findZero(t, 0, eqn);
            } else if (Math.abs(t - 1) < EPSILON) {
                res[i] = findZero(t, 1, eqn);
            }
        }
    }

    public static int solveCubicNew(double eqn[], double res[]) {
        // From Graphics Gems:
        // http://tog.acm.org/resources/GraphicsGems/gems/Roots3And4.c
        final double d = eqn[3];
        if (d == 0) {
            return QuadCurve2D.solveQuadratic(eqn, res);
        }

        if (res == eqn) {
            // Copy the eqn so that we don't clobber it with the
            // roots.  This is needed so that fixRoots can do its
            // work with the original equation.
            eqn = new double[4];
            System.arraycopy(res, 0, eqn, 0, 4);
        }

        /* normal form: x^3 + Ax^2 + Bx + C = 0 */
        final double A = eqn[2] / d;
        final double B = eqn[1] / d;
        final double C = eqn[0] / d;


        //  substitute x = y - A/3 to eliminate quadratic term:
        //     x^3 +px + q = 0

        double sq_A = A * A;
        double p = 1.0/3 * (-1.0/3 * sq_A + B);
        double q = 1.0/2 * (2.0/27 * A * sq_A - 1.0/3 * A * B + C);

        /* use Cardano's formula */

        double cb_p = p * p * p;
        double D = q * q + cb_p;

        int num;
        // XXX: we consider 0 to be anything within 1e-9 of 0.
        // Is this really right? Maybe we should use a bound that changes
        // with the number being tested (math.ulp would do that, but
        // what input do we give it? And do we scale it, or will
        // within(D, 0, math.ulp(somenumber)) be good enough).
        if (iszero(D)) {
            // XXX: do we really need iszero for q? All we do with it is
            // take it's cube root, which works fine even for Double.MIN_VALUE
            // Then again, if we remove it, we will get two extremely close
            // roots for equations where there is only one zero root.
            // We probably should use something like iszero, but much more
            // scrict - so, within(q, 0, 2*Math.ulp(0));
            if (iszero(q)) { /* one triple solution */
                res[ 0 ] = 0;
                num = 1;
            } else { /* one single and one double solution */
                final double u = Math.cbrt(-q);
                res[ 0 ] = 2*u;
                res[ 1 ] = -u;
                num = 2;
            }
        } else if (D < 0) { /* Casus irreducibilis: three real solutions */
            final double phi = 1.0/3 * Math.acos(-q / Math.sqrt(-cb_p));
            final double t = 2 * Math.sqrt(-p);

            if (res == eqn) {
                // Copy the eqn so that we don't clobber it with the
                // roots.  This is needed so that fixRoots can do its
                // work with the original equation.
                eqn = new double[4];
                System.arraycopy(res, 0, eqn, 0, 4);
            }
            res[ 0 ] =  ( t * Math.cos(phi));
            res[ 1 ] =  (-t * Math.cos(phi + Math.PI / 3));
            res[ 2 ] =  (-t * Math.cos(phi - Math.PI / 3));
            num = 3;
        } else { /* one real solution */
            final double sqrt_D = Math.sqrt(D);
            final double u = Math.cbrt(sqrt_D - q);
            final double v = - Math.cbrt(sqrt_D + q);

            res[ 0 ] =  u + v;
            num = 1;
        }

        /* resubstitute */

        final double sub = 1.0/3 * A;

        for (int i = 0; i < num; ++i)
            res[ i ] -= sub;

        if (num == 1) {
        	final double[] deriv = {eqn[1], 2*eqn[2], 3*eqn[3]};
            int critCount = QuadCurve2D.solveQuadratic(deriv, deriv);
            Arrays.sort(deriv, 0, critCount);
            if (critCount == 2) {
        		final double c1 = solveEqn(eqn, 3, deriv[0]);
                final boolean iszeroc1 = iszero(c1);
        		final double c2 = solveEqn(eqn, 3, deriv[1]);
                final boolean iszeroc2 = iszero(c2);
                // if they are both 0, then 1 root is enough.
                if (!(iszeroc1 && iszeroc2)) {
                	if (iszeroc1) {
                		res[num++] = deriv[0];
                	}
                	if (iszeroc2) {
                		res[num++] = deriv[1];
                	}
                }
            }
        }

        return num;
    }

    private static boolean oppositeSigns(double x1, double x2) {
    	return (x1 < 0 && x2 > 0) || (x1 > 0 && x2 < 0); 
    }

    private static double solveEqn(double eqn[], int order, double t) {
        double v = eqn[order];
        while (--order >= 0) {
            v = v * t + eqn[order];
        }
        return v;
    }

    private static double findZero(double t, double target, double eqn[]) {
        double slopeqn[] = {eqn[1], 2*eqn[2], 3*eqn[3]};
        double slope;
        double origdelta = 0;
        double origt = t;
        while (true) {
            slope = solveEqn(slopeqn, 2, t);
            if (slope == 0) {
                // At a local minima - must return
                return t;
            }
            double y = solveEqn(eqn, 3, t);
            if (y == 0) {
                // Found it! - return it
                return t;
            }
            // assert(slope != 0 && y != 0);
            double delta = - (y / slope);
            // assert(delta != 0);
            if (origdelta == 0) {
                origdelta = delta;
            }
            if (t < target) {
                if (delta < 0) return t;
            } else if (t > target) {
                if (delta > 0) return t;
            } else { /* t == target */
                return (delta > 0
                        ? (target + java.lang.Double.MIN_VALUE)
                        : (target - java.lang.Double.MIN_VALUE));
            }
            double newt = t + delta;
            if (t == newt) {
                // The deltas are so small that we aren't moving...
                return t;
            }
            if (delta * origdelta < 0) {
                // We have reversed our path.
                int tag = (origt < t
                           ? getTag(target, origt, t)
                           : getTag(target, t, origt));
                if (tag != INSIDE) {
                    // Local minima found away from target - return the middle
                    return (origt + t) / 2;
                }
                // Local minima somewhere near target - move to target
                // and let the slope determine the resulting t.
                t = target;
            } else {
                t = newt;
            }
        }
    }

    private static final int BELOW = -2;
    private static final int LOWEDGE = -1;
    private static final int INSIDE = 0;
    private static final int HIGHEDGE = 1;
    private static final int ABOVE = 2;

    private static int getTag(double coord, double low, double high) {
        if (coord <= low) {
            return (coord < low ? BELOW : LOWEDGE);
        }
        if (coord >= high) {
            return (coord > high ? ABOVE : HIGHEDGE);
        }
        return INSIDE;
    }

    public static void trySolve3(double r1, double r2, double r3, int count) {
        // (x-r1)*(x-r2)*(x-r3)
        // = (xx - r1x - r2x + r1r2)*(x-r3)
        // = xxx - r1xx - r2xx + r1r2x - r3xx + r1r3x + r2r3x - r1r2r3
        // = xxx - (r1 + r2 + r3)xx + (r1r2 + r2r3 + r3r1)x - r1r2r3
//        System.out.println("solving: (x - "+r1+") * (x - "+r2+") * (x - "+r3+") = 0");
        double eqn[] = new double[4];
        eqn[3] = 1.0;
        eqn[2] = - (r1 + r2 + r3);
        eqn[1] = r1*r2 + r2*r3 + r3*r1;
        eqn[0] = - (r1 * r2 * r3);
        double res[] = new double[4];
        int n = solveCubicOld(eqn, res);
//        System.out.print("old method returned "+n+" roots: ");
//        for (int i = 0; i < n; i++) {
//            System.out.print(res[i]+", ");
//        }
//        System.out.println();
        n = solveCubicNew(eqn, res);
        Arrays.sort(res, 0, n);
        if (n != count) {
        	System.out.printf(r1 + ", " + r2 + ", " + r3 + "\n");
        	System.out.print("There should be " + count + " roots. new method returned "+n+" roots: ");
        	for (int i = 0; i < n; i++) {
        		System.out.print(res[i]+", ");
        	}
        	System.out.println();
        	System.out.println();
        }
    }

    public static void trySolve2(double r1, double r2) {
        trySolve3(r1, r2, r2, 2);
        trySolve3(r1, r1, r2, 2);
    }

    public static void trySolve1(double r) {
        trySolve3(r, r, r, 1);
    }

    public static void trySolve(double r1, double r2, double r3) {
        trySolve3(r1, r2, r3, 3);
        trySolve2(r1, r2);
        trySolve2(r2, r3);
        trySolve2(r3, r1);
        trySolve1(r1);
        trySolve1(r2);
        trySolve1(r3);
        System.out.println("--------------------------------------------------------------");
    }

    public static void main(String argv[]) {
//        trySolve(1, 2, 3);
//        trySolve(15, 103, 27);
//        trySolve3(24.105582313136892, 10.806395151914266, 10.806395151914266);
    	trySolve3(57.732106900005846, 57.732106900005846, -7.5303056978166, 2);
//        double[] eqn = {-6+10, 11, -6, 1};
//        double[] res = new double[4];
//        solveCubicNew(eqn, res);
    	Random r = new Random();
    	double min = -10, max = 100;
    	int numtests = 100;
    	for (int i = 0; i < numtests; i++) {
    		double r1 = r.nextDouble(), r2 = r.nextDouble(), r3 = r.nextDouble();
    		r1 = r1 * (max - min) + min;
    		r2 = r2 * (max - min) + min;
    		r3 = r3 * (max - min) + min;
    		
    		trySolve(r1, r2, r3);
    	}
    }
}

Reply via email to