Module: Mesa
Branch: main
Commit: 4035b0fa42ff886c4b31656ea9c1e1f347b16ba3
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=4035b0fa42ff886c4b31656ea9c1e1f347b16ba3

Author: Konstantin Seurer <[email protected]>
Date:   Wed Jul 20 19:23:11 2022 +0200

radv: Use a lds stack for ray queries when possible

Signed-off-by: Konstantin Seurer <[email protected]>
Reviewed-by: Bas Nieuwenhuizen <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17663>

---

 src/amd/vulkan/radv_nir_lower_ray_queries.c | 70 ++++++++++++++++++++++-------
 src/amd/vulkan/radv_shader.c                | 10 +++++
 2 files changed, 64 insertions(+), 16 deletions(-)

diff --git a/src/amd/vulkan/radv_nir_lower_ray_queries.c 
b/src/amd/vulkan/radv_nir_lower_ray_queries.c
index e6786b09bcd..579037984cc 100644
--- a/src/amd/vulkan/radv_nir_lower_ray_queries.c
+++ b/src/amd/vulkan/radv_nir_lower_ray_queries.c
@@ -34,7 +34,8 @@
 /* Traversal stack size. Traversal supports backtracking so we can go deeper 
than this size if
  * needed. However, we keep a large stack size to avoid it being put into 
registers, which hurts
  * occupancy. */
-#define MAX_STACK_ENTRY_COUNT 76
+#define MAX_SCRATCH_STACK_ENTRY_COUNT 76
+#define MAX_SHARED_STACK_ENTRY_COUNT  8
 
 typedef struct {
    nir_variable *variable;
@@ -176,6 +177,7 @@ struct ray_query_vars {
    struct ray_query_traversal_vars trav;
 
    rq_variable *stack;
+   uint32_t shared_base;
 };
 
 #define VAR_NAME(name)                                                         
                    \
@@ -244,7 +246,7 @@ init_ray_query_intersection_vars(void *ctx, nir_shader 
*shader, unsigned array_l
 
 static void
 init_ray_query_vars(nir_shader *shader, unsigned array_length, struct 
ray_query_vars *dst,
-                    const char *base_name)
+                    const char *base_name, uint32_t max_shared_size)
 {
    void *ctx = dst;
    const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
@@ -268,16 +270,27 @@ init_ray_query_vars(nir_shader *shader, unsigned 
array_length, struct ray_query_
 
    dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, 
VAR_NAME("_top"));
 
-   dst->stack = rq_variable_create(dst, shader, array_length,
-                                   glsl_array_type(glsl_uint_type(), 
MAX_STACK_ENTRY_COUNT,
-                                                   
glsl_get_explicit_stride(glsl_uint_type())),
-                                   VAR_NAME("_stack"));
+   uint32_t workgroup_size = shader->info.workgroup_size[0] * 
shader->info.workgroup_size[1] *
+                             shader->info.workgroup_size[2];
+   uint32_t shared_stack_size = workgroup_size * MAX_SHARED_STACK_ENTRY_COUNT 
* 4;
+   uint32_t shared_offset = align(shader->info.shared_size, 4);
+   if (shader->info.stage != MESA_SHADER_COMPUTE || array_length > 1 ||
+       shared_offset + shared_stack_size > max_shared_size) {
+      dst->stack = rq_variable_create(
+         dst, shader, array_length,
+         glsl_array_type(glsl_uint_type(), MAX_SCRATCH_STACK_ENTRY_COUNT, 0), 
VAR_NAME("_stack"));
+   } else {
+      dst->stack = NULL;
+      dst->shared_base = shared_offset;
+      shader->info.shared_size = shared_offset + shared_stack_size;
+   }
 }
 
 #undef VAR_NAME
 
 static void
-lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table 
*ht)
+lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table 
*ht,
+                uint32_t max_shared_size)
 {
    struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
 
@@ -285,7 +298,8 @@ lower_ray_query(nir_shader *shader, nir_variable 
*ray_query, struct hash_table *
    if (glsl_type_is_array(ray_query->type))
       array_length = glsl_get_length(ray_query->type);
 
-   init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? 
"" : ray_query->name);
+   init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? 
"" : ray_query->name,
+                       max_shared_size);
 
    _mesa_hash_table_insert(ht, ray_query, vars);
 }
@@ -385,7 +399,17 @@ lower_rq_initialize(nir_builder *b, nir_ssa_def *index, 
nir_intrinsic_instr *ins
       rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1);
       rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1);
 
-      rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1);
+      if (vars->stack) {
+         rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1);
+         rq_store_var(b, index, vars->trav.stack_base, nir_imm_int(b, 0), 0x1);
+      } else {
+         nir_ssa_def *base_offset =
+            nir_imul_imm(b, nir_load_local_invocation_index(b), 
sizeof(uint32_t));
+         base_offset = nir_iadd_imm(b, base_offset, vars->shared_base);
+         rq_store_var(b, index, vars->trav.stack, base_offset, 0x1);
+         rq_store_var(b, index, vars->trav.stack_base, base_offset, 0x1);
+      }
+
       rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, 
RADV_BVH_ROOT_NODE), 0x1);
       rq_store_var(b, index, vars->trav.previous_node, nir_imm_int(b, 
RADV_BVH_INVALID_NODE), 0x1);
       rq_store_var(b, index, vars->trav.instance_top_node, nir_imm_int(b, 
RADV_BVH_INVALID_NODE),
@@ -393,7 +417,6 @@ lower_rq_initialize(nir_builder *b, nir_ssa_def *index, 
nir_intrinsic_instr *ins
       rq_store_var(b, index, vars->trav.instance_bottom_node, nir_imm_int(b, 
RADV_BVH_NO_INSTANCE_ROOT), 0x1);
 
       rq_store_var(b, index, vars->trav.top_stack, nir_imm_int(b, -1), 1);
-      rq_store_var(b, index, vars->trav.stack_base, nir_imm_int(b, 0), 1);
    }
    nir_push_else(b, NULL);
    {
@@ -614,14 +637,20 @@ store_stack_entry(nir_builder *b, nir_ssa_def *index, 
nir_ssa_def *value,
                   const struct radv_ray_traversal_args *args)
 {
    struct traversal_data *data = args->data;
-   rq_store_array(b, data->index, data->vars->stack, index, value, 1);
+   if (data->vars->stack)
+      rq_store_array(b, data->index, data->vars->stack, index, value, 1);
+   else
+      nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
 }
 
 static nir_ssa_def *
 load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct 
radv_ray_traversal_args *args)
 {
    struct traversal_data *data = args->data;
-   return rq_load_array(b, data->index, data->vars->stack, index);
+   if (data->vars->stack)
+      return rq_load_array(b, data->index, data->vars->stack, index);
+   else
+      return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
 }
 
 static nir_ssa_def *
@@ -658,8 +687,6 @@ lower_rq_proceed(nir_builder *b, nir_ssa_def *index, struct 
ray_query_vars *vars
       .tmin = rq_load_var(b, index, vars->tmin),
       .dir = rq_load_var(b, index, vars->direction),
       .vars = trav_vars,
-      .stack_stride = 1,
-      .stack_entries = MAX_STACK_ENTRY_COUNT,
       .stack_store_cb = store_stack_entry,
       .stack_load_cb = load_stack_entry,
       .aabb_cb = handle_candidate_aabb,
@@ -667,6 +694,17 @@ lower_rq_proceed(nir_builder *b, nir_ssa_def *index, 
struct ray_query_vars *vars
       .data = &data,
    };
 
+   if (vars->stack) {
+      args.stack_stride = 1;
+      args.stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT;
+   } else {
+      uint32_t workgroup_size = b->shader->info.workgroup_size[0] *
+                                b->shader->info.workgroup_size[1] *
+                                b->shader->info.workgroup_size[2];
+      args.stack_stride = workgroup_size * 4;
+      args.stack_entries = MAX_SHARED_STACK_ENTRY_COUNT;
+   }
+
    nir_push_if(b, rq_load_var(b, index, vars->incomplete));
    {
       nir_ssa_def *incomplete = radv_build_ray_traversal(device, b, &args);
@@ -695,7 +733,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
       if (!var->data.ray_query)
          continue;
 
-      lower_ray_query(shader, var, query_ht);
+      lower_ray_query(shader, var, query_ht, 
device->physical_device->max_shared_size);
       contains_ray_query = true;
    }
 
@@ -710,7 +748,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
          if (!var->data.ray_query)
             continue;
 
-         lower_ray_query(shader, var, query_ht);
+         lower_ray_query(shader, var, query_ht, 
device->physical_device->max_shared_size);
          contains_ray_query = true;
       }
 
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 8a6e7c0419f..6f7b6284948 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -977,6 +977,16 @@ radv_shader_spirv_to_nir(struct radv_device *device, const 
struct radv_pipeline_
    nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
 
    if (nir->info.ray_queries > 0) {
+      /* Lower shared variables early to prevent the over allocation of shared 
memory in
+       * radv_nir_lower_ray_queries.  */
+      if (nir->info.stage == MESA_SHADER_COMPUTE) {
+         if (!nir->info.shared_memory_explicit_layout)
+            NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, 
nir_var_mem_shared, shared_var_info);
+
+         NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_shared,
+                  nir_address_format_32bit_offset);
+      }
+
       NIR_PASS(_, nir, nir_opt_ray_queries);
       NIR_PASS(_, nir, nir_opt_ray_query_ranges);
       NIR_PASS(_, nir, radv_nir_lower_ray_queries, device);

Reply via email to