Hi Jim.

I've attached another modification of your CubicSolver.java.
I tried many different things, but I think what's in the
attachment is the only satisfactory implementation. The logic
is somewhat similar to what you suggested, in that it computes
3 roots for D < 0. However, once roots are computed, it doesn't
try very hard at all to eliminate them in case too many were
computed. I tried doing this, but it was very problematic, because
the only reliable way to count roots is to split the domain into
intervals where the polynomial is strictly increasing or decreasing
by finding the roots of its derivative, evaluating the poly at the
interval end points and then counting sign changes. This works for
well behaved polynomials, but not for edge cases where D is very
small (which is the only situation where we actually need it to work)
because in cases where 3 roots are computed but only 2 exist one of
the critical points will also be a root, so the function will be
locally flat at one of its roots which will make solveEqn(eqn, 3, x)
fluctuate a lot near the root and the assumption that the function
is monotonic in each interval will not hold. Also, it's better to
have too many roots than too few.

I modified trySolve3 to count calls of solveCubicNew that find too
many or too few roots. When I run trySolve 1000 times it never finds
fewer roots than there actually are (or maybe it does but in extremely
rare cases, but I don't remember seeing any instances of this). It finds
too many roots in ~3000 cases compared to ~2500 of the version that
doesn't call fixRoots() (note that this isn't the same fixRoots that
is used by the old function). I think this is very good.

As for performance, it's not as good as the version that doesn't
call fixRoots, but accuracy has improved a lot. I tried to calibrate
Newton's method in the root refining function to get good accuracy
with as few iterations as possible. 3 iterations is a very good
compromise (although we might be able to get away with 2).

Regards,
Denis.

----- Original Message -----
> Hi Denis,
> 
> What about logic like this:
> 
> boolean checkRoots = false;
> if (D < 0) {
> // 3 solution form is possible, so use it
> checkRoots = (D > -TINY); // Check them if we were borderline
> // compute 3 roots as before
> } else {
> double u = ...;
> double v = ...;
> res[0] = u+v; // should be 2*u if D is near zero
> if (u close to v) { // Will be true for D near zero
> res[1] = -res[0]/2; // should be -u if D is near zero
> checkRoots = true; // Check them if we were borderline
> // Note that q=0 case ends up here as well...
> }
> }
> if (checkRoots) {
> if (num > 2 && (res[2] == res[1] || res[2] == res[0]) {
> num--;
> }
> if (num > 1 && res[1] == res[0]) {
> res[1] = res[--num]; // Copies res[2] to res[1] if needed
> }
> for (int i = num-1; i >= 0; i--) {
> res[i] = refine(res[i]);
> for (int j = i+1; j < num; j++) {
> if (res[i] == res[j]) {
> res[i] = res[--num];
> break;
> }
> }
> }
> }
> 
> Note that we lose the optimization of calculating just 2*u and -u for
> the 2 root case, but that only happened in rare circumstances. Also,
> if
> D is near zero and negative, then we generate 3 roots using
> transcendentals and potentially refine one away, but that should also
> be
> an uncommon situation and "there but for the grace of being a tiny
> negative number would we have gone anyway" so I think it is OK to take
> the long way to the answer.
> 
> Also, one could argue that if we used the transcendentals to calculate
> the 3 roots, it couldn't hurt to refine the answers anyway. The other
> solutions should have higher precision, but the transcendental results
> will be much less accurate.
> 
> Finally, this lacks the "refine them anyway if any of them are near 0
> or
> 1" rule - the original only did that if the transcendentals were used,
> but it would be nice to do that for any of the cases. It might make
> sense to have a variant that takes a boolean indicating whether to
> ensure higher accuracy around 0 and 1, but that would require an API
> change request...
> 
> ...jim
> 
> On 1/4/11 2:02 PM, Denis Lila wrote:
> > Hi Jim.
> >
> >> The test as it is has a test case (I just chose random numbers to
> >> check
> >> and got lucky - d'oh!) that generates 1 solution from the new code
> >> even
> >> though the equation had 2 distinct solutions that weren't even near
> >> each
> >> other...
> >
> > I figured out why this happens. It's because of cancellation in the
> > computation of D (two large numbers are subtracted and the result is
> > supposed to be 0 or close to 0, but it's about 1e-7, which wasn't
> > enough to pass the iszero test). I've been working on this and I
> > came up with a couple of different ways. They are in the attached
> > file (it's a modified version of the file your CubicSolve.java).
> >
> > The first thing I did was to modify solveCubicOld. I tried to get
> > a bit fancy and although I think I fixed the problems it had, the
> > end result is ugly, complicated and it has small problems, like
> > returning 3 very close roots when there should only be one.
> >
> > The other solution is to just check if the roots of the derivative
> > are also roots of the cubic polynomial if only 1 root was computed
> > by the closed form algorithm. This doesn't have the numerical
> > accuracy of the first way (which used bisectRoots when things went
> > wrong)
> > but it's much faster, doesn't have the multiple roots problem, and
> > it's
> > much simpler. I called your trySolve function on a few hundred
> > polynomials with random roots in [-10, 10] and it never finds fewer
> > roots than there actually are. Sometimes it finds 3 roots when there
> > are
> > only 2, but I don't think this is a huge problem.
> >
> > I've attached what I have so far.
> >
> > Regards,
> > Denis.
> >
> > ----- Original Message -----
> >> Hi Denis,
> >>
> >> I'm attaching a test program I wrote that compares the old and new
> >> algorithms.
> >>
> >> Obviously the old one missed a bunch of solutions because it
> >> classified
> >> all solutions as 1 or 3, but the new one also sometimes misses a
> >> solution. You might want to turn this into an automated test for
> >> the
> >> bug (and maybe use it as a stress test with a random number
> >> generator).
> >>
> >> I think one problem might be that you use "is close to zero" to
> >> check
> >> if
> >> you should use special processing. I think any tests which say "do
> >> it
> >> this way and get fewer roots" should be conservative and if we are
> >> on
> >> the borderline and we can do the code that generates more solutions
> >> then
> >> we should generate more and them maybe refine the roots and
> >> eliminate
> >> duplicates. That way we can be (more) sure not to leave any roots
> >> unsolved.
> >>
> >
> >>
> >> ...jim
import java.awt.geom.QuadCurve2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import static java.lang.System.out;

import static java.lang.Math.abs;
import static java.lang.Math.max;
import static java.lang.Math.ulp;

public class CubicSolver {
    public static int solveCubicOld(double eqn[], double res[]) {
        if (res == eqn) {
            // Copy the eqn so that we don't clobber it with the
            // roots.
            eqn = new double[4];
            System.arraycopy(res, 0, eqn, 0, 4);
        }

        // From Numerical Recipes, 5.6, Quadratic and Cubic Equations
        double d = eqn[3];
        if (d == 0.0) {
            // The cubic has degenerated to quadratic (or line or ...).
            return QuadCurve2D.solveQuadratic(eqn, res);
        }
        double a = eqn[2] / d;
        double b = eqn[1] / d;
        double c = eqn[0] / d;
        int roots = 0;
        double Q = (a * a - 3.0 * b) / 9.0;
        double R = (2.0 * a * a * a - 9.0 * a * b + 27.0 * c) / 54.0;
        double R2 = R * R;
        double Q3 = Q * Q * Q;
        a = a / 3.0;
        if (R2 < Q3) {
            double theta = Math.acos(R / Math.sqrt(Q3));
            Q = -2.0 * Math.sqrt(Q);

            res[roots++] = Q * Math.cos(theta / 3.0) - a;
            res[roots++] = Q * Math.cos((theta + Math.PI * 2.0)/ 3.0) - a;
            res[roots++] = Q * Math.cos((theta - Math.PI * 2.0)/ 3.0) - a;
        } else {
            boolean neg = (R < 0.0);
            double S = Math.sqrt(R2 - Q3);
            if (neg) {
                R = -R;
            }
            double A = Math.pow(R + S, 1.0 / 3.0);
            if (!neg) {
                A = -A;
            }
            double B = (A == 0.0) ? 0.0 : (Q / A);
            res[roots++] = (A + B) - a;
        }

        if (roots == 3) {
            fixRoots(res, eqn);
        }

        return roots;
    }

    private static boolean iszero(double x, double err) {
        return within(x, 0, err);
    }
    
    private static boolean within(final double x, final double y, final double err) {
        final double d = y - x;
        return (d <= err && d >= -err);
    }

    private static void fixRoots(double res[], double eqn[]) {
        final double EPSILON = 1E-5;
        for (int i = 0; i < 3; i++) {
            double t = res[i];
            if (Math.abs(t) < EPSILON) {
                res[i] = findZero(t, 0, eqn);
            } else if (Math.abs(t - 1) < EPSILON) {
                res[i] = findZero(t, 1, eqn);
            }
        }
    }
    public static int solveCubicNew(double eqn[], double res[]) {
        // From Graphics Gems:
        // http://tog.acm.org/resources/GraphicsGems/gems/Roots3And4.c
        final double d = eqn[3];
        if (d == 0) {
            return QuadCurve2D.solveQuadratic(eqn, res);
        }

        /* normal form: x^3 + Ax^2 + Bx + C = 0 */
        final double A = eqn[2] / d;
        final double B = eqn[1] / d;
        final double C = eqn[0] / d;


        //  substitute x = y - A/3 to eliminate quadratic term:
        //     x^3 +px + q = 0

        double sq_A = A * A;
        double p = 1.0/3 * (-1.0/3 * sq_A + B);
        double q = 1.0/2 * (2.0/27 * A * sq_A - 1.0/3 * A * B + C);

        /* use Cardano's formula */

        double cb_p = p * p * p;
        double D = q * q + cb_p;

        final double sub = 1.0/3 * A;

        int num;
        if (D < 0) { /* Casus irreducibilis: three real solutions */
            final double phi = 1.0/3 * Math.acos(-q / Math.sqrt(-cb_p));
            final double t = 2 * Math.sqrt(-p);

            if (res == eqn) {
                eqn = Arrays.copyOf(res, 4);
            }

            res[ 0 ] =  ( t * Math.cos(phi));
            res[ 1 ] =  (-t * Math.cos(phi + Math.PI / 3));
            res[ 2 ] =  (-t * Math.cos(phi - Math.PI / 3));
            num = 3;

            for (int i = 0; i < num; ++i)
                res[ i ] -= sub;

        } else {
            final double sqrt_D = Math.sqrt(D);
            final double u = Math.cbrt(sqrt_D - q);
            final double v = - Math.cbrt(sqrt_D + q);
            final double uv = u+v;

            res[ 0 ] =  uv - sub;
            num = 1;

            final double err = 1200000000*ulp(abs(uv) + abs(sub));
            if (iszero(D, err) || within(u, v, err)) {
                if (res == eqn) {
                    eqn = Arrays.copyOf(res, 4);
                }
            	res[1] = -(uv / 2) - sub;
            	num = 2;
            }
        }

        if (num == 3 || num == 2) {
            num = fixRoots(eqn, res, num);
        }
        if (num > 2 && (res[2] == res[1] || res[2] == res[0])) {
            num--;
        }
        if (num > 1 && res[1] == res[0]) {
            res[1] = res[--num]; // Copies res[2] to res[1] if needed
        }
        return num;
    }

    private static int fixRoots(double[] eqn, double[] res, final int num) {
    	double[] intervals = {eqn[1], 2*eqn[2], 3*eqn[3]};
    	int critCount = QuadCurve2D.solveQuadratic(intervals, intervals);
    	if (critCount == 2 && intervals[0] == intervals[1]) {
    		critCount--;
        }
    	if (critCount == 2 && intervals[0] > intervals[1]) {
    		double tmp = intervals[0];
    		intervals[0] = intervals[1];
    		intervals[1] = tmp;
    	}

    	double xe = getRootUpperBound(eqn);
    	xe += max(ulp(xe), 1);
        double x0 = -xe;

    	double fx0 = eqn[3] > 0 ? -1 : 1;
    	double fxe = -fx0;

    	if (num == 3) {
    		Arrays.sort(res, 0, num);
    		if (critCount == 2) {
    			res[0] = refineRootWithHint(eqn, x0, intervals[0], res[0]);
    			res[1] = refineRootWithHint(eqn, intervals[0], intervals[1], res[1]);
    			res[2] = refineRootWithHint(eqn, intervals[1], xe, res[2]);
    			return 3;
    		} else if (critCount == 1) {
    			double x1 = intervals[0];
    			double fx1 = solveEqn(eqn, 3, x1);

    			if (oppositeSigns(fx0, fx1)) {
    				res[0] = bisectRootWithHint(eqn, x0, x1, res[0]);
    			} else if (oppositeSigns(fx1, fxe)) {
    				res[0] = bisectRootWithHint(eqn, x1, xe, res[2]);
    			} else /* c1 must be 0 */ {
    				res[0] = x1;
    			}
    		} else if (critCount == 0) {
    			res[0] = bisectRootWithHint(eqn, x0, xe, res[1]);
    		}
    	} else if (num == 2 && critCount == 2) {
    		double badRoot = res[1];
    		double goodRoot = res[0];
    		double x1 = intervals[0];
    		double x2 = intervals[1];
    		double x = abs(x1 - goodRoot) > abs(x2 - goodRoot) ? x1 : x2;
    		double fx = solveEqn(eqn, 3, x);

    		if (iszero(fx, 10000000*ulp(x))) { 
    			double badRootVal = solveEqn(eqn, 3, badRoot);
    			res[1] = abs(badRootVal) < abs(fx) ? badRoot : x;
    			return 2;
    		}
    	} // else there can only be one root - goodRoot, and it is already in res[0]

        return 1;
    }

    // use newton's method.
    private static double refineRootWithHint(double[] eqn, double min, double max, double t) {
    	if (!inInterval(t, min, max)) {
    		return t;
    	}
        double[] deriv = {eqn[1], 2*eqn[2], 3*eqn[3]};
        double slope;
        double origt = t;
        for (int i = 0; i < 3; i++) {
            slope = solveEqn(deriv, 2, t);
            double y = solveEqn(eqn, 3, t);
            double delta = - (y / slope);
            double newt = t + delta;

            if (slope == 0 || y == 0 || t == newt) {
                break;
            }

            t = newt;
        }
        if (within(t, origt, 1000*ulp(origt)) && inInterval(t, min, max)) {
            return t;
        }
        return origt;
    }
    
    private static boolean inInterval(double t, double min, double max) {
    	return min <= t && t <= max;
    }

    public static int solveCubic(double eqn[], double res[]) {
        // From Graphics Gems:
        // http://tog.acm.org/resources/GraphicsGems/gems/Roots3And4.c
        final double d = eqn[3];
        if (d == 0) {
            return QuadCurve2D.solveQuadratic(eqn, res);
        }

        if (res == eqn) {
            // Copy the eqn so that we don't clobber it with the roots.
            eqn = new double[4];
            System.arraycopy(res, 0, eqn, 0, 4);
        }

        /* normal form: x^3 + Ax^2 + Bx + C = 0 */
        final double A = eqn[2] / d;
        final double B = eqn[1] / d;
        final double C = eqn[0] / d;


        //  substitute x = y - A/3 to eliminate quadratic term:
        //     x^3 +px + q = 0

        double sq_A = A * A;
        double p = 1.0/3 * (-1.0/3 * sq_A + B);
        double q = 1.0/2 * (2.0/27 * A * sq_A - 1.0/3 * A * B + C);

        /* use Cardano's formula */

        double cb_p = p * p * p;
        double D = q * q + cb_p;

        int num;
        if (D < 0) { /* Casus irreducibilis: three real solutions */
            final double phi = 1.0/3 * Math.acos(-q / Math.sqrt(-cb_p));
            final double t = 2 * Math.sqrt(-p);

            res[ 0 ] =  ( t * Math.cos(phi));
            res[ 1 ] =  (-t * Math.cos(phi + Math.PI / 3));
            res[ 2 ] =  (-t * Math.cos(phi - Math.PI / 3));
            num = 3;
        } else { /* one real solution */
            final double sqrt_D = Math.sqrt(D);
            final double u = Math.cbrt(sqrt_D - q);
            final double v = -Math.cbrt(sqrt_D + q);

            res[ 0 ] =  u + v;
            num = 1;
            if (within(u, v, 1e-5)) {
            	res[1] = -res[0] / 2;
            	num = 2;
            }
        }

        /* resubstitute */

        final double sub = 1.0/3 * A;

        for (int i = 0; i < num; ++i)
            res[ i ] -= sub;
        return num;
    }
    
    private static double bisectRootWithHint(double[] eqn, double x0, double xe, double hint) {
    	double delta1 = Math.min(abs(hint - x0) / 64, 0.0625);
    	double delta2 = Math.min(abs(hint - xe) / 64, 0.0625);
    	double x02 = hint - delta1;
    	double xe2 = hint + delta2;
    	double fx02 = solveEqn(eqn, 3, x02);
    	double fxe2 = solveEqn(eqn, 3, xe2);
    	while(oppositeSigns(fx02, fxe2)) {
    		if (x02 >= xe2) {
    			return x02;
    		}
    		x0 = x02;
    		xe = xe2;
    		delta1 /= 64;
    		delta2 /= 64;
    		x02 = hint - delta1;
    		xe2 = hint + delta2;
        	fx02 = solveEqn(eqn, 3, x02);
        	fxe2 = solveEqn(eqn, 3, xe2);
    	}
      	if (fx02 == 0) {
    	    return x02;
    	}
    	if (fxe2 == 0) {
    	    return xe2;
    	}
    	
    	return bisectRoot(eqn, x0, xe);
    }

    private static double bisectRoot(double[] eqn, double x0, double xe) {
		double fx0 = solveEqn(eqn, 3, x0);
		double m = x0 + (xe - x0) / 2;
		double fm = solveEqn(eqn, 3, m);
		while (m != x0 && m != xe) {
			if (fm == 0) {
				return m;
			}
			if (oppositeSigns(fx0, fm)) {
				xe = m;
			} else {
				fx0 = fm;
				x0 = m;
			}
			m = x0 + (xe-x0)/2;
			fm = solveEqn(eqn, 3, m);
		}
		return m;
    }

    private static double getRootUpperBound(double[] eqn) {
        final double d = eqn[3];
        final double a = eqn[2];
        final double b = eqn[1];
        final double c = eqn[0];
        
        double M = 1 + max(max(abs(a), abs(b)), abs(c)) / abs(d);
        M += max(ulp(M), 1);
        return M;
    }

    private static boolean oppositeSigns(double x1, double x2) {
    	return (x1 < 0 && x2 > 0) || (x1 > 0 && x2 < 0); 
    }

    private static double solveEqn(double eqn[], int order, double t) {
        double v = eqn[order];
        while (--order >= 0) {
            v = v * t + eqn[order];
        }
        return v;
    }
    
    private static double findZero(double t, double target, double eqn[]) {
        double slopeqn[] = {eqn[1], 2*eqn[2], 3*eqn[3]};
        double slope;
        double origdelta = 0;
        double origt = t;
        while (true) {
            slope = solveEqn(slopeqn, 2, t);
            if (slope == 0) {
                // At a local minima - must return
                return t;
            }
            double y = solveEqn(eqn, 3, t);
            if (y == 0) {
                // Found it! - return it
                return t;
            }
            // assert(slope != 0 && y != 0);
            double delta = - (y / slope);
            // assert(delta != 0);
            if (origdelta == 0) {
                origdelta = delta;
            }
            if (t < target) {
                if (delta < 0) return t;
            } else if (t > target) {
                if (delta > 0) return t;
            } else { /* t == target */
                return (delta > 0
                        ? (target + java.lang.Double.MIN_VALUE)
                        : (target - java.lang.Double.MIN_VALUE));
            }
            double newt = t + delta;
            if (t == newt) {
                // The deltas are so small that we aren't moving...
                return t;
            }
            if (delta * origdelta < 0) {
                // We have reversed our path.
                int tag = (origt < t
                           ? getTag(target, origt, t)
                           : getTag(target, t, origt));
                if (tag != INSIDE) {
                    // Local minima found away from target - return the middle
                    return (origt + t) / 2;
                }
                // Local minima somewhere near target - move to target
                // and let the slope determine the resulting t.
                t = target;
            } else {
                t = newt;
            }
        }
    }

    private static final int BELOW = -2;
    private static final int LOWEDGE = -1;
    private static final int INSIDE = 0;
    private static final int HIGHEDGE = 1;
    private static final int ABOVE = 2;

    private static int getTag(double coord, double low, double high) {
        if (coord <= low) {
            return (coord < low ? BELOW : LOWEDGE);
        }
        if (coord >= high) {
            return (coord > high ? ABOVE : HIGHEDGE);
        }
        return INSIDE;
    }

    private static double[] makeEqnWithRoots(double r1, double r2, double r3) {
        double eqn[] = new double[4];
        eqn[3] = 1.0;
        eqn[2] = - (r1 + r2 + r3);
        eqn[1] = r1*r2 + r2*r3 + r3*r1;
        eqn[0] = - (r1 * r2 * r3);
    	return eqn;
    }

    private static int oldTooMany = 0;
    private static int oldTooFew = 0;
    private static int newTooFew = 0;
    private static int newTooMany = 0;
    private static List<Double> oldDiffs = new ArrayList<Double>();
    private static List<Double> newDiffs = new ArrayList<Double>();
    private static int oldWins = 0;
    private static int newWins = 0;

    private static double rd(Random r, double min, double max) {
    	return r.nextDouble() * (max - min) + min;
    }
    private static void perfTest(int n) {
    	double[][] eqns = new double[n][];
    	double[][] oldres = new double[n][4];
    	int[] oldNumRoots = new int[n];
        double[][] newres = new double[n][4];
        int[] newNumRoots = new int[n];
        double min = -1000;
        double max = 1000;
    	Random r = new Random();
    	for(int i = 0; i < n; i++) {
    	    eqns[i] = makeEqnWithRoots(rd(r, min, max), rd(r, min, max), rd(r, min, max));
    	}

    	long t1 = System.nanoTime();
    	for(int i = 0; i < n; i++) {
    		newNumRoots[i] = solveCubicNew(eqns[i], newres[i]);
    	}
    	long t2 = System.nanoTime();
    	for(int i = 0; i < n; i++) {
    		oldNumRoots[i] = solveCubic(eqns[i], oldres[i]);
    	}
    	long t3 = System.nanoTime();
    	System.out.println("Test2, solve NewNewNew took: " + (t2 - t1));
    	System.out.println("Test2, solve NewNew took: " + (t3 - t2));
        System.out.println(Arrays.toString(newres[abs(r.nextInt())%n]));
        System.out.println(Arrays.toString(oldres[abs(r.nextInt())%n]));

        compStats(eqns, oldres, oldNumRoots, newres, newNumRoots);
    }

    private static void compStats(double[][] eqns, double[][] oldres, int[] oldNumRoots, double[][] newres, int[] newNumRoots) {
        int disagreements = 0;
        int num = oldNumRoots.length;
        List<Double> olddiffs = new LinkedList<Double>();
        List<Double> newdiffs = new LinkedList<Double>();
        for (int i = 0; i < num; i++) {
        	int oldNum = oldNumRoots[i];
        	int newNum = newNumRoots[i];
        	double[] eqn = eqns[i];
        	if (oldNum == newNum) {
        		for (int j = 0; j < oldNum; j++) {
        			double oldroot = oldres[i][j];
        			double newroot = newres[i][j];
        			olddiffs.add(abs(solveEqn(eqn, 3, oldroot)));
        			newdiffs.add(abs(solveEqn(eqn, 3, newroot)));
        		}
        	} else {
        		disagreements++;
        	}
        }

        int same = 0, newWin = 0, oldWin = 0;
        double maxoldfiff = 0, maxnewdiff = 0;
        double meanolddiff = 0, meannewdiff = 0;
        for (int i = 0; i < olddiffs.size(); i++) {
        	double o = olddiffs.get(i);
        	double n = newdiffs.get(i);

        	meanolddiff += o;
        	meannewdiff += n;

        	maxoldfiff = max(maxoldfiff, o);
        	maxnewdiff = max(maxnewdiff, n);

        	if (o < n) {
        		oldWin++;
        	} else if (n < o) {
        		newWin++;
        	} else {
        		same++;
        	}
        }

        out.println("Out of " + num + " equations:");
        out.printf("%d disagreements on the number of roots%n", disagreements);
        out.printf("sameDiffs = %d, newWins = %d, oldwins = %d%n", same, newWin, oldWin);
        out.printf("maxolddiff   = %.22f, maxnewdiff   = %.22f%n", maxoldfiff, maxnewdiff);
        out.printf("totalolddiff = %.22f, totalnewdiff = %.22f%n", meanolddiff, meannewdiff);
        meanolddiff /= olddiffs.size();
        meannewdiff /= newdiffs.size();
        out.printf("meanolddiff  = %.22f, meannewdiff  = %.22f%n", meanolddiff, meannewdiff);
    }

    
    public static void trySolve3(double r1, double r2, double r3, int count) {
        // (x-r1)*(x-r2)*(x-r3)
        // = (xx - r1x - r2x + r1r2)*(x-r3)
        // = xxx - r1xx - r2xx + r1r2x - r3xx + r1r3x + r2r3x - r1r2r3
        // = xxx - (r1 + r2 + r3)xx + (r1r2 + r2r3 + r3r1)x - r1r2r3
//        System.out.println("solving: (x - "+r1+") * (x - "+r2+") * (x - "+r3+") = 0");
        double eqn[] = makeEqnWithRoots(r1, r2, r3);
        double resOld[] = new double[4];
        int n = solveCubic(eqn, resOld);
        Arrays.sort(resOld, 0, n);
        boolean otf = false;
        if (n < count) {
        	otf = true;
        	oldTooFew++;
        } else if (n > count) {
        	oldTooMany++;
        }

        double resNew[] = new double[4];
        int m = solveCubicNew(eqn, resNew);
        Arrays.sort(resNew, 0, m);
        boolean ntf = false;
        if (m < count) {
        	ntf = true;
        	newTooFew++;
        } else if (m > count) {
        	newTooMany++;
        }

        if (n == count && m == count) {
        	double[] roots = {r1, r2, r3};
        	Arrays.sort(roots, 0, 3);
        	if (r1 == r2) {
        		roots[1] = roots[2];
        	}
        	double dold = 0;
        	double dnew = 0;
            for (int i = 0; i < count; i++) {
            	double oldRoot = resOld[i];
            	double newRoot = resNew[i];
            	double actRoot = roots[i];
            	dold += abs(oldRoot - actRoot);
            	dnew += abs(newRoot - actRoot);
            }

            oldDiffs.add(dold);
            newDiffs.add(dnew);
            if (dold < dnew) {
            	System.out.printf(r1 + ", " + r2 + ", " + r3 + "\n");
            	oldWins++;
            } else if (dnew < dold) {
            	newWins++;
            }
        }

//        if (m < count) {
//        	System.out.printf(r1 + ", " + r2 + ", " + r3 + "\n");
//        }
        if (m != count) {
        	System.out.printf(r1 + ", " + r2 + ", " + r3 + "\n");
        	System.out.print("There should be " + count + " roots. new method returned "+m+" roots: ");
        	for (int i = 0; i < m; i++) {
        		System.out.print(resNew[i]+", ");
        	}
        	System.out.println();
        	System.out.println();
        } else {
        	System.out.print("Good: ");
        	for (int i = 0; i < m; i++) {
        		System.out.print(resNew[i]+", ");
        	}
        	System.out.println();
        }
    }

    private static void printStats() {
    	double mean = 0;
    	for (Double d : oldDiffs) {
    		mean += d;
    	}
    	mean /= oldDiffs.size();
    	double stdDev = 0;
    	for (Double d : oldDiffs) {
    		stdDev += (d - mean) * (d - mean);
    	}
    	stdDev /= oldDiffs.size();
    	stdDev = Math.sqrt(stdDev);

    	System.out.printf("Old mean = %f, stdDev = %f, old wins =%d%n", mean, stdDev, oldWins);

    	mean = 0;
    	for (Double d : newDiffs) {
    		mean += d;
    	}
    	mean /= newDiffs.size();
    	stdDev = 0;
    	for (Double d : newDiffs) {
    		stdDev += (d - mean) * (d - mean);
    	}
    	stdDev /= newDiffs.size();
    	stdDev = Math.sqrt(stdDev);
    	System.out.printf("New mean = %f, stdDev = %f, new wins =%d%n", mean, stdDev, newWins);
    	System.out.println();
    	System.out.println("oldTooMany : " + oldTooMany);
    	System.out.println("oldTooFew  : " + oldTooFew );
    	System.out.println("newTooFew  : " + newTooFew );
    	System.out.println("newTooMany : " + newTooMany);
    }

    public static void trySolve2(double r1, double r2) {
        trySolve3(r1, r2, r2, 2);
        trySolve3(r1, r1, r2, 2);
    }

    public static void trySolve1(double r) {
        trySolve3(r, r, r, 1);
    }

    public static void trySolve(double r1, double r2, double r3) {
        trySolve3(r1, r2, r3, 3);
        trySolve2(r1, r2);
        trySolve2(r2, r3);
        trySolve2(r3, r1);
        trySolve1(r1);
        trySolve1(r2);
        trySolve1(r3);
        System.out.println("--------------------------------------------------------------");
    }

    public static void main(String argv[]) {
    	trySolve3(-798.2134325057966, -798.2134325057966, -218.8642676477499, 2);
    	Random r = new Random();
    	double min = -1000, max = 1000;
    	int numtests = 1000;
    	for (int i = 0; i < numtests; i++) {
    		double r1 = rd(r, min, max), r2 = rd(r, min, max), r3 = rd(r, min, max);
    		trySolve(r1, r2, r3);
    	}
    	printStats();

//        perfTest(20000);
    }
}

Reply via email to