Module: Mesa
Branch: main
Commit: b2067001d43fd83ae4942fd2412f6a75b3769167
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=b2067001d43fd83ae4942fd2412f6a75b3769167

Author: Friedrich Vock <[email protected]>
Date:   Thu Nov 16 14:34:24 2023 +0100

radv/rt: bsearch inlined shaders

When there are lots of inlined shaders, going over each one and checking
if the call index matches becomes expensive. Instead, use a
binary-search-like selection to skip most of the checks.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26380>

---

 src/amd/vulkan/radv_rt_shader.c | 54 +++++++++++++++++++++++++++++++++++------
 1 file changed, 47 insertions(+), 7 deletions(-)

diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c
index fbc0dbc790f..9ce5bc202c3 100644
--- a/src/amd/vulkan/radv_rt_shader.c
+++ b/src/amd/vulkan/radv_rt_shader.c
@@ -40,6 +40,9 @@
 
 #define RADV_RT_SWITCH_NULL_CHECK_THRESHOLD 3
 
+/* Minimum number of inlined shaders to use binary search to select which 
shader to run. */
+#define INLINED_SHADER_BSEARCH_THRESHOLD 16
+
 struct radv_rt_case_data {
    struct radv_device *device;
    struct radv_ray_tracing_pipeline *pipeline;
@@ -51,12 +54,44 @@ typedef void (*radv_get_group_info)(struct 
radv_ray_tracing_group *, uint32_t *,
 typedef void (*radv_insert_shader_case)(nir_builder *, nir_def *, struct 
radv_ray_tracing_group *,
                                         struct radv_rt_case_data *);
 
+struct inlined_shader_case {
+   struct radv_ray_tracing_group *group;
+   uint32_t call_idx;
+};
+
+static int
+compare_inlined_shader_case(const void *a, const void *b)
+{
+   const struct inlined_shader_case *visit_a = a;
+   const struct inlined_shader_case *visit_b = b;
+   return visit_a->call_idx > visit_b->call_idx ? 1 : visit_a->call_idx < 
visit_b->call_idx ? -1 : 0;
+}
+
+static void
+insert_inlined_range(nir_builder *b, nir_def *sbt_idx, radv_insert_shader_case 
shader_case,
+                     struct radv_rt_case_data *data, struct 
inlined_shader_case *cases, uint32_t length)
+{
+   if (length >= INLINED_SHADER_BSEARCH_THRESHOLD) {
+      nir_push_if(b, nir_ige_imm(b, sbt_idx, cases[length / 2].call_idx));
+      {
+         insert_inlined_range(b, sbt_idx, shader_case, data, cases + (length / 
2), length - (length / 2));
+      }
+      nir_push_else(b, NULL);
+      {
+         insert_inlined_range(b, sbt_idx, shader_case, data, cases, length / 
2);
+      }
+      nir_pop_if(b, NULL);
+   } else {
+      for (uint32_t i = 0; i < length; ++i)
+         shader_case(b, sbt_idx, cases[i].group, data);
+   }
+}
+
 static void
 radv_visit_inlined_shaders(nir_builder *b, nir_def *sbt_idx, bool 
can_have_null_shaders, struct radv_rt_case_data *data,
                            radv_get_group_info group_info, 
radv_insert_shader_case shader_case)
 {
-   struct radv_ray_tracing_group **groups =
-      calloc(data->pipeline->group_count, sizeof(struct radv_ray_tracing_group 
*));
+   struct inlined_shader_case *cases = calloc(data->pipeline->group_count, 
sizeof(struct inlined_shader_case));
    uint32_t case_count = 0;
 
    for (unsigned i = 0; i < data->pipeline->group_count; i++) {
@@ -81,23 +116,28 @@ radv_visit_inlined_shaders(nir_builder *b, nir_def 
*sbt_idx, bool can_have_null_
          }
       }
 
-      if (!duplicate)
-         groups[case_count++] = group;
+      if (!duplicate) {
+         cases[case_count++] = (struct inlined_shader_case){
+            .group = group,
+            .call_idx = handle_index,
+         };
+      }
    }
 
+   qsort(cases, case_count, sizeof(struct inlined_shader_case), 
compare_inlined_shader_case);
+
    /* Do not emit 'if (sbt_idx != 0) { ... }' is there are only a few cases. */
    can_have_null_shaders &= case_count >= RADV_RT_SWITCH_NULL_CHECK_THRESHOLD;
 
    if (can_have_null_shaders)
       nir_push_if(b, nir_ine_imm(b, sbt_idx, 0));
 
-   for (unsigned i = 0; i < case_count; i++)
-      shader_case(b, sbt_idx, groups[i], data);
+   insert_inlined_range(b, sbt_idx, shader_case, data, cases, case_count);
 
    if (can_have_null_shaders)
       nir_pop_if(b, NULL);
 
-   free(groups);
+   free(cases);
 }
 
 static bool

Reply via email to