Module: Mesa
Branch: main
Commit: ea3119353218e5c0344ecef39ca26d0f43864d70
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=ea3119353218e5c0344ecef39ca26d0f43864d70

Author: Samuel Pitoiset <[email protected]>
Date:   Fri Aug  4 17:48:10 2023 +0200

radv: update the number of scratch waves for RT prolog at bind time

The compute scratch size is computed later because the RT stack size
can be dynamic, but the number of waves shouldn't change.

Signed-off-by: Samuel Pitoiset <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24502>

---

 src/amd/vulkan/radv_cmd_buffer.c | 3 +++
 src/amd/vulkan/radv_pipeline.c   | 2 +-
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c
index 6397985c17e..d795c36934c 100644
--- a/src/amd/vulkan/radv_cmd_buffer.c
+++ b/src/amd/vulkan/radv_cmd_buffer.c
@@ -6570,6 +6570,9 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, 
VkPipelineBindPoint pipeline
       /* Bind the stack size when it's not dynamic. */
       if (rt_pipeline->stack_size != -1u)
          cmd_buffer->state.rt_stack_size = rt_pipeline->stack_size;
+
+      const unsigned max_scratch_waves = 
radv_get_max_scratch_waves(cmd_buffer->device, rt_pipeline->prolog);
+      cmd_buffer->compute_scratch_waves_wanted = 
MAX2(cmd_buffer->compute_scratch_waves_wanted, max_scratch_waves);
       break;
    }
    case VK_PIPELINE_BIND_POINT_GRAPHICS: {
diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index 203264b6851..89de2271f2a 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -129,7 +129,7 @@ radv_DestroyPipeline(VkDevice _device, VkPipeline 
_pipeline, const VkAllocationC
 void
 radv_pipeline_init_scratch(const struct radv_device *device, struct 
radv_pipeline *pipeline, struct radv_shader *shader)
 {
-   if (!shader->config.scratch_bytes_per_wave && pipeline->type != 
RADV_PIPELINE_RAY_TRACING)
+   if (!shader->config.scratch_bytes_per_wave)
       return;
 
    pipeline->scratch_bytes_per_wave = MAX2(pipeline->scratch_bytes_per_wave, 
shader->config.scratch_bytes_per_wave);

Reply via email to