Module: Mesa
Branch: main
Commit: 35196b6d89169d0af3f0e61d711a59a58850d199
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=35196b6d89169d0af3f0e61d711a59a58850d199

Author: Rhys Perry <[email protected]>
Date:   Mon Jan 17 17:48:33 2022 +0000

aco: combine add/mul as v_fma_mix into fma

fossil-db (Sienna Cichlid):
Totals from 7345 (5.44% of 134913) affected shaders:
CodeSize: 73840060 -> 73768936 (-0.10%); split: -0.10%, +0.00%
Instrs: 13701603 -> 13684183 (-0.13%); split: -0.13%, +0.00%
Latency: 185389373 -> 185306538 (-0.04%); split: -0.04%, +0.00%
InvThroughput: 33785020 -> 33757593 (-0.08%); split: -0.08%, +0.00%
VClause: 237337 -> 237338 (+0.00%)
SClause: 485728 -> 485720 (-0.00%)
Copies: 935900 -> 935279 (-0.07%); split: -0.07%, +0.00%
Branches: 480721 -> 480722 (+0.00%)

fossil-db (Navi):
Totals from 10649 (7.89% of 134913) affected shaders:
VGPRs: 756624 -> 756516 (-0.01%); split: -0.02%, +0.01%
CodeSize: 92156580 -> 91707900 (-0.49%); split: -0.49%, +0.00%
MaxWaves: 159402 -> 159476 (+0.05%); split: +0.07%, -0.02%
Instrs: 17155827 -> 17070449 (-0.50%); split: -0.50%, +0.00%
Latency: 246296456 -> 245487120 (-0.33%); split: -0.33%, +0.00%
InvThroughput: 41438159 -> 41117424 (-0.77%); split: -0.77%, +0.00%
VClause: 323790 -> 323867 (+0.02%); split: -0.00%, +0.03%
SClause: 612077 -> 612034 (-0.01%); split: -0.01%, +0.00%
Copies: 1103012 -> 1102775 (-0.02%); split: -0.03%, +0.01%
Branches: 555893 -> 555896 (+0.00%); split: -0.00%, +0.00%
PreSGPRs: 824372 -> 824378 (+0.00%)
PreVGPRs: 740390 -> 740363 (-0.00%); split: -0.01%, +0.01%

fossil-db (Vega):
Totals from 10950 (8.11% of 135048) affected shaders:
SGPRs: 1034528 -> 1034560 (+0.00%)
VGPRs: 794092 -> 794104 (+0.00%); split: -0.01%, +0.01%
CodeSize: 94409768 -> 93955568 (-0.48%); split: -0.48%, +0.00%
MaxWaves: 38950 -> 38939 (-0.03%); split: +0.00%, -0.03%
Instrs: 18162637 -> 18070934 (-0.50%); split: -0.51%, +0.00%
Latency: 291718455 -> 290772451 (-0.32%); split: -0.32%, +0.00%
InvThroughput: 109114674 -> 108489767 (-0.57%); split: -0.57%, +0.00%
VClause: 334498 -> 334579 (+0.02%); split: -0.01%, +0.03%
SClause: 628871 -> 628825 (-0.01%); split: -0.01%, +0.00%
Copies: 1674477 -> 1674850 (+0.02%); split: -0.02%, +0.04%
PreSGPRs: 834800 -> 834802 (+0.00%)
PreVGPRs: 750460 -> 750415 (-0.01%); split: -0.01%, +0.01%

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Daniel Schürmann <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>

---

 src/amd/compiler/aco_optimizer.cpp | 137 ++++++++++++++++++++++++++++---------
 1 file changed, 103 insertions(+), 34 deletions(-)

diff --git a/src/amd/compiler/aco_optimizer.cpp 
b/src/amd/compiler/aco_optimizer.cpp
index 6e9ca7a69f5..d99a97e7d9e 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -3689,18 +3689,25 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
    }
 
    /* combine mul+add -> mad */
+   bool is_add_mix =
+      (instr->opcode == aco_opcode::v_fma_mix_f32 ||
+       instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
+      !instr->vop3p().neg_lo[0] &&
+      ((instr->operands[0].constantEquals(0x3f800000) && 
(instr->vop3p().opsel_hi & 0x1) == 0) ||
+       (instr->operands[0].constantEquals(0x3C00) && (instr->vop3p().opsel_hi 
& 0x1) &&
+        !(instr->vop3p().opsel_lo & 0x1)));
    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == 
aco_opcode::v_sub_f32 ||
                 instr->opcode == aco_opcode::v_subrev_f32;
    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == 
aco_opcode::v_sub_f16 ||
                 instr->opcode == aco_opcode::v_subrev_f16;
    bool mad64 = instr->opcode == aco_opcode::v_add_f64;
-   if (mad16 || mad32 || mad64) {
+   if (is_add_mix || mad16 || mad32 || mad64) {
       Instruction* mul_instr = nullptr;
       unsigned add_op_idx = 0;
       uint32_t uses = UINT32_MAX;
       bool emit_fma = false;
       /* find the 'best' mul instruction to combine with the add */
-      for (unsigned i = 0; i < 2; i++) {
+      for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
          if (!instr->operands[i].isTemp() || 
!ctx.info[instr->operands[i].tempId()].is_mul())
             continue;
          ssa_info& info = ctx.info[instr->operands[i].tempId()];
@@ -3708,26 +3715,39 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
          /* no clamp/omod allowed between mul and add */
          if (info.instr->isVOP3() && (info.instr->vop3().clamp || 
info.instr->vop3().omod))
             continue;
-         if (info.instr->isVOP3P())
+         if (info.instr->isVOP3P() && info.instr->vop3p().clamp)
+            continue;
+         /* v_fma_mix_f32/etc can't do omod */
+         if (info.instr->isVOP3P() && instr->isVOP3() && instr->vop3().omod)
+            continue;
+         /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions 
*/
+         if (is_add_mix && info.instr->definitions[0].bytes() == 2)
             continue;
 
          if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() 
* 8)
             continue;
 
          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
+         bool mad_mix = is_add_mix || info.instr->isVOP3P();
 
          bool has_fma = mad16 || mad64 || (legacy && ctx.program->chip_class 
>= GFX10_3) ||
-                        (mad32 && !legacy && ctx.program->dev.has_fast_fma32);
-         bool has_mad = (mad32 && ctx.program->chip_class < GFX10_3) ||
-                        (mad16 && ctx.program->chip_class <= GFX9);
+                        (mad32 && !legacy && !mad_mix && 
ctx.program->dev.has_fast_fma32) ||
+                        (mad_mix && ctx.program->dev.fused_mad_mix);
+         bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
+                                : ((mad32 && ctx.program->chip_class < 
GFX10_3) ||
+                                   (mad16 && ctx.program->chip_class <= GFX9));
          bool can_use_fma = has_fma && !info.instr->definitions[0].isPrecise() 
&&
                             !instr->definitions[0].isPrecise();
          bool can_use_mad =
-            has_mad && (mad32 ? ctx.fp_mode.denorm32 : 
ctx.fp_mode.denorm16_64) == 0;
+            has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : 
ctx.fp_mode.denorm16_64) == 0;
+         if (mad_mix && legacy)
+            continue;
          if (!can_use_fma && !can_use_mad)
             continue;
 
-         Operand op[3] = {info.instr->operands[0], info.instr->operands[1], 
instr->operands[1 - i]};
+         unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
+         Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
+                          instr->operands[candidate_add_op_idx]};
          if (info.instr->isSDWA() || info.instr->isDPP() || 
!check_vop3_operands(ctx, 3, op) ||
              ctx.uses[instr->operands[i].tempId()] > uses)
             continue;
@@ -3740,7 +3760,7 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
          }
 
          mul_instr = info.instr;
-         add_op_idx = 1 - i;
+         add_op_idx = candidate_add_op_idx;
          uses = ctx.uses[instr->operands[i].tempId()];
          emit_fma = !can_use_mad;
       }
@@ -3761,6 +3781,8 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
          bool abs[3] = {false, false, false};
          unsigned omod = 0;
          bool clamp = false;
+         uint8_t opsel_lo = 0;
+         uint8_t opsel_hi = 0;
 
          if (mul_instr->isVOP3()) {
             VOP3_instruction& vop3 = mul_instr->vop3();
@@ -3768,6 +3790,14 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
             neg[1] = vop3.neg[1];
             abs[0] = vop3.abs[0];
             abs[1] = vop3.abs[1];
+         } else if (mul_instr->isVOP3P()) {
+            VOP3P_instruction& vop3p = mul_instr->vop3p();
+            neg[0] = vop3p.neg_lo[0];
+            neg[1] = vop3p.neg_lo[1];
+            abs[0] = vop3p.neg_hi[0];
+            abs[1] = vop3p.neg_hi[1];
+            opsel_lo = vop3p.opsel_lo & 0x3;
+            opsel_hi = vop3p.opsel_hi & 0x3;
          }
 
          if (instr->isVOP3()) {
@@ -3785,41 +3815,79 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
             }
             /* neg of the multiplication result */
             neg[1] = neg[1] ^ vop3.neg[1 - add_op_idx];
+         } else if (instr->isVOP3P()) {
+            VOP3P_instruction& vop3p = instr->vop3p();
+            neg[2] = vop3p.neg_lo[add_op_idx];
+            abs[2] = vop3p.neg_hi[add_op_idx];
+            opsel_lo |= vop3p.opsel_lo & (1 << add_op_idx) ? 0x4 : 0x0;
+            opsel_hi |= vop3p.opsel_hi & (1 << add_op_idx) ? 0x4 : 0x0;
+            clamp = vop3p.clamp;
+            /* abs of the multiplication result */
+            if (vop3p.neg_hi[3 - add_op_idx]) {
+               neg[0] = false;
+               neg[1] = false;
+               abs[0] = true;
+               abs[1] = true;
+            }
+            /* neg of the multiplication result */
+            neg[1] = neg[1] ^ vop3p.neg_lo[3 - add_op_idx];
          }
+
          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == 
aco_opcode::v_sub_f16)
             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
                   instr->opcode == aco_opcode::v_subrev_f16)
             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
 
-         aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : 
aco_opcode::v_mad_f32;
-         if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
-            assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
-            mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : 
aco_opcode::v_mad_legacy_f32;
-         } else if (mad16) {
-            mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_fma_legacy_f16
-                                                                 : 
aco_opcode::v_fma_f16)
-                              : (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_mad_legacy_f16
-                                                                 : 
aco_opcode::v_mad_f16);
-         } else if (mad64) {
-            mad_op = aco_opcode::v_fma_f64;
-         }
+         aco_ptr<Instruction> add_instr = std::move(instr);
+         if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
+            assert(!omod);
+
+            aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? 
aco_opcode::v_fma_mixlo_f16
+                                                                       : 
aco_opcode::v_fma_mix_f32;
+            aco_ptr<VOP3P_instruction> mad{
+               create_instruction<VOP3P_instruction>(mad_op, Format::VOP3P, 3, 
1)};
+            for (unsigned i = 0; i < 3; i++) {
+               mad->operands[i] = op[i];
+               mad->neg_lo[i] = neg[i];
+               mad->neg_hi[i] = abs[i];
+            }
+            mad->clamp = clamp;
+            mad->opsel_lo = opsel_lo;
+            mad->opsel_hi = opsel_hi;
 
-         aco_ptr<VOP3_instruction> mad{
-            create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
-         for (unsigned i = 0; i < 3; i++) {
-            mad->operands[i] = op[i];
-            mad->neg[i] = neg[i];
-            mad->abs[i] = abs[i];
+            instr = std::move(mad);
+         } else {
+            aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : 
aco_opcode::v_mad_f32;
+            if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
+               assert(emit_fma == (ctx.program->chip_class >= GFX10_3));
+               mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : 
aco_opcode::v_mad_legacy_f32;
+            } else if (mad16) {
+               mad_op = emit_fma ? (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_fma_legacy_f16
+                                                                    : 
aco_opcode::v_fma_f16)
+                                 : (ctx.program->chip_class == GFX8 ? 
aco_opcode::v_mad_legacy_f16
+                                                                    : 
aco_opcode::v_mad_f16);
+            } else if (mad64) {
+               mad_op = aco_opcode::v_fma_f64;
+            }
+
+            aco_ptr<VOP3_instruction> mad{
+               create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 
1)};
+            for (unsigned i = 0; i < 3; i++) {
+               mad->operands[i] = op[i];
+               mad->neg[i] = neg[i];
+               mad->abs[i] = abs[i];
+            }
+            mad->omod = omod;
+            mad->clamp = clamp;
+
+            instr = std::move(mad);
          }
-         mad->omod = omod;
-         mad->clamp = clamp;
-         mad->definitions[0] = instr->definitions[0];
+         instr->definitions[0] = add_instr->definitions[0];
 
          /* mark this ssa_def to be re-checked for profitability and literals 
*/
-         ctx.mad_infos.emplace_back(std::move(instr), 
mul_instr->definitions[0].tempId());
-         ctx.info[mad->definitions[0].tempId()].set_mad(mad.get(), 
ctx.mad_infos.size() - 1);
-         instr = std::move(mad);
+         ctx.mad_infos.emplace_back(std::move(add_instr), 
mul_instr->definitions[0].tempId());
+         ctx.info[instr->definitions[0].tempId()].set_mad(instr.get(), 
ctx.mad_infos.size() - 1);
          return;
       }
    }
@@ -4084,7 +4152,8 @@ select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
          mad_info = NULL;
       }
       /* check literals */
-      else if (!instr->usesModifiers() && instr->opcode != 
aco_opcode::v_fma_f64 &&
+      else if (!instr->usesModifiers() && !instr->isVOP3P() &&
+               instr->opcode != aco_opcode::v_fma_f64 &&
                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
                instr->opcode != aco_opcode::v_fma_legacy_f32) {
          /* FMA can only take literals on GFX10+ */

Reply via email to