From: Nicolai Hähnle <nicolai.haeh...@amd.com> Allow for a unified but efficient treatment of adding a bitmask over a wave or an entire threadgroup. --- src/amd/common/ac_llvm_build.c | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-)
diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c index 932f4bbdeef..eb840369d07 100644 --- a/src/amd/common/ac_llvm_build.c +++ b/src/amd/common/ac_llvm_build.c @@ -3391,36 +3391,57 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu if (maxprefix <= 32) return result; tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); return result; } LLVMValueRef ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op) { - ac_build_optimization_barrier(ctx, &src); LLVMValueRef result; + + if (LLVMTypeOf(src) == ctx->i1 && op == nir_op_iadd) { + LLVMBuilderRef builder = ctx->builder; + src = LLVMBuildZExt(builder, src, ctx->i32, ""); + result = ac_build_ballot(ctx, src); + result = ac_build_mbcnt(ctx, result); + result = LLVMBuildAdd(builder, result, src, ""); + return result; + } + + ac_build_optimization_barrier(ctx, &src); + LLVMValueRef identity = get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); result = ac_build_scan(ctx, op, result, identity, 64); return ac_build_wwm(ctx, result); } LLVMValueRef ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op) { - ac_build_optimization_barrier(ctx, &src); LLVMValueRef result; + + if (LLVMTypeOf(src) == ctx->i1 && op == nir_op_iadd) { + LLVMBuilderRef builder = ctx->builder; + src = LLVMBuildZExt(builder, src, ctx->i32, ""); + result = ac_build_ballot(ctx, src); + result = ac_build_mbcnt(ctx, result); + return result; + } + + ac_build_optimization_barrier(ctx, &src); + LLVMValueRef identity = get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false); result = ac_build_scan(ctx, op, result, identity, 64); return ac_build_wwm(ctx, result); } @@ -3585,20 +3606,22 @@ ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) * "Top half" of a scan that reduces per-thread values across an entire * workgroup. * * All lanes must be active when this code runs. */ void ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) { if (ws->enable_exclusive) { ws->extra = ac_build_exclusive_scan(ctx, ws->src, ws->op); + if (LLVMTypeOf(ws->src) == ctx->i1 && ws->op == nir_op_iadd) + ws->src = LLVMBuildZExt(ctx->builder, ws->src, ctx->i32, ""); ws->src = ac_build_alu_op(ctx, ws->extra, ws->src, ws->op); } else { ws->src = ac_build_inclusive_scan(ctx, ws->src, ws->op); } bool enable_inclusive = ws->enable_inclusive; bool enable_exclusive = ws->enable_exclusive; ws->enable_inclusive = false; ws->enable_exclusive = ws->enable_exclusive || enable_inclusive; ac_build_wg_wavescan_top(ctx, ws); -- 2.19.1 _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev