Use softfloat-parts.h so that we can more naturally perform the required operations witha single rounding step. This happens to also simplify the NaN detection step.
Signed-off-by: Richard Henderson <[email protected]> --- target/arm/tcg/sme_helper.c | 96 ++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 56 deletions(-) diff --git a/target/arm/tcg/sme_helper.c b/target/arm/tcg/sme_helper.c index 0702a1b129..16b96890cc 100644 --- a/target/arm/tcg/sme_helper.c +++ b/target/arm/tcg/sme_helper.c @@ -27,6 +27,7 @@ #include "accel/tcg/helper-retaddr.h" #include "qemu/int128.h" #include "fpu/softfloat.h" +#include "fpu/softfloat-parts.h" #include "vec_internal.h" #include "sve_ldst_internal.h" @@ -1227,18 +1228,15 @@ static inline uint32_t bf16mop_ah_neg_adj_pair(uint32_t pair, uint32_t pg) } static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2, - float_status *s_f16, float_status *s_std, - float_status *s_odd) + float_status *s_f16, float_status *s_std) { /* - * We need three different float_status for different parts of this + * We need two different float_status for different parts of this * operation: * - the input conversion of the float16 values must use the * f16-specific float_status, so that the FPCR.FZ16 control is applied * - operations on float32 including the final accumulation must use * the normal float_status, so that FPCR.FZ is applied - * - we have pre-set-up copy of s_std which is set to round-to-odd, - * for the multiply (see below) */ float16 h1r = e1 & 0xffff; float16 h1c = e1 >> 16; @@ -1246,48 +1244,49 @@ static float32 f16_dotadd(float32 sum, uint32_t e1, uint32_t e2, float16 h2c = e2 >> 16; float32 t32; + FloatParts64 p1r = float16_unpack_canonical(h1r, s_f16); + FloatParts64 p1c = float16_unpack_canonical(h1c, s_f16); + FloatParts64 p2r = float16_unpack_canonical(h2r, s_f16); + FloatParts64 p2c = float16_unpack_canonical(h2c, s_f16); + + int all_mask = (float_cmask(p1r.cls) | float_cmask(p1c.cls) | + float_cmask(p1r.cls) | float_cmask(p1c.cls)); + /* C.f. FPProcessNaNs4 */ - if (float16_is_any_nan(h1r) || float16_is_any_nan(h1c) || - float16_is_any_nan(h2r) || float16_is_any_nan(h2c)) { + if (unlikely(all_mask & float_cmask_anynan)) { float16 t16; - if (float16_is_signaling_nan(h1r, s_f16)) { - t16 = h1r; - } else if (float16_is_signaling_nan(h1c, s_f16)) { - t16 = h1c; - } else if (float16_is_signaling_nan(h2r, s_f16)) { - t16 = h2r; - } else if (float16_is_signaling_nan(h2c, s_f16)) { - t16 = h2c; - } else if (float16_is_any_nan(h1r)) { - t16 = h1r; - } else if (float16_is_any_nan(h1c)) { - t16 = h1c; - } else if (float16_is_any_nan(h2r)) { - t16 = h2r; + if (unlikely(all_mask & float_cmask_snan)) { + if (p1r.cls == float_class_snan) { + t16 = h1r; + } else if (p1c.cls == float_class_snan) { + t16 = h1c; + } else if (p2r.cls == float_class_snan) { + t16 = h2r; + } else { + t16 = h2c; + } } else { - t16 = h2c; + if (p1r.cls == float_class_qnan) { + t16 = h1r; + } else if (p1c.cls == float_class_qnan) { + t16 = h1c; + } else if (p2r.cls == float_class_qnan) { + t16 = h2r; + } else { + t16 = h2c; + } } t32 = float16_to_float32(t16, true, s_f16); } else { - float64 e1r = float16_to_float64(h1r, true, s_f16); - float64 e1c = float16_to_float64(h1c, true, s_f16); - float64 e2r = float16_to_float64(h2r, true, s_f16); - float64 e2c = float16_to_float64(h2c, true, s_f16); - float64 t64; - /* * The ARM pseudocode function FPDot performs both multiplies - * and the add with a single rounding operation. Emulate this - * by performing the first multiply in round-to-odd, then doing - * the second multiply as fused multiply-add, and rounding to - * float32 all in one step. + * and the add with a single rounding operation. */ - t64 = float64_mul(e1r, e2r, s_odd); - t64 = float64r32_muladd(e1c, e2c, t64, 0, s_std); + FloatParts64 tmp = parts64_mul(&p1r, &p2r, s_std); + tmp = parts64_muladd(&p1c, &p2c, &tmp, 0, s_std); - /* This conversion is exact, because we've already rounded. */ - t32 = float64_to_float32(t64, s_std); + t32 = float32_round_pack_canonical(&tmp, s_std); } /* The final accumulation step is not fused. */ @@ -1299,9 +1298,6 @@ static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn, uint32_t negx, bool ah_neg) { intptr_t row, col, oprsz = simd_maxsz(desc); - float_status fpst_odd = env->vfp.fp_status[FPST_ZA]; - - set_float_rounding_mode(float_round_to_odd, &fpst_odd); for (row = 0; row < oprsz; ) { uint16_t prow = pn[H2(row >> 4)]; @@ -1325,8 +1321,7 @@ static void do_fmopa_w_h(void *vza, void *vzn, void *vzm, uint16_t *pn, m = f16mop_adj_pair(m, pcol, 0); *a = f16_dotadd(*a, n, m, &env->vfp.fp_status[FPST_ZA_F16], - &env->vfp.fp_status[FPST_ZA], - &fpst_odd); + &env->vfp.fp_status[FPST_ZA]); } col += 4; pcol >>= 4; @@ -1363,15 +1358,12 @@ void HELPER(sme2_fdot_h)(void *vd, void *vn, void *vm, void *va, bool za = extract32(desc, SIMD_DATA_SHIFT, 1); float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64]; float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16]; - float_status fpst_odd = *fpst_std; float32 *d = vd, *a = va; uint32_t *n = vn, *m = vm; - set_float_rounding_mode(float_round_to_odd, &fpst_odd); - for (i = 0; i < oprsz / sizeof(float32); ++i) { d[H4(i)] = f16_dotadd(a[H4(i)], n[H4(i)], m[H4(i)], - fpst_f16, fpst_std, &fpst_odd); + fpst_f16, fpst_std); } } @@ -1385,17 +1377,14 @@ void HELPER(sme2_fdot_idx_h)(void *vd, void *vn, void *vm, void *va, bool za = extract32(desc, SIMD_DATA_SHIFT + 2, 1); float_status *fpst_std = &env->vfp.fp_status[za ? FPST_ZA : FPST_A64]; float_status *fpst_f16 = &env->vfp.fp_status[za ? FPST_ZA_F16 : FPST_A64_F16]; - float_status fpst_odd = *fpst_std; float32 *d = vd, *a = va; uint32_t *n = vn, *m = (uint32_t *)vm + H4(idx); - set_float_rounding_mode(float_round_to_odd, &fpst_odd); - for (i = 0; i < elements; i += eltspersegment) { uint32_t mm = m[i]; for (j = 0; j < eltspersegment; ++j) { d[H4(i + j)] = f16_dotadd(a[H4(i + j)], n[H4(i + j)], mm, - fpst_f16, fpst_std, &fpst_odd); + fpst_f16, fpst_std); } } } @@ -1408,24 +1397,19 @@ void HELPER(sme2_fvdot_idx_h)(void *vd, void *vn, void *vm, void *va, intptr_t eltspersegment = MIN(4, elements); int idx = extract32(desc, SIMD_DATA_SHIFT, 2); int sel = extract32(desc, SIMD_DATA_SHIFT + 2, 1); - float_status fpst_odd, *fpst_std, *fpst_f16; float32 *d = vd, *a = va; uint16_t *n0 = vn; uint16_t *n1 = vn + sizeof(ARMVectorReg); uint32_t *m = (uint32_t *)vm + H4(idx); - fpst_std = &env->vfp.fp_status[FPST_ZA]; - fpst_f16 = &env->vfp.fp_status[FPST_ZA_F16]; - fpst_odd = *fpst_std; - set_float_rounding_mode(float_round_to_odd, &fpst_odd); - for (i = 0; i < elements; i += eltspersegment) { uint32_t mm = m[i]; for (j = 0; j < eltspersegment; ++j) { uint32_t nn = (n0[H2(2 * (i + j) + sel)]) | (n1[H2(2 * (i + j) + sel)] << 16); d[i + H4(j)] = f16_dotadd(a[i + H4(j)], nn, mm, - fpst_f16, fpst_std, &fpst_odd); + &env->vfp.fp_status[FPST_ZA_F16], + &env->vfp.fp_status[FPST_ZA]); } } } -- 2.43.0
