Module: Mesa
Branch: main
Commit: 44e82a6cf11317c2486f8470839cf9c8789326fc
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=44e82a6cf11317c2486f8470839cf9c8789326fc

Author: Samuel Pitoiset <[email protected]>
Date:   Fri Aug  4 17:43:20 2023 +0200

radv: add a helper to get the maximum number of scratch waves per shader

Signed-off-by: Samuel Pitoiset <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24502>

---

 src/amd/vulkan/radv_pipeline.c | 4 +---
 src/amd/vulkan/radv_shader.c   | 8 ++++++++
 src/amd/vulkan/radv_shader.h   | 2 ++
 3 files changed, 11 insertions(+), 3 deletions(-)

diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index 2420dc18497..203264b6851 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -134,9 +134,7 @@ radv_pipeline_init_scratch(const struct radv_device 
*device, struct radv_pipelin
 
    pipeline->scratch_bytes_per_wave = MAX2(pipeline->scratch_bytes_per_wave, 
shader->config.scratch_bytes_per_wave);
 
-   unsigned max_stage_waves = device->scratch_waves;
-   max_stage_waves = MIN2(max_stage_waves, 4 * 
device->physical_device->rad_info.num_cu *
-                                              radv_get_max_waves(device, 
shader, shader->info.stage));
+   const unsigned max_stage_waves = radv_get_max_scratch_waves(device, shader);
    pipeline->max_waves = MAX2(pipeline->max_waves, max_stage_waves);
 }
 
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index ea001f7f0f9..dcdfacb0976 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -2832,6 +2832,14 @@ radv_get_max_waves(const struct radv_device *device, 
struct radv_shader *shader,
    return gfx_level >= GFX10 ? max_simd_waves * (wave_size / 32) : 
max_simd_waves;
 }
 
+unsigned
+radv_get_max_scratch_waves(const struct radv_device *device, struct 
radv_shader *shader)
+{
+   const unsigned num_cu = device->physical_device->rad_info.num_cu;
+
+   return MIN2(device->scratch_waves, 4 * num_cu * radv_get_max_waves(device, 
shader, shader->info.stage));
+}
+
 unsigned
 radv_compute_spi_ps_input(const struct radv_pipeline_key *pipeline_key, const 
struct radv_shader_info *info)
 {
diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h
index 72436d4a1e1..6b02b14a94e 100644
--- a/src/amd/vulkan/radv_shader.h
+++ b/src/amd/vulkan/radv_shader.h
@@ -699,6 +699,8 @@ struct radv_shader *radv_find_shader(struct radv_device 
*device, uint64_t pc);
 
 unsigned radv_get_max_waves(const struct radv_device *device, struct 
radv_shader *shader, gl_shader_stage stage);
 
+unsigned radv_get_max_scratch_waves(const struct radv_device *device, struct 
radv_shader *shader);
+
 const char *radv_get_shader_name(const struct radv_shader_info *info, 
gl_shader_stage stage);
 
 unsigned radv_compute_spi_ps_input(const struct radv_pipeline_key 
*pipeline_key, const struct radv_shader_info *info);

Reply via email to