There is a much better way to implement java.lang.Math.fma(double, double, 
double) than the current default implementation in 
src/java.base/share/classes/java/lang/Math.java.

Here is a much better implementation of java.lang.Math.fma (and this has been 
shown to be much faster than (new BigDecimal(a)).multiply(new 
BigDecimal(b)).add(new BigDecimal(c)).doubleValue() if a, b, and c are all 
finite floating-point values through benchmarking):
    private static double scalbFiniteF64Sum(double hi, double lo, int exp) {
        // Both hi and lo should be finite, and |hi| >= |lo| should be true if
        // hi != 0.0

        final double sum0 = hi + lo;
        final double scaleResult0 = Math.scalb(sum0, exp);
        if (!(Math.abs(scaleResult0) <= Double.MIN_NORMAL) || lo == 0.0) {
            // If |scaleResult0| > Double.MIN_NORMAL or lo == 0.0, then
            // scaleResult0 will be correctly rounded.

            // Return scaleResult0 in this case.

            return scaleResult0;
        }

        // scaleResult1 is equal to scaleResult0 with its ULP bit cleared.
        // This ensures that the final result will be correctly rounded if
        // the ULP bit of scaleResult0 is set.
        final double scaleResult1 = Double.longBitsToDouble(
                Double.doubleToRawLongBits(scaleResult0) & (-2L));

        // scaleErr is equal to the error of scaling sum0 by 2**exp, scaled back
        // up by 2**-exp. scaleResult1 is used to compute scaleErr to ensure
        // that the value of the ULP bit of scaleResult0, scaled up by 2**-exp,
        // is included in scaleErr to ensure that the final result is correctly
        // rounded.
        final double scaleErr = sum0 - Math.scalb(scaleResult1, -exp);

        // err0 is equal to the error of hi + lo
        final double err0 = (hi - sum0) + lo;

        // Compute scaleErr + err0 using the Fast2Sum algorithm as
        // |scaleErr| >= |err0| should be true if |scaleErr| > 0
        final double sum1 = scaleErr + err0;
        final double err1 = (scaleErr - sum1) + err0;

        final long sum1Bits = Double.doubleToRawLongBits(sum1);
        final long err1Bits = Double.doubleToRawLongBits(err1);
        final long sum1IsInexactInSignBit = err1Bits
                ^ (err1Bits + 0x7FFF_FFFF_FFFF_FFFFL);

        // roundedToOddSum1 is equal to the rounded to odd sum of scaleErr and
        // err0
        final double roundedToOddSum1 = Double.longBitsToDouble((sum1Bits
                + (((sum1Bits ^ err1Bits) & sum1IsInexactInSignBit) >> 63))
                | (sum1IsInexactInSignBit >>> 63));

        // scaleResult2 is equal to the correctly rounded value of
        // (hi + lo) * 2**exp
        final double scaleResult2 = scaleResult1
                + Math.scalb(roundedToOddSum1, exp);

        return scaleResult2;
    }

    public static double fma(double a, double b, double c) {
        final long aBits = Double.doubleToRawLongBits(a);
        final long bBits = Double.doubleToRawLongBits(b);
        final long cBits = Double.doubleToRawLongBits(c);

        final int aBiasedExp = (int) (aBits >>> 52) & 0x7FF;
        final int bBiasedExp = (int) (bBits >>> 52) & 0x7FF;
        final int cBiasedExp = (int) (cBits >>> 52) & 0x7FF;

        if (((aBits ^ (aBits - 1)) | (bBits ^ (bBits - 1))
                | (cBits ^ (cBits - 1)) | (0x7FE - aBiasedExp)
                | (0x7FE - bBiasedExp)) < 0) {
            // If at least one of a, b, or c are zero or at least one of a or b
            // is non-finite, the result will be equal to a * b + c.
            return a * b + c;
        } else if (cBiasedExp == 0x7FF) {
            // If a and b are both nonzero finite numbers and c is a non-finite
            // value, simply return c since the exact result of a * b + c will
            // be equal to c in this case.
            return c;
        }

        // a, b, and c are all nonzero finite values at this point

        // Normalize a and b to normal floating-point numbers

        final int aIsDenormalMask = (aBiasedExp - 1) >> 31;
        final int bIsDenormalMask = (bBiasedExp - 1) >> 31;

        final long aNormalizeAdjBits = (aBits & 0x8000_0000_0000_0000L)
                | (aIsDenormalMask & 0x0350_0000_0000_0000L);
        final long bNormalizeAdjBits = (bBits & 0x8000_0000_0000_0000L)
                | (bIsDenormalMask & 0x0350_0000_0000_0000L);

        final double normalizedA =
                Double.longBitsToDouble(aBits | aNormalizeAdjBits)
                - Double.longBitsToDouble(aNormalizeAdjBits);
        final double normalizedB =
                Double.longBitsToDouble(bBits | bNormalizeAdjBits)
                - Double.longBitsToDouble(bNormalizeAdjBits);

        // If a is a denormal number, normalizedA is equal to a * 2**52.
        // Otherwise, if a is already a normal number, normalizedA is equal to
        // a.

        // If b is a denormal number, normalizedB is equal to b * 2**52.
        // Otherwise, if b is already a normal number, normalizedB is equal to
        // b.

        final long normalizedABits = Double.doubleToRawLongBits(normalizedA);
        final long normalizedBBits = Double.doubleToRawLongBits(normalizedB);

        final int normalizedABiasedExp = (int) (normalizedABits >>> 52) & 0x7FF;
        final int normalizedBBiasedExp = (int) (normalizedBBits >>> 52) & 0x7FF;

        // minPBiasedExp is the smallest possible biased exponent (with an
        // exponent bias of 1023) of the exact value of a * b.
        final int minPBiasedExp = normalizedABiasedExp + normalizedBBiasedExp
                + (aIsDenormalMask & -52) + (bIsDenormalMask & -52) + -1023;

        if (minPBiasedExp >= 2048) {
            // If minPBiasedExp >= 2048, then a * b + c is known to overflow to
            // infinity. Return a * b + c in this case.
            return a * b + c;
        }

        // Normalize c to a normal floating-point number

        final int cIsDenormalMask = (cBiasedExp - 1) >> 31;

        final long cNormalizeAdjBits = (cBits & 0x8000_0000_0000_0000L)
                | (cIsDenormalMask & 0x0350_0000_0000_0000L);

        final double normalizedC =
                Double.longBitsToDouble(cBits | cNormalizeAdjBits)
                - Double.longBitsToDouble(cNormalizeAdjBits);

        // If c is a denormal number, normalizedC is equal to c * 2**52.
        // Otherwise, if c is already a normal number, normalizedC is equal to
        // c.

        final long normalizedCBits = Double.doubleToRawLongBits(normalizedC);
        final int normalizedCBiasedExp = (int) (normalizedCBits >>> 52) & 0x7FF;

        final int adjCBiasedExp = normalizedCBiasedExp
                + (cIsDenormalMask & -52);

        // adjCBiasedExp is equal to floor(log2(|c|)) + 1023, even if c is a
        // denormal number

        final int expDiff = adjCBiasedExp - minPBiasedExp;
        if (expDiff >= 55) {
            // If expDiff is greater than or equal to 55, the exact value of
            // |a * b| is less than 0.5 * ulp(c). The correctly rounded result
            // of a * b + c is known to be equal to c in this case. Return c in
            // this case as the exact value of |a * b| is too small to affect
            // the correctly rounded result of a * b + c.

            return c;
        }

        // aMant is equal to the mantissa of a, with 1 <= |a| < 2
        // bMant is equal to the mantissa of b, with 1 <= |b| < 2
        final double aMant =
                Double.longBitsToDouble(((normalizedABits | 
0x3FF0_0000_0000_0000L)
                        & 0xBFFF_FFFF_FFFF_FFFFL));
        final double bMant =
                Double.longBitsToDouble(((normalizedBBits | 
0x3FF0_0000_0000_0000L)
                        & 0xBFFF_FFFF_FFFF_FFFFL));

        // Split aMant and bMant using Veltkamp-Dekker splitting
        final double aMantGamma = aMant * 134217729.0;
        final double bMantGamma = bMant * 134217729.0;

        final double aMantHi = aMantGamma + (aMant - aMantGamma);
        final double bMantHi = bMantGamma + (bMant - bMantGamma);

        final double aMantLo = aMant - aMantHi;
        final double bMantLo = bMant - bMantHi;

        final double pHi = aMant * bMant;
        final double pLo = ((aMantHi * bMantHi - pHi)
                + (aMantHi * bMantLo + aMantLo * bMantHi)) + aMantLo * bMantLo;

        if (minPBiasedExp >= 1 && pLo == 0.0) {
            // If minPBiasedExp >= 1 and pLo == 0.0 are both true, then a * b is
            // known to be exact and normal. Return a * b + c in this case.

            return a * b + c;
        }

        final int resultScaleUpExp = minPBiasedExp - 1023;

        if (expDiff <= -106) {
            // If expDiff <= -106 is true, then |a * b| >= 2**-968 must be true
            // since 2**-1074 <= |c| < 2**-105 * |a * b|.

            if (pLo == 0.0) {
                // If expDiff <= -106 and pLo == 0.0 are both true, then |c|
                // is too small to affect the correctly rounded result.

                // Return a * b in this case since a * b is known to be either
                // an exact normal finite number or infinity and as |c| is too
                // small to affect the result of a * b + c.
                return a * b;
            } else {
                // If expDiff <= -106 and pLo != 0.0 are both true, then
                // 0 < |c * 2**-resultScaleUpExp| < 0.5 * ulp(pLo) is known
                // to be true.

                // In this case, the rounded to odd sum of pLo and
                // c * 2**-resultScaleUpExp can be computed by decrementing
                // pLoBits by 1 if pLoBits and cBits have different signs
                // rounded to odd sum of pLo and c * 2**-resultScaleUpExp
                // can be computed by simply decrementing pLoBits by 1 if pLo
                // and c have different signs followed by setting the LSB bit as
                // in this case.

                final long pLoBits = Double.doubleToRawLongBits(pLo);

                // roundedToOddLoSum is equal to the rounded to odd sum of
                // pLo and c * 2**-resultScaleUpExp
                final double roundedToOddLoSum = Double.longBitsToDouble(
                        (pLoBits + ((pLoBits ^ cBits) >> 63)) | 1L);

                // Return the correctly rounded value of
                // (s1 + roundedToOddLoSum) * 2**resultScaleUpExp
                return scalbFiniteF64Sum(pHi, roundedToOddLoSum,
                        resultScaleUpExp);
            }
        }

        // -105 <= expDiff <= 54 is now true at this point

        final double scaledC =
            Double.longBitsToDouble(((normalizedCBits | 0x3FF0_0000_0000_0000L)
                        & 0xBFFF_FFFF_FFFF_FFFFL) + ((long) expDiff << 52));

        // Compute pLo + scaledC using the 2Sum algorithm

        final double s0 = pLo + scaledC;
        final double v0 = s0 - pLo;
        final double e0 = (pLo - (s0 - v0)) + (scaledC - v0);

        // s0 + e0 == pLo + scaledC

        // Compute pHi + s0 using the 2Sum algorithm

        final double s1 = pHi + s0;
        final double v1 = s1 - pHi;
        final double e1 = (pHi - (s1 - v1)) + (s0 - v1);

        // s1 + e1 + e0 == pHi + pLo + scaledC

        // Compute the rounded to odd sum of e1 and e0 to ensure that the final
        // result is correctly rounded.

        // Fast2Sum is sufficient for computing e1 + e0 as either |e1| == 0.0 or
        // |e1| >= |e0| should be true

        final double s2 = e1 + e0;
        final double e2 = (e1 - s2) + e0;

        // s1 + s2 + e2 == pHi + pLo + scaledC

        final long s2Bits = Double.doubleToRawLongBits(s2);
        final long e2Bits = Double.doubleToRawLongBits(e2);
        final long s2IsInexactInSignBit = e2Bits
                ^ (e2Bits + 0x7FFF_FFFF_FFFF_FFFFL);

        // roundedToOddS2 is equal to the rounded to odd sum of e1 and e0 and
        // is used to ensure that the final result is correctly rounded
        final double roundedToOddS2 = Double.longBitsToDouble(
                (s2Bits + (((s2Bits ^ e2Bits) & s2IsInexactInSignBit) >> 63))
                        | (s2IsInexactInSignBit >>> 63));

        // Return the correctly rounded value of
        // (s1 + roundedToOddS2) * 2**resultScaleUpExp
        return scalbFiniteF64Sum(s1, roundedToOddS2, resultScaleUpExp);
    }

There are a few tricks up the sleeve with the above implementation.

If xBits is the result of Double.doubleToRawLongBits(x):
xBits ^ (xBits - 1) is equal to -1 if and only if x == 0 or x == -0 and is 
non-negative otherwise
xBits ^ (xBits + 0x7FFF_FFFF_FFFF_FFFFL) is negative if and only if x != 0 and 
equal to Long.MAX_VALUE if x == 0

The rounded to odd result of a + b is used in a few places in the above 
implementation of fma to ensure that the final sum is correctly rounded.

Reply via email to