Module: Mesa
Branch: main
Commit: 037bbabcb968b8a911e90ce61c202c76d3cc7a67
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=037bbabcb968b8a911e90ce61c202c76d3cc7a67

Author: Mike Blumenkrantz <[email protected]>
Date:   Mon Oct 17 10:11:08 2022 -0400

zink: pass KERNEL shaders through successfully

basically just merging with COMPUTE cases

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19327>

---

 src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c |  6 ++++--
 src/gallium/drivers/zink/zink_compiler.c             | 17 +++++++++--------
 src/gallium/drivers/zink/zink_compiler.h             |  5 +++++
 src/gallium/drivers/zink/zink_descriptors.c          |  5 +++--
 4 files changed, 21 insertions(+), 12 deletions(-)

diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c 
b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index a0a93abb1fe..d640ed72526 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -3335,7 +3335,7 @@ emit_intrinsic(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
       break;
 
    case nir_intrinsic_control_barrier:
-      if (ctx->stage == MESA_SHADER_COMPUTE)
+      if (gl_shader_stage_is_compute(ctx->stage))
          spirv_builder_emit_control_barrier(&ctx->builder, SpvScopeWorkgroup,
                                             SpvScopeWorkgroup,
                                             
SpvMemorySemanticsWorkgroupMemoryMask | SpvMemorySemanticsAcquireReleaseMask);
@@ -4428,7 +4428,7 @@ nir_to_spirv(struct nir_shader *s, const struct 
zink_shader_info *sinfo, uint32_
    ctx.explicit_lod = true;
    spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageUnknown, 0);
 
-   if (s->info.stage == MESA_SHADER_COMPUTE) {
+   if (gl_shader_stage_is_compute(s->info.stage)) {
       SpvAddressingModel model;
       if (s->info.cs.ptr_size == 32)
          model = SpvAddressingModelPhysical32;
@@ -4474,6 +4474,7 @@ nir_to_spirv(struct nir_shader *s, const struct 
zink_shader_info *sinfo, uint32_
       exec_model = SpvExecutionModelFragment;
       break;
    case MESA_SHADER_COMPUTE:
+   case MESA_SHADER_KERNEL:
       exec_model = SpvExecutionModelGLCompute;
       break;
    default:
@@ -4597,6 +4598,7 @@ nir_to_spirv(struct nir_shader *s, const struct 
zink_shader_info *sinfo, uint32_
                                            SpvExecutionModeOutputVertices,
                                            MAX2(s->info.gs.vertices_out, 1));
       break;
+   case MESA_SHADER_KERNEL:
    case MESA_SHADER_COMPUTE:
       if (s->info.workgroup_size[0] || s->info.workgroup_size[1] || 
s->info.workgroup_size[2])
          spirv_builder_emit_exec_mode_literal3(&ctx.builder, entry_point, 
SpvExecutionModeLocalSize,
diff --git a/src/gallium/drivers/zink/zink_compiler.c 
b/src/gallium/drivers/zink/zink_compiler.c
index ddbe583f6cf..ddd36838a04 100644
--- a/src/gallium/drivers/zink/zink_compiler.c
+++ b/src/gallium/drivers/zink/zink_compiler.c
@@ -2245,7 +2245,7 @@ zink_shader_spirv_compile(struct zink_screen *screen, 
struct zink_shader *zs, st
       }
       nir_shader *nir = spirv_to_nir(spirv->words, spirv->num_words,
                          spec_entries, num_spec_entries,
-                         zs->nir->info.stage, "main", &spirv_options, 
&screen->nir_options);
+                         clamp_stage(zs->nir), "main", &spirv_options, 
&screen->nir_options);
       assert(nir);
       ralloc_free(nir);
       free(spec_entries);
@@ -2791,7 +2791,7 @@ zink_binding(gl_shader_stage stage, VkDescriptorType 
type, int index, bool compa
    } else {
       unsigned base = stage;
       /* clamp compute bindings for better driver efficiency */
-      if (stage == MESA_SHADER_COMPUTE)
+      if (gl_shader_stage_is_compute(stage))
          base = 0;
       switch (type) {
       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
@@ -3263,7 +3263,7 @@ zink_shader_create(struct zink_screen *screen, struct 
nir_shader *nir,
       subgroup_options.ballot_bit_size = 32;
       subgroup_options.ballot_components = 4;
       subgroup_options.lower_subgroup_masks = true;
-      if (!(screen->info.subgroup.supportedStages & 
mesa_to_vk_shader_stage(nir->info.stage))) {
+      if (!(screen->info.subgroup.supportedStages & 
mesa_to_vk_shader_stage(clamp_stage(nir)))) {
          subgroup_options.subgroup_size = 1;
          subgroup_options.lower_vote_trivial = true;
       }
@@ -3325,8 +3325,8 @@ zink_shader_create(struct zink_screen *screen, struct 
nir_shader *nir,
             ztype = ZINK_DESCRIPTOR_TYPE_UBO;
             /* buffer 0 is a push descriptor */
             var->data.descriptor_set = !!var->data.driver_location;
-            var->data.binding = !var->data.driver_location ? nir->info.stage :
-                                zink_binding(nir->info.stage,
+            var->data.binding = !var->data.driver_location ? clamp_stage(nir) :
+                                zink_binding(clamp_stage(nir),
                                              VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
                                              var->data.driver_location,
                                              screen->compact_descriptors);
@@ -3347,7 +3347,7 @@ zink_shader_create(struct zink_screen *screen, struct 
nir_shader *nir,
          } else if (var->data.mode == nir_var_mem_ssbo) {
             ztype = ZINK_DESCRIPTOR_TYPE_SSBO;
             var->data.descriptor_set = screen->desc_set_id[ztype];
-            var->data.binding = zink_binding(nir->info.stage,
+            var->data.binding = zink_binding(clamp_stage(nir),
                                              VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
                                              var->data.driver_location,
                                              screen->compact_descriptors);
@@ -3370,7 +3370,7 @@ zink_shader_create(struct zink_screen *screen, struct 
nir_shader *nir,
                   ret->num_texel_buffers++;
                var->data.driver_location = var->data.binding;
                var->data.descriptor_set = screen->desc_set_id[ztype];
-               var->data.binding = zink_binding(nir->info.stage, vktype, 
var->data.driver_location, screen->compact_descriptors);
+               var->data.binding = zink_binding(clamp_stage(nir), vktype, 
var->data.driver_location, screen->compact_descriptors);
                ret->bindings[ztype][ret->num_bindings[ztype]].index = 
var->data.driver_location;
                ret->bindings[ztype][ret->num_bindings[ztype]].binding = 
var->data.binding;
                ret->bindings[ztype][ret->num_bindings[ztype]].type = vktype;
@@ -3389,7 +3389,8 @@ zink_shader_create(struct zink_screen *screen, struct 
nir_shader *nir,
 
    if (!screen->info.feats.features.shaderInt64 || 
!screen->info.feats.features.shaderFloat64)
       NIR_PASS_V(nir, lower_64bit_vars, 
screen->info.feats.features.shaderInt64);
-   NIR_PASS_V(nir, match_tex_dests);
+   if (nir->info.stage != MESA_SHADER_KERNEL)
+      NIR_PASS_V(nir, match_tex_dests);
 
    ret->nir = nir;
    nir_foreach_shader_out_variable(var, nir)
diff --git a/src/gallium/drivers/zink/zink_compiler.h 
b/src/gallium/drivers/zink/zink_compiler.h
index 21f6bab1ef1..1572aa3b239 100644
--- a/src/gallium/drivers/zink/zink_compiler.h
+++ b/src/gallium/drivers/zink/zink_compiler.h
@@ -40,6 +40,11 @@ struct spirv_shader;
 
 struct tgsi_token;
 
+static inline gl_shader_stage
+clamp_stage(nir_shader *nir)
+{
+   return nir->info.stage == MESA_SHADER_KERNEL ? MESA_SHADER_COMPUTE : 
nir->info.stage;
+}
 
 const void *
 zink_get_compiler_options(struct pipe_screen *screen,
diff --git a/src/gallium/drivers/zink/zink_descriptors.c 
b/src/gallium/drivers/zink/zink_descriptors.c
index e9caa26d313..00b07465a75 100644
--- a/src/gallium/drivers/zink/zink_descriptors.c
+++ b/src/gallium/drivers/zink/zink_descriptors.c
@@ -26,6 +26,7 @@
  */
 
 #include "zink_context.h"
+#include "zink_compiler.h"
 #include "zink_descriptors.h"
 #include "zink_program.h"
 #include "zink_render_pass.h"
@@ -308,7 +309,7 @@ init_template_entry(struct zink_shader *shader, enum 
zink_descriptor_type type,
                     unsigned idx, VkDescriptorUpdateTemplateEntry *entry, 
unsigned *entry_idx)
 {
     int index = shader->bindings[type][idx].index;
-    gl_shader_stage stage = shader->nir->info.stage;
+    gl_shader_stage stage = clamp_stage(shader->nir);
     entry->dstArrayElement = 0;
     entry->dstBinding = shader->bindings[type][idx].binding;
     entry->descriptorCount = shader->bindings[type][idx].size;
@@ -423,7 +424,7 @@ zink_descriptor_program_init(struct zink_context *ctx, 
struct zink_program *pg)
       if (!shader)
          continue;
 
-      gl_shader_stage stage = shader->nir->info.stage;
+      gl_shader_stage stage = clamp_stage(shader->nir);
       VkShaderStageFlagBits stage_flags = mesa_to_vk_shader_stage(stage);
       /* uniform ubos handled in push */
       if (shader->has_uniforms) {

Reply via email to