Module: Mesa
Branch: main
Commit: eeef1bbe6530f9a86f71af355f13a37f65a6d6e3
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=eeef1bbe6530f9a86f71af355f13a37f65a6d6e3

Author: Rhys Perry <[email protected]>
Date:   Mon Jan 17 17:33:25 2022 +0000

aco: refactor selection of mad/fma

In the future, whether we need to use fma will depend on which
multiplication is chosen.

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Daniel Schürmann <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>

---

 src/amd/compiler/aco_optimizer.cpp | 32 ++++++++++++++++----------------
 1 file changed, 16 insertions(+), 16 deletions(-)

diff --git a/src/amd/compiler/aco_optimizer.cpp 
b/src/amd/compiler/aco_optimizer.cpp
index 448c3a744f5..a775fd29edb 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -3549,25 +3549,15 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
                 instr->opcode == aco_opcode::v_subrev_f16;
    bool mad64 = instr->opcode == aco_opcode::v_add_f64;
    if (mad16 || mad32 || mad64) {
-      bool need_fma =
-         mad32 ? (ctx.fp_mode.denorm32 != 0 || ctx.program->chip_class >= 
GFX10_3)
-               : (ctx.fp_mode.denorm16_64 != 0 || ctx.program->chip_class >= 
GFX10 || mad64);
-      if (need_fma && instr->definitions[0].isPrecise())
-         return;
-      if (need_fma && mad32 && !ctx.program->dev.has_fast_fma32)
-         return;
-
       Instruction* mul_instr = nullptr;
       unsigned add_op_idx = 0;
       uint32_t uses = UINT32_MAX;
+      bool emit_fma = false;
       /* find the 'best' mul instruction to combine with the add */
       for (unsigned i = 0; i < 2; i++) {
          if (!instr->operands[i].isTemp() || 
!ctx.info[instr->operands[i].tempId()].is_mul())
             continue;
-         /* check precision requirements */
          ssa_info& info = ctx.info[instr->operands[i].tempId()];
-         if (need_fma && info.instr->definitions[0].isPrecise())
-            continue;
 
          /* no clamp/omod allowed between mul and add */
          if (info.instr->isVOP3() && (info.instr->vop3().clamp || 
info.instr->vop3().omod))
@@ -3577,7 +3567,16 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
             continue;
 
          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
-         if (legacy && need_fma && ctx.program->chip_class < GFX10_3)
+
+         bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class 
>= GFX10_3) ||
+                        (mad32 && !legacy && ctx.program->dev.has_fast_fma32);
+         bool has_mad = (mad32 && ctx.program->chip_class < GFX10_3) ||
+                        (mad16 && ctx.program->chip_class <= GFX9);
+         bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() 
&&
+                            !instr->definitions[0].isPrecise();
+         bool can_use_mad =
+            has_mad && (mad32 ? ctx.fp_mode.denorm32 : 
ctx.fp_mode.denorm16_64) == 0;
+         if (!can_use_fma && !can_use_mad)
             continue;
 
          Operand op[3] = {info.instr->operands[0], info.instr->operands[1], 
instr->operands[1 - i]};
@@ -3595,6 +3594,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
          mul_instr = info.instr;
          add_op_idx = 1 - i;
          uses = ctx.uses[instr->operands[i].tempId()];
+         emit_fma = !can_use_mad;
       }
 
       if (mul_instr) {
@@ -3644,12 +3644,12 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
                   instr->opcode == aco_opcode::v_subrev_f16)
             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
 
-         aco_opcode mad_op = need_fma ? aco_opcode::v_fma_f32 : 
aco_opcode::v_mad_f32;
+         aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : 
aco_opcode::v_mad_f32;
          if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
-            assert(need_fma == (ctx.program->chip_class >= GFX10_3));
-            mad_op = need_fma ? aco_opcode::v_fma_legacy_f32 : 
aco_opcode::v_mad_legacy_f32;
+            assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
+            mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : 
aco_opcode::v_mad_legacy_f32;
          } else if (mad16) {
-            mad_op = need_fma ? (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_fma_legacy_f16
+            mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_fma_legacy_f16
                                                                  : 
aco_opcode::v_fma_f16)
                               : (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_mad_legacy_f16
                                                                  : 
aco_opcode::v_mad_f16);

Reply via email to