Module: Mesa
Branch: master
Commit: 7b3073ad44af8b9576203cdee2dc4e2ca10a7b54
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=7b3073ad44af8b9576203cdee2dc4e2ca10a7b54

Author: Dave Airlie <[email protected]>
Date:   Fri Mar 19 12:02:50 2021 +1000

gallivm: add subgroup reduction + in/ex scan support

Reviewed-by: Roland Scheidegger <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/9645>

---

 src/gallium/auxiliary/gallivm/lp_bld_nir.c     |   5 +
 src/gallium/auxiliary/gallivm/lp_bld_nir.h     |   1 +
 src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c | 151 +++++++++++++++++++++++++
 3 files changed, 157 insertions(+)

diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c 
b/src/gallium/auxiliary/gallivm/lp_bld_nir.c
index d41be6ba0fc..aaa88cdf494 100644
--- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c
+++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c
@@ -1789,6 +1789,11 @@ static void visit_intrinsic(struct lp_build_nir_context 
*bld_base,
    case nir_intrinsic_elect:
       bld_base->elect(bld_base, result);
       break;
+   case nir_intrinsic_reduce:
+   case nir_intrinsic_inclusive_scan:
+   case nir_intrinsic_exclusive_scan:
+      bld_base->reduce(bld_base, cast_type(bld_base, get_src(bld_base, 
instr->src[0]), nir_type_int, nir_src_bit_size(instr->src[0])), instr, result);
+      break;
    case nir_intrinsic_interp_deref_at_offset:
    case nir_intrinsic_interp_deref_at_centroid:
    case nir_intrinsic_interp_deref_at_sample:
diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.h 
b/src/gallium/auxiliary/gallivm/lp_bld_nir.h
index e396db71eb8..d65309e6b0a 100644
--- a/src/gallium/auxiliary/gallivm/lp_bld_nir.h
+++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.h
@@ -186,6 +186,7 @@ struct lp_build_nir_context
 
    void (*vote)(struct lp_build_nir_context *bld_base, LLVMValueRef src, 
nir_intrinsic_instr *instr, LLVMValueRef dst[4]);
    void (*elect)(struct lp_build_nir_context *bld_base, LLVMValueRef dst[4]);
+   void (*reduce)(struct lp_build_nir_context *bld_base, LLVMValueRef src, 
nir_intrinsic_instr *instr, LLVMValueRef dst[4]);
    void (*helper_invocation)(struct lp_build_nir_context *bld_base, 
LLVMValueRef *dst);
 
    void (*interp_at)(struct lp_build_nir_context *bld_base,
diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c 
b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
index 6c122fc18ba..5965f863f86 100644
--- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
+++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c
@@ -1920,6 +1920,156 @@ static void emit_elect(struct lp_build_nir_context 
*bld_base, LLVMValueRef resul
                                       "");
 }
 
+static void emit_reduce(struct lp_build_nir_context *bld_base, LLVMValueRef 
src,
+                        nir_intrinsic_instr *instr, LLVMValueRef result[4])
+{
+   struct gallivm_state *gallivm = bld_base->base.gallivm;
+   LLVMBuilderRef builder = gallivm->builder;
+   uint32_t bit_size = nir_src_bit_size(instr->src[0]);
+   /* can't use llvm reduction intrinsics because of exec_mask */
+   LLVMValueRef exec_mask = mask_vec(bld_base);
+   struct lp_build_loop_state loop_state;
+   nir_op reduction_op = nir_intrinsic_reduction_op(instr);
+
+   LLVMValueRef res_store = NULL;
+   LLVMValueRef scan_store;
+   struct lp_build_context *int_bld = get_int_bld(bld_base, true, bit_size);
+
+   if (instr->intrinsic != nir_intrinsic_reduce)
+      res_store = lp_build_alloca(gallivm, int_bld->vec_type, "");
+
+   scan_store = lp_build_alloca(gallivm, int_bld->elem_type, "");
+
+   struct lp_build_context elem_bld;
+   bool is_flt = reduction_op == nir_op_fadd ||
+      reduction_op == nir_op_fmul ||
+      reduction_op == nir_op_fmin ||
+      reduction_op == nir_op_fmax;
+   bool is_unsigned = reduction_op == nir_op_umin ||
+      reduction_op == nir_op_umax;
+
+   struct lp_build_context *vec_bld = is_flt ? get_flt_bld(bld_base, bit_size) 
:
+      get_int_bld(bld_base, is_unsigned, bit_size);
+
+   lp_build_context_init(&elem_bld, gallivm, lp_elem_type(vec_bld->type));
+
+   LLVMValueRef store_val = NULL;
+   /*
+    * Put the identity value for the operation into the storage
+    */
+   switch (reduction_op) {
+   case nir_op_fmin: {
+      LLVMValueRef flt_max = bit_size == 64 ? 
LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), INFINITY) :
+         lp_build_const_float(gallivm, INFINITY);
+      store_val = LLVMBuildBitCast(builder, flt_max, int_bld->elem_type, "");
+      break;
+   }
+   case nir_op_fmax: {
+      LLVMValueRef flt_min = bit_size == 64 ? 
LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), -INFINITY) :
+         lp_build_const_float(gallivm, -INFINITY);
+      store_val = LLVMBuildBitCast(builder, flt_min, int_bld->elem_type, "");
+      break;
+   }
+   case nir_op_fmul: {
+      LLVMValueRef flt_one = bit_size == 64 ? 
LLVMConstReal(LLVMDoubleTypeInContext(gallivm->context), 1.0) :
+         lp_build_const_float(gallivm, 1.0);
+      store_val = LLVMBuildBitCast(builder, flt_one, int_bld->elem_type, "");
+      break;
+   }
+   case nir_op_umin:
+      store_val = lp_build_const_int32(gallivm, UINT_MAX);
+      break;
+   case nir_op_imin:
+      store_val = lp_build_const_int32(gallivm, INT_MAX);
+      break;
+   case nir_op_imax:
+      store_val = lp_build_const_int32(gallivm, INT_MIN);
+      break;
+   case nir_op_imul:
+      store_val = lp_build_const_int32(gallivm, 1);
+      break;
+   case nir_op_iand:
+      store_val = lp_build_const_int32(gallivm, 0xffffffff);
+      break;
+   default:
+      break;
+   }
+   if (store_val)
+      LLVMBuildStore(builder, store_val, scan_store);
+
+   LLVMValueRef outer_cond = LLVMBuildICmp(builder, LLVMIntNE, exec_mask, 
bld_base->uint_bld.zero, "");
+
+   lp_build_loop_begin(&loop_state, gallivm, lp_build_const_int32(gallivm, 0));
+
+   struct lp_build_if_state ifthen;
+   LLVMValueRef if_cond = LLVMBuildExtractElement(gallivm->builder, 
outer_cond, loop_state.counter, "");
+   lp_build_if(&ifthen, gallivm, if_cond);
+   LLVMValueRef value = LLVMBuildExtractElement(gallivm->builder, src, 
loop_state.counter, "");
+
+   LLVMValueRef res = NULL;
+   LLVMValueRef scan_val = LLVMBuildLoad(gallivm->builder, scan_store, "");
+   if (instr->intrinsic != nir_intrinsic_reduce)
+      res = LLVMBuildLoad(gallivm->builder, res_store, "");
+
+   if (instr->intrinsic == nir_intrinsic_exclusive_scan)
+      res = LLVMBuildInsertElement(builder, res, scan_val, loop_state.counter, 
"");
+
+   if (is_flt) {
+      scan_val = LLVMBuildBitCast(builder, scan_val, elem_bld.elem_type, "");
+      value = LLVMBuildBitCast(builder, value, elem_bld.elem_type, "");
+   }
+   switch (reduction_op) {
+   case nir_op_fadd:
+   case nir_op_iadd:
+      scan_val = lp_build_add(&elem_bld, value, scan_val);
+      break;
+   case nir_op_fmul:
+   case nir_op_imul:
+      scan_val = lp_build_mul(&elem_bld, value, scan_val);
+      break;
+   case nir_op_imin:
+   case nir_op_umin:
+   case nir_op_fmin:
+      scan_val = lp_build_min(&elem_bld, value, scan_val);
+      break;
+   case nir_op_imax:
+   case nir_op_umax:
+   case nir_op_fmax:
+      scan_val = lp_build_max(&elem_bld, value, scan_val);
+      break;
+   case nir_op_iand:
+      scan_val = lp_build_and(&elem_bld, value, scan_val);
+      break;
+   case nir_op_ior:
+      scan_val = lp_build_or(&elem_bld, value, scan_val);
+      break;
+   case nir_op_ixor:
+      scan_val = lp_build_xor(&elem_bld, value, scan_val);
+      break;
+   default:
+      assert(0);
+      break;
+   }
+   if (is_flt)
+      scan_val = LLVMBuildBitCast(builder, scan_val, int_bld->elem_type, "");
+   LLVMBuildStore(builder, scan_val, scan_store);
+
+   if (instr->intrinsic == nir_intrinsic_inclusive_scan) {
+      res = LLVMBuildInsertElement(builder, res, scan_val, loop_state.counter, 
"");
+   }
+
+   if (instr->intrinsic != nir_intrinsic_reduce)
+      LLVMBuildStore(builder, res, res_store);
+   lp_build_endif(&ifthen);
+
+   lp_build_loop_end_cond(&loop_state, lp_build_const_int32(gallivm, 
bld_base->uint_bld.type.length),
+                          NULL, LLVMIntUGE);
+   if (instr->intrinsic == nir_intrinsic_reduce)
+      result[0] = lp_build_broadcast_scalar(int_bld, LLVMBuildLoad(builder, 
scan_store, ""));
+   else
+      result[0] = LLVMBuildLoad(builder, res_store, "");
+}
+
 static void
 emit_interp_at(struct lp_build_nir_context *bld_base,
                unsigned num_components,
@@ -2166,6 +2316,7 @@ void lp_build_nir_soa(struct gallivm_state *gallivm,
    bld.bld_base.image_size = emit_image_size;
    bld.bld_base.vote = emit_vote;
    bld.bld_base.elect = emit_elect;
+   bld.bld_base.reduce = emit_reduce;
    bld.bld_base.helper_invocation = emit_helper_invocation;
    bld.bld_base.interp_at = emit_interp_at;
    bld.bld_base.load_scratch = emit_load_scratch;

_______________________________________________
mesa-commit mailing list
[email protected]
https://lists.freedesktop.org/mailman/listinfo/mesa-commit

Reply via email to