Module: Mesa
Branch: staging/23.3
Commit: 35a94b16470254e944a4a7b1e8f7aa745a2cb76a
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=35a94b16470254e944a4a7b1e8f7aa745a2cb76a

Author: Bas Nieuwenhuizen <[email protected]>
Date:   Wed Dec 20 00:19:55 2023 +0100

radv: Use correct writemask for cooperative matrix ordering.

Not expecting this to actually fix anything externally visible,
but reduces some invalid usage when the resulting vector is
not 16 elements long (e.g. the C/result matrix).

Fixes: 9df4703fbb5 ("radv: Add cooperative matrix lowering.")
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26768>
(cherry picked from commit 07ad6fd34a6ed32b74a3f9697545261a3fd84de2)

---

 .pick_status.json                                    |  2 +-
 .../vulkan/nir/radv_nir_lower_cooperative_matrix.c   | 20 ++++++++++++--------
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/.pick_status.json b/.pick_status.json
index 84ffb2b364e..7da0abcc41e 100644
--- a/.pick_status.json
+++ b/.pick_status.json
@@ -264,7 +264,7 @@
         "description": "radv: Use correct writemask for cooperative matrix 
ordering.",
         "nominated": true,
         "nomination_type": 1,
-        "resolution": 0,
+        "resolution": 1,
         "main_sha": null,
         "because_sha": "9df4703fbb59d1295a9d3daf6320f329c9de2d66",
         "notes": null
diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c 
b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
index d81231b0137..e882100e141 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
@@ -181,7 +181,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
 
                nir_def *elem = intr->src[1].ssa;
                nir_def *r = nir_vector_insert(&b, src1, elem, index);
-               nir_store_deref(&b, dst_deref, r, 0xffff);
+               nir_store_deref(&b, dst_deref, r, 
nir_component_mask(r->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
@@ -193,7 +193,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
 
                nir_def *r = nir_replicate(&b, elem, radv_nir_cmat_length(desc, 
wave_size));
 
-               nir_store_deref(&b, dst_deref, r, 0xffff);
+               nir_store_deref(&b, dst_deref, r, 
nir_component_mask(r->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
@@ -253,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
                }
 
                nir_def *mat = nir_vec(&b, vars, length);
-               nir_store_deref(&b, dst_deref, mat, 0xffff);
+               nir_store_deref(&b, dst_deref, mat, 
nir_component_mask(mat->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
@@ -332,7 +332,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
                ret = nir_cmat_muladd_amd(&b, A, B, C, .saturate = 
nir_intrinsic_saturate(intr),
                                          .cmat_signed_mask = 
nir_intrinsic_cmat_signed_mask(intr));
 
-               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
+                               nir_component_mask(ret->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
@@ -366,7 +367,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
                   ret = nir_vec(&b, components, ret->num_components * 2);
                }
 
-               nir_store_deref(&b, dst_deref, ret, 0xffff);
+               nir_store_deref(&b, dst_deref, ret, 
nir_component_mask(ret->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
@@ -375,7 +376,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
                nir_op op = nir_intrinsic_alu_op(intr);
                nir_def *ret = nir_build_alu2(&b, op, src1, intr->src[2].ssa);
-               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
+                               nir_component_mask(ret->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
@@ -385,14 +387,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, 
unsigned wave_size)
                nir_def *src2 = radv_nir_load_cmat(&b, wave_size, 
intr->src[2].ssa);
                nir_op op = nir_intrinsic_alu_op(intr);
                nir_def *ret = nir_build_alu2(&b, op, src1, src2);
-               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret, 0xffff);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), ret,
+                               nir_component_mask(ret->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;
             }
             case nir_intrinsic_cmat_bitcast: {
                nir_def *src1 = radv_nir_load_cmat(&b, wave_size, 
intr->src[1].ssa);
-               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1, 0xffff);
+               nir_store_deref(&b, 
nir_instr_as_deref(intr->src[0].ssa->parent_instr), src1,
+                               nir_component_mask(src1->num_components));
                nir_instr_remove(instr);
                progress = true;
                break;

Reply via email to