Module: Mesa
Branch: main
Commit: 6bc8567bb98e3fca1a786b334702bc99a21c56b0
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=6bc8567bb98e3fca1a786b334702bc99a21c56b0

Author: Faith Ekstrand <faith.ekstr...@collabora.com>
Date:   Wed Apr 19 11:52:33 2023 -0500

nir: Handle array-deref-of-vec in vars_to_ssa

Closes: https://gitlab.freedesktop.org/mesa/mesa/-/issues/7746
Reviewed-by: Karol Herbst <kher...@redhat.com>
Reviewed-by: Alyssa Rosenzweig <aly...@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22580>

---

 src/compiler/nir/nir_lower_vars_to_ssa.c | 63 ++++++++++++++++++++++++++++----
 1 file changed, 56 insertions(+), 7 deletions(-)

diff --git a/src/compiler/nir/nir_lower_vars_to_ssa.c 
b/src/compiler/nir/nir_lower_vars_to_ssa.c
index b9adc17c8ae..505c232bf89 100644
--- a/src/compiler/nir/nir_lower_vars_to_ssa.c
+++ b/src/compiler/nir/nir_lower_vars_to_ssa.c
@@ -173,7 +173,11 @@ get_deref_node_recur(nir_deref_instr *deref,
       return parent->children[deref->strct.index];
 
    case nir_deref_type_array: {
-      if (nir_src_is_const(deref->arr.index)) {
+      if (glsl_type_is_vector_or_scalar(parent->type)) {
+         /* For an array deref of a vector, return the vector */
+         assert(glsl_type_is_vector(parent->type));
+         return parent;
+      } else if (nir_src_is_const(deref->arr.index)) {
          uint32_t index = nir_src_as_uint(deref->arr.index);
          /* This is possible if a loop unrolls and generates an
           * out-of-bounds offset.  We need to handle this at least
@@ -252,7 +256,8 @@ foreach_deref_node_worker(struct deref_node *node, 
nir_deref_instr **path,
                                      struct lower_variables_state *state),
                           struct lower_variables_state *state)
 {
-   if (*path == NULL) {
+   if (glsl_type_is_vector_or_scalar(node->type)) {
+      assert(*path == NULL || (*path)->deref_type == nir_deref_type_array);
       cb(node, state);
       return;
    }
@@ -266,6 +271,9 @@ foreach_deref_node_worker(struct deref_node *node, 
nir_deref_instr **path,
       return;
 
    case nir_deref_type_array: {
+      if (glsl_type_is_vector_or_scalar(node->type))
+         return;
+
       uint32_t index = nir_src_as_uint((*path)->arr.index);
 
       if (node->children[index]) {
@@ -330,6 +338,13 @@ path_may_be_aliased_node(struct deref_node *node, 
nir_deref_instr **path,
       }
 
    case nir_deref_type_array: {
+      /* If the node is a vector, we consider it to not be aliased by any
+       * indirects for the purposes of this pass.  We'll insert a pile of
+       * bcsel if needed to resolve indirects.
+       */
+      if (glsl_type_is_vector_or_scalar(node->type))
+         return false;
+
       if (!nir_src_is_const((*path)->arr.index))
          return true;
 
@@ -357,6 +372,10 @@ path_may_be_aliased_node(struct deref_node *node, 
nir_deref_instr **path,
 }
 
 /* Returns true if there are no indirects that can ever touch this deref.
+ *
+ * The one exception here is that we allow indirects which select components
+ * of vectors.  These are handled by this pass by inserting the requisite
+ * pile of bcsel().
  *
  * For example, if the given deref is a[6].foo, then any uses of a[i].foo
  * would cause this to return false, but a[i].bar would not affect it
@@ -563,6 +582,24 @@ lower_copies_to_load_store(struct deref_node *node,
    node->copies = NULL;
 }
 
+static nir_def *
+deref_vec_component(nir_deref_instr *deref)
+{
+   if (deref->deref_type != nir_deref_type_array) {
+      assert(glsl_type_is_vector_or_scalar(deref->type));
+      return NULL;
+   }
+
+   nir_deref_instr *parent = nir_deref_instr_parent(deref);
+   if (glsl_type_is_vector_or_scalar(parent->type)) {
+      assert(glsl_type_is_scalar(deref->type));
+      return deref->arr.index.ssa;
+   } else {
+      assert(glsl_type_is_vector_or_scalar(deref->type));
+      return NULL;
+   }
+}
+
 /* Performs variable renaming
  *
  * This algorithm is very similar to the one outlined in "Efficiently
@@ -621,7 +658,15 @@ rename_variables(struct lower_variables_state *state)
             val = nir_mov(&b, val);
 
             assert(val->bit_size == intrin->def.bit_size);
-            assert(val->num_components == intrin->def.num_components);
+
+            nir_def *comp = deref_vec_component(deref);
+            if (comp == NULL) {
+               assert(val->num_components == intrin->def.num_components);
+            } else {
+               assert(intrin->def.num_components == 1);
+               b.cursor = nir_before_instr(&intrin->instr);
+               val = nir_vector_extract(&b, val, comp);
+            }
 
             nir_def_rewrite_uses(&intrin->def, val);
             nir_instr_remove(&intrin->instr);
@@ -646,13 +691,19 @@ rename_variables(struct lower_variables_state *state)
                continue;
 
             assert(intrin->num_components ==
-                   glsl_get_vector_elements(node->type));
+                   glsl_get_vector_elements(deref->type));
 
             nir_def *new_def;
             b.cursor = nir_before_instr(&intrin->instr);
 
+            nir_def *comp = deref_vec_component(deref);
             unsigned wrmask = nir_intrinsic_write_mask(intrin);
-            if (wrmask == (1 << intrin->num_components) - 1) {
+            if (comp != NULL) {
+               assert(wrmask == 1 && intrin->num_components == 1);
+               nir_def *old_def =
+                  nir_phi_builder_value_get_block_def(node->pb_value, block);
+               new_def = nir_vector_insert(&b, old_def, value, comp);
+            } else if (wrmask == (1 << intrin->num_components) - 1) {
                /* Whole variable store - just copy the source.  Note that
                 * intrin->num_components and value->num_components
                 * may differ.
@@ -681,8 +732,6 @@ rename_variables(struct lower_variables_state *state)
                new_def = nir_vec_scalars(&b, srcs, intrin->num_components);
             }
 
-            assert(new_def->num_components == intrin->num_components);
-
             nir_phi_builder_value_set_block_def(node->pb_value, block, 
new_def);
             nir_instr_remove(&intrin->instr);
             break;

Reply via email to