Module: Mesa
Branch: main
Commit: 59dbe633e3d243d65c06624345995d93758ee007
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=59dbe633e3d243d65c06624345995d93758ee007

Author: Rhys Perry <pendingchao...@gmail.com>
Date:   Wed Jan  3 17:52:52 2024 +0000

radv: use CS wave selection for task shaders

This uses wave32 for small workgroups and wave64 when certain subgroup
operations are used.

Signed-off-by: Rhys Perry <pendingchao...@gmail.com>
Reviewed-by: Daniel Schürmann <dan...@schuermann.dev>
Reviewed-by: Samuel Pitoiset <samuel.pitoi...@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26894>

---

 src/amd/vulkan/radv_shader_info.c | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/src/amd/vulkan/radv_shader_info.c 
b/src/amd/vulkan/radv_shader_info.c
index 9c74669ecc0..4270beab009 100644
--- a/src/amd/vulkan/radv_shader_info.c
+++ b/src/amd/vulkan/radv_shader_info.c
@@ -346,12 +346,10 @@ radv_get_wave_size(struct radv_device *device, 
gl_shader_stage stage, const stru
 
    if (stage == MESA_SHADER_GEOMETRY && !info->is_ngg)
       return 64;
-   else if (stage == MESA_SHADER_COMPUTE)
+   else if (stage == MESA_SHADER_COMPUTE || stage == MESA_SHADER_TASK)
       return info->cs.subgroup_size;
    else if (stage == MESA_SHADER_FRAGMENT)
       return device->physical_device->ps_wave_size;
-   else if (stage == MESA_SHADER_TASK)
-      return device->physical_device->cs_wave_size;
    else if (gl_shader_stage_is_rt(stage))
       return device->physical_device->rt_wave_size;
    else
@@ -933,10 +931,10 @@ gather_shader_info_cs(struct radv_device *device, const 
nir_shader *nir, const s
     * the subgroup size.
     */
    const bool require_full_subgroups =
-      pipeline_key->stage_info[MESA_SHADER_COMPUTE].subgroup_require_full || 
nir->info.cs.has_cooperative_matrix ||
+      pipeline_key->stage_info[nir->info.stage].subgroup_require_full || 
nir->info.cs.has_cooperative_matrix ||
       (default_wave_size == 32 && nir->info.uses_wide_subgroup_intrinsics && 
local_size % RADV_SUBGROUP_SIZE == 0);
 
-   const unsigned required_subgroup_size = 
pipeline_key->stage_info[MESA_SHADER_COMPUTE].subgroup_required_size * 32;
+   const unsigned required_subgroup_size = 
pipeline_key->stage_info[nir->info.stage].subgroup_required_size * 32;
 
    if (required_subgroup_size) {
       info->cs.subgroup_size = required_subgroup_size;
@@ -955,9 +953,11 @@ gather_shader_info_cs(struct radv_device *device, const 
nir_shader *nir, const s
 }
 
 static void
-gather_shader_info_task(const nir_shader *nir, const struct radv_pipeline_key 
*pipeline_key,
+gather_shader_info_task(struct radv_device *device, const nir_shader *nir, 
const struct radv_pipeline_key *pipeline_key,
                         struct radv_shader_info *info)
 {
+   gather_shader_info_cs(device, nir, pipeline_key, info);
+
    /* Task shaders always need these for the I/O lowering even if the API 
shader doesn't actually
     * use them.
     */
@@ -1196,7 +1196,7 @@ radv_nir_shader_info_pass(struct radv_device *device, 
const struct nir_shader *n
       gather_shader_info_cs(device, nir, pipeline_key, info);
       break;
    case MESA_SHADER_TASK:
-      gather_shader_info_task(nir, pipeline_key, info);
+      gather_shader_info_task(device, nir, pipeline_key, info);
       break;
    case MESA_SHADER_FRAGMENT:
       gather_shader_info_fs(device, nir, pipeline_key, info);

Reply via email to