Module: Mesa Branch: main Commit: eb633b49782be3f6df90a5030e92f771896a7e17 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=eb633b49782be3f6df90a5030e92f771896a7e17
Author: Faith Ekstrand <[email protected]> Date: Tue Dec 5 15:09:34 2023 -0600 nak: Wire up DP4 Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26533> --- src/nouveau/compiler/nak/api.rs | 3 +++ src/nouveau/compiler/nak/encode_sm70.rs | 28 ++++++++++++++++++++++++++++ src/nouveau/compiler/nak/from_nir.rs | 27 +++++++++++++++++++++++++++ src/nouveau/compiler/nak/ir.rs | 28 ++++++++++++++++++++++++++++ src/nouveau/compiler/nak/legalize.rs | 14 +++++++++++++- 5 files changed, 99 insertions(+), 1 deletion(-) diff --git a/src/nouveau/compiler/nak/api.rs b/src/nouveau/compiler/nak/api.rs index 80ddac62f43..44d96a87618 100644 --- a/src/nouveau/compiler/nak/api.rs +++ b/src/nouveau/compiler/nak/api.rs @@ -118,6 +118,9 @@ fn nir_options(dev: &nv_device_info) -> nir_shader_compiler_options { op.lower_scmp = true; op.lower_uadd_carry = true; op.lower_usub_borrow = true; + op.has_sdot_4x8 = dev.sm >= 70; + op.has_udot_4x8 = dev.sm >= 70; + op.has_sudot_4x8 = dev.sm >= 70; op } diff --git a/src/nouveau/compiler/nak/encode_sm70.rs b/src/nouveau/compiler/nak/encode_sm70.rs index ee83be9a627..6f2c16f3210 100644 --- a/src/nouveau/compiler/nak/encode_sm70.rs +++ b/src/nouveau/compiler/nak/encode_sm70.rs @@ -661,6 +661,33 @@ impl SM70Instr { self.set_pred_src(77..80, 80, op.carry[1]); } + fn encode_idp4(&mut self, op: &OpIDp4) { + self.encode_alu( + 0x026, + Some(op.dst), + ALUSrc::from_src(&op.srcs[0]), + ALUSrc::from_src(&op.srcs[1]), + ALUSrc::from_src(&op.srcs[2]), + ); + + self.set_bit( + 73, + match op.src_types[0] { + IntType::U8 => false, + IntType::I8 => true, + _ => panic!("Invalid DP4 source type"), + }, + ); + self.set_bit( + 74, + match op.src_types[1] { + IntType::U8 => false, + IntType::I8 => true, + _ => panic!("Invalid DP4 source type"), + }, + ); + } + fn encode_imad(&mut self, op: &OpIMad) { self.encode_alu( 0x024, @@ -1915,6 +1942,7 @@ impl SM70Instr { Op::IAbs(op) => si.encode_iabs(&op), Op::IAdd3(op) => si.encode_iadd3(&op), Op::IAdd3X(op) => si.encode_iadd3x(&op), + Op::IDp4(op) => si.encode_idp4(&op), Op::IMad(op) => si.encode_imad(&op), Op::IMad64(op) => si.encode_imad64(&op), Op::IMnMx(op) => si.encode_imnmx(&op), diff --git a/src/nouveau/compiler/nak/from_nir.rs b/src/nouveau/compiler/nak/from_nir.rs index 1e308f2f635..e24c96a660c 100644 --- a/src/nouveau/compiler/nak/from_nir.rs +++ b/src/nouveau/compiler/nak/from_nir.rs @@ -1024,6 +1024,33 @@ impl<'a> ShaderFromNir<'a> { b.prmt(low.into(), high.into(), [0, 1, 4, 5]) } + nir_op_sdot_4x8_iadd => { + let dst = b.alloc_ssa(RegFile::GPR, 1); + b.push_op(OpIDp4 { + dst: dst.into(), + src_types: [IntType::I8, IntType::I8], + srcs: [srcs[0], srcs[1], srcs[2]], + }); + dst + } + nir_op_sudot_4x8_iadd => { + let dst = b.alloc_ssa(RegFile::GPR, 1); + b.push_op(OpIDp4 { + dst: dst.into(), + src_types: [IntType::I8, IntType::U8], + srcs: [srcs[0], srcs[1], srcs[2]], + }); + dst + } + nir_op_udot_4x8_uadd => { + let dst = b.alloc_ssa(RegFile::GPR, 1); + b.push_op(OpIDp4 { + dst: dst.into(), + src_types: [IntType::U8, IntType::U8], + srcs: [srcs[0], srcs[1], srcs[2]], + }); + dst + } nir_op_u2f16 | nir_op_u2f32 | nir_op_u2f64 => { let src_bits = alu.get_src(0).src.bit_size(); let dst_bits = alu.def.bit_size(); diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 87571b8c46a..0dd1bf9163c 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -2450,6 +2450,32 @@ impl DisplayOp for OpIAdd3X { } impl_display_for_op!(OpIAdd3X); +#[repr(C)] +#[derive(SrcsAsSlice, DstsAsSlice)] +pub struct OpIDp4 { + pub dst: Dst, + + pub src_types: [IntType; 2], + + #[src_type(I32)] + pub srcs: [Src; 3], +} + +impl DisplayOp for OpIDp4 { + fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "idp4{}{} {} {} {}", + self.src_types[0], + self.src_types[1], + self.srcs[0], + self.srcs[1], + self.srcs[2], + ) + } +} +impl_display_for_op!(OpIDp4); + #[repr(C)] #[derive(SrcsAsSlice, DstsAsSlice)] pub struct OpIMad { @@ -4299,6 +4325,7 @@ pub enum Op { INeg(OpINeg), IAdd3(OpIAdd3), IAdd3X(OpIAdd3X), + IDp4(OpIDp4), IMad(OpIMad), IMad64(OpIMad64), IMnMx(OpIMnMx), @@ -4731,6 +4758,7 @@ impl Instr { | Op::INeg(_) | Op::IAdd3(_) | Op::IAdd3X(_) + | Op::IDp4(_) | Op::IMad(_) | Op::IMad64(_) | Op::IMnMx(_) diff --git a/src/nouveau/compiler/nak/legalize.rs b/src/nouveau/compiler/nak/legalize.rs index 61b7e02a7d3..894d08a1200 100644 --- a/src/nouveau/compiler/nak/legalize.rs +++ b/src/nouveau/compiler/nak/legalize.rs @@ -63,9 +63,12 @@ fn copy_src_if_not_reg(b: &mut impl SSABuilder, src: &mut Src, file: RegFile) { } } -fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) { +fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) -> bool { if !src_is_reg(x) && src_is_reg(y) { std::mem::swap(x, y); + true + } else { + false } } @@ -153,6 +156,15 @@ fn legalize_instr( copy_src_if_not_reg(b, src0, RegFile::GPR); copy_src_if_not_reg(b, src2, RegFile::GPR); } + Op::IDp4(op) => { + let [ref mut src_type0, ref mut src_type1] = op.src_types; + let [ref mut src0, ref mut src1, ref mut src2] = op.srcs; + if swap_srcs_if_not_reg(src0, src1) { + std::mem::swap(src_type0, src_type1); + } + copy_src_if_not_reg(b, src0, RegFile::GPR); + copy_src_if_not_reg(b, src2, RegFile::GPR); + } Op::IMad(op) => { let [ref mut src0, ref mut src1, ref mut src2] = op.srcs; swap_srcs_if_not_reg(src0, src1);
