Module: Mesa Branch: main Commit: 79c8740c6e758bef29d43fa352d7b5f4668f78a8 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=79c8740c6e758bef29d43fa352d7b5f4668f78a8
Author: Rhys Perry <[email protected]> Date: Wed Feb 23 11:33:16 2022 +0000 aco: fix fp16 opcode definitions The v_fma_mix optimizations assume v_cvt_f16_f32 and v_mul_f16 use a v2b definition. Signed-off-by: Rhys Perry <[email protected]> Reviewed-by: Daniel Schürmann <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769> --- src/amd/compiler/aco_instruction_selection.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index ea4518c2943..ebcb0a63410 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -2520,7 +2520,7 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) aco_ptr<Instruction> norm; if (dst.regClass() == v2b) { Temp half_pi = bld.copy(bld.def(s1), Operand::c32(0x3118u)); - Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v1), half_pi, src); + Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v2b), half_pi, src); aco_opcode opcode = instr->op == nir_op_fsin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16; bld.vop1(opcode, Definition(dst), tmp); @@ -3334,7 +3334,7 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) } case nir_op_fquantize2f16: { Temp src = get_alu_src(ctx, instr->src[0]); - Temp f16 = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src); + Temp f16 = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v2b), src); Temp f32, cmp_res; if (ctx->program->chip_class >= GFX8) { @@ -7642,7 +7642,8 @@ emit_addition_uniform_reduce(isel_context* ctx, nir_op op, Definition dst, nir_s if (op == nir_op_fadd) { src_tmp = as_vgpr(ctx, src_tmp); - Temp tmp = dst.regClass() == s1 ? bld.tmp(src_tmp.regClass()) : dst.getTemp(); + Temp tmp = dst.regClass() == s1 ? bld.tmp(RegClass::get(RegType::vgpr, src.ssa->bit_size / 8)) + : dst.getTemp(); if (src.ssa->bit_size == 16) { count = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v2b), count);
