Module: Mesa Branch: main Commit: 44e82a6cf11317c2486f8470839cf9c8789326fc URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=44e82a6cf11317c2486f8470839cf9c8789326fc
Author: Samuel Pitoiset <[email protected]> Date: Fri Aug 4 17:43:20 2023 +0200 radv: add a helper to get the maximum number of scratch waves per shader Signed-off-by: Samuel Pitoiset <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24502> --- src/amd/vulkan/radv_pipeline.c | 4 +--- src/amd/vulkan/radv_shader.c | 8 ++++++++ src/amd/vulkan/radv_shader.h | 2 ++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c index 2420dc18497..203264b6851 100644 --- a/src/amd/vulkan/radv_pipeline.c +++ b/src/amd/vulkan/radv_pipeline.c @@ -134,9 +134,7 @@ radv_pipeline_init_scratch(const struct radv_device *device, struct radv_pipelin pipeline->scratch_bytes_per_wave = MAX2(pipeline->scratch_bytes_per_wave, shader->config.scratch_bytes_per_wave); - unsigned max_stage_waves = device->scratch_waves; - max_stage_waves = MIN2(max_stage_waves, 4 * device->physical_device->rad_info.num_cu * - radv_get_max_waves(device, shader, shader->info.stage)); + const unsigned max_stage_waves = radv_get_max_scratch_waves(device, shader); pipeline->max_waves = MAX2(pipeline->max_waves, max_stage_waves); } diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index ea001f7f0f9..dcdfacb0976 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -2832,6 +2832,14 @@ radv_get_max_waves(const struct radv_device *device, struct radv_shader *shader, return gfx_level >= GFX10 ? max_simd_waves * (wave_size / 32) : max_simd_waves; } +unsigned +radv_get_max_scratch_waves(const struct radv_device *device, struct radv_shader *shader) +{ + const unsigned num_cu = device->physical_device->rad_info.num_cu; + + return MIN2(device->scratch_waves, 4 * num_cu * radv_get_max_waves(device, shader, shader->info.stage)); +} + unsigned radv_compute_spi_ps_input(const struct radv_pipeline_key *pipeline_key, const struct radv_shader_info *info) { diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h index 72436d4a1e1..6b02b14a94e 100644 --- a/src/amd/vulkan/radv_shader.h +++ b/src/amd/vulkan/radv_shader.h @@ -699,6 +699,8 @@ struct radv_shader *radv_find_shader(struct radv_device *device, uint64_t pc); unsigned radv_get_max_waves(const struct radv_device *device, struct radv_shader *shader, gl_shader_stage stage); +unsigned radv_get_max_scratch_waves(const struct radv_device *device, struct radv_shader *shader); + const char *radv_get_shader_name(const struct radv_shader_info *info, gl_shader_stage stage); unsigned radv_compute_spi_ps_input(const struct radv_pipeline_key *pipeline_key, const struct radv_shader_info *info);
