Module: Mesa
Branch: main
Commit: eb633b49782be3f6df90a5030e92f771896a7e17
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=eb633b49782be3f6df90a5030e92f771896a7e17

Author: Faith Ekstrand <[email protected]>
Date:   Tue Dec  5 15:09:34 2023 -0600

nak: Wire up DP4

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26533>

---

 src/nouveau/compiler/nak/api.rs         |  3 +++
 src/nouveau/compiler/nak/encode_sm70.rs | 28 ++++++++++++++++++++++++++++
 src/nouveau/compiler/nak/from_nir.rs    | 27 +++++++++++++++++++++++++++
 src/nouveau/compiler/nak/ir.rs          | 28 ++++++++++++++++++++++++++++
 src/nouveau/compiler/nak/legalize.rs    | 14 +++++++++++++-
 5 files changed, 99 insertions(+), 1 deletion(-)

diff --git a/src/nouveau/compiler/nak/api.rs b/src/nouveau/compiler/nak/api.rs
index 80ddac62f43..44d96a87618 100644
--- a/src/nouveau/compiler/nak/api.rs
+++ b/src/nouveau/compiler/nak/api.rs
@@ -118,6 +118,9 @@ fn nir_options(dev: &nv_device_info) -> 
nir_shader_compiler_options {
     op.lower_scmp = true;
     op.lower_uadd_carry = true;
     op.lower_usub_borrow = true;
+    op.has_sdot_4x8 = dev.sm >= 70;
+    op.has_udot_4x8 = dev.sm >= 70;
+    op.has_sudot_4x8 = dev.sm >= 70;
 
     op
 }
diff --git a/src/nouveau/compiler/nak/encode_sm70.rs 
b/src/nouveau/compiler/nak/encode_sm70.rs
index ee83be9a627..6f2c16f3210 100644
--- a/src/nouveau/compiler/nak/encode_sm70.rs
+++ b/src/nouveau/compiler/nak/encode_sm70.rs
@@ -661,6 +661,33 @@ impl SM70Instr {
         self.set_pred_src(77..80, 80, op.carry[1]);
     }
 
+    fn encode_idp4(&mut self, op: &OpIDp4) {
+        self.encode_alu(
+            0x026,
+            Some(op.dst),
+            ALUSrc::from_src(&op.srcs[0]),
+            ALUSrc::from_src(&op.srcs[1]),
+            ALUSrc::from_src(&op.srcs[2]),
+        );
+
+        self.set_bit(
+            73,
+            match op.src_types[0] {
+                IntType::U8 => false,
+                IntType::I8 => true,
+                _ => panic!("Invalid DP4 source type"),
+            },
+        );
+        self.set_bit(
+            74,
+            match op.src_types[1] {
+                IntType::U8 => false,
+                IntType::I8 => true,
+                _ => panic!("Invalid DP4 source type"),
+            },
+        );
+    }
+
     fn encode_imad(&mut self, op: &OpIMad) {
         self.encode_alu(
             0x024,
@@ -1915,6 +1942,7 @@ impl SM70Instr {
             Op::IAbs(op) => si.encode_iabs(&op),
             Op::IAdd3(op) => si.encode_iadd3(&op),
             Op::IAdd3X(op) => si.encode_iadd3x(&op),
+            Op::IDp4(op) => si.encode_idp4(&op),
             Op::IMad(op) => si.encode_imad(&op),
             Op::IMad64(op) => si.encode_imad64(&op),
             Op::IMnMx(op) => si.encode_imnmx(&op),
diff --git a/src/nouveau/compiler/nak/from_nir.rs 
b/src/nouveau/compiler/nak/from_nir.rs
index 1e308f2f635..e24c96a660c 100644
--- a/src/nouveau/compiler/nak/from_nir.rs
+++ b/src/nouveau/compiler/nak/from_nir.rs
@@ -1024,6 +1024,33 @@ impl<'a> ShaderFromNir<'a> {
 
                 b.prmt(low.into(), high.into(), [0, 1, 4, 5])
             }
+            nir_op_sdot_4x8_iadd => {
+                let dst = b.alloc_ssa(RegFile::GPR, 1);
+                b.push_op(OpIDp4 {
+                    dst: dst.into(),
+                    src_types: [IntType::I8, IntType::I8],
+                    srcs: [srcs[0], srcs[1], srcs[2]],
+                });
+                dst
+            }
+            nir_op_sudot_4x8_iadd => {
+                let dst = b.alloc_ssa(RegFile::GPR, 1);
+                b.push_op(OpIDp4 {
+                    dst: dst.into(),
+                    src_types: [IntType::I8, IntType::U8],
+                    srcs: [srcs[0], srcs[1], srcs[2]],
+                });
+                dst
+            }
+            nir_op_udot_4x8_uadd => {
+                let dst = b.alloc_ssa(RegFile::GPR, 1);
+                b.push_op(OpIDp4 {
+                    dst: dst.into(),
+                    src_types: [IntType::U8, IntType::U8],
+                    srcs: [srcs[0], srcs[1], srcs[2]],
+                });
+                dst
+            }
             nir_op_u2f16 | nir_op_u2f32 | nir_op_u2f64 => {
                 let src_bits = alu.get_src(0).src.bit_size();
                 let dst_bits = alu.def.bit_size();
diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs
index 87571b8c46a..0dd1bf9163c 100644
--- a/src/nouveau/compiler/nak/ir.rs
+++ b/src/nouveau/compiler/nak/ir.rs
@@ -2450,6 +2450,32 @@ impl DisplayOp for OpIAdd3X {
 }
 impl_display_for_op!(OpIAdd3X);
 
+#[repr(C)]
+#[derive(SrcsAsSlice, DstsAsSlice)]
+pub struct OpIDp4 {
+    pub dst: Dst,
+
+    pub src_types: [IntType; 2],
+
+    #[src_type(I32)]
+    pub srcs: [Src; 3],
+}
+
+impl DisplayOp for OpIDp4 {
+    fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(
+            f,
+            "idp4{}{} {} {} {}",
+            self.src_types[0],
+            self.src_types[1],
+            self.srcs[0],
+            self.srcs[1],
+            self.srcs[2],
+        )
+    }
+}
+impl_display_for_op!(OpIDp4);
+
 #[repr(C)]
 #[derive(SrcsAsSlice, DstsAsSlice)]
 pub struct OpIMad {
@@ -4299,6 +4325,7 @@ pub enum Op {
     INeg(OpINeg),
     IAdd3(OpIAdd3),
     IAdd3X(OpIAdd3X),
+    IDp4(OpIDp4),
     IMad(OpIMad),
     IMad64(OpIMad64),
     IMnMx(OpIMnMx),
@@ -4731,6 +4758,7 @@ impl Instr {
             | Op::INeg(_)
             | Op::IAdd3(_)
             | Op::IAdd3X(_)
+            | Op::IDp4(_)
             | Op::IMad(_)
             | Op::IMad64(_)
             | Op::IMnMx(_)
diff --git a/src/nouveau/compiler/nak/legalize.rs 
b/src/nouveau/compiler/nak/legalize.rs
index 61b7e02a7d3..894d08a1200 100644
--- a/src/nouveau/compiler/nak/legalize.rs
+++ b/src/nouveau/compiler/nak/legalize.rs
@@ -63,9 +63,12 @@ fn copy_src_if_not_reg(b: &mut impl SSABuilder, src: &mut 
Src, file: RegFile) {
     }
 }
 
-fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) {
+fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) -> bool {
     if !src_is_reg(x) && src_is_reg(y) {
         std::mem::swap(x, y);
+        true
+    } else {
+        false
     }
 }
 
@@ -153,6 +156,15 @@ fn legalize_instr(
             copy_src_if_not_reg(b, src0, RegFile::GPR);
             copy_src_if_not_reg(b, src2, RegFile::GPR);
         }
+        Op::IDp4(op) => {
+            let [ref mut src_type0, ref mut src_type1] = op.src_types;
+            let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
+            if swap_srcs_if_not_reg(src0, src1) {
+                std::mem::swap(src_type0, src_type1);
+            }
+            copy_src_if_not_reg(b, src0, RegFile::GPR);
+            copy_src_if_not_reg(b, src2, RegFile::GPR);
+        }
         Op::IMad(op) => {
             let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
             swap_srcs_if_not_reg(src0, src1);

Reply via email to