Attention is currently required from: Richard Cooper.

Hello Richard Cooper,

I'd like you to do a code review.
Please visit

    https://gem5-review.googlesource.com/c/public/gem5/+/70732?usp=email

to review the following change.


Change subject: arch-arm: Added Armv8.2-I8MM SVE mixed-sign dot product instrs.
......................................................................

arch-arm: Added Armv8.2-I8MM SVE mixed-sign dot product instrs.

Add support for the SVE mixed sign dot product instructions (USDOT,
SUDOT) required by the Armv8.2 SVE Int8 matrix multiplication
extension (ARMv8.2-I8MM).

For more information please refer to the "ARM Architecture Reference
Manual Supplement - The Scalable Vector Extension (SVE), for ARMv8-A"
(https://developer.arm.com/architectures/cpu-architecture/a-profile/
docs/arm-architecture-reference-manual-supplement-armv8-a)

Change-Id: I83841654cee74b940f967b3a37b99d87c01bd92c
Reviewed-by: Richard Cooper <richard.coo...@arm.com>
---
M src/arch/arm/isa/formats/sve_2nd_level.isa
M src/arch/arm/isa/insts/sve.isa
M src/arch/arm/isa/templates/sve.isa
3 files changed, 91 insertions(+), 47 deletions(-)



diff --git a/src/arch/arm/isa/formats/sve_2nd_level.isa b/src/arch/arm/isa/formats/sve_2nd_level.isa
index 4148b96..051323f 100644
--- a/src/arch/arm/isa/formats/sve_2nd_level.isa
+++ b/src/arch/arm/isa/formats/sve_2nd_level.isa
@@ -2256,19 +2256,19 @@
         uint8_t usig = (uint8_t) bits(machInst, 10);
         if (size & 0x1) {
             if (usig) {
-                return new SveUdotv<uint16_t, uint64_t>(machInst,
-                                                        zda, zn, zm);
+                return new SveUdotv<uint16_t, uint16_t, uint64_t>
+                                        (machInst, zda, zn, zm);
             } else {
-                return new SveSdotv<int16_t, int64_t>(machInst,
-                                                        zda, zn, zm);
+                return new SveSdotv<int16_t, int16_t, int64_t>
+                                        (machInst, zda, zn, zm);
             }
         } else {
             if (usig) {
-                return new SveUdotv<uint8_t, uint32_t>(machInst,
-                                                        zda, zn, zm);
+                return new SveUdotv<uint8_t, uint8_t, uint32_t>
+                                        (machInst, zda, zn, zm);
             } else {
-                return new SveSdotv<int8_t, int32_t>(machInst,
-                                                        zda, zn, zm);
+                return new SveSdotv<int8_t, int8_t, int32_t>
+                                        (machInst, zda, zn, zm);
             }
         }

@@ -2292,21 +2292,21 @@
             RegIndex zm = (RegIndex) (uint8_t) bits(machInst, 19, 16);
             uint8_t i1 = (uint8_t) bits(machInst, 20);
             if (usig) {
-                return new SveUdoti<uint16_t, uint64_t>(machInst,
-                                                        zda, zn, zm, i1);
+                return new SveUdoti<uint16_t, uint16_t, uint64_t>
+                                       (machInst, zda, zn, zm, i1);
             } else {
-                return new SveSdoti<int16_t, int64_t>(machInst,
-                                                        zda, zn, zm, i1);
+                return new SveSdoti<int16_t, int16_t, int64_t>
+                                       (machInst, zda, zn, zm, i1);
             }
         } else {
             RegIndex zm = (RegIndex) (uint8_t) bits(machInst, 18, 16);
             uint8_t i2 = (uint8_t) bits(machInst, 20, 19);
             if (usig) {
-                return new SveUdoti<uint8_t, uint32_t>(machInst,
-                                                        zda, zn, zm, i2);
+                return new SveUdoti<uint8_t, uint8_t, uint32_t>
+                                        (machInst, zda, zn, zm, i2);
             } else {
-                return new SveSdoti<int8_t, int32_t>(machInst,
-                                                        zda, zn, zm, i2);
+                return new SveSdoti<int8_t, int8_t, int32_t>
+                                        (machInst, zda, zn, zm, i2);
             }
         }
         return new Unknown64(machInst);
@@ -2327,9 +2327,8 @@
         RegIndex zm M5_VAR_USED = (RegIndex)
                                           (uint8_t) bits(machInst, 20, 16);

-        // Placeholder for SveUsdotv
- //return SveUsdotv<int32_t, uint8_t, int8_t>(machInst, zda, zn, zm);
-        return new Unknown64(machInst);
+        return new SveUsdotv<uint8_t, int8_t, int32_t>
+                                 (machInst, zda, zn, zm);
     } // decodeSveMixedSignDotProduct

     StaticInstPtr
@@ -2351,15 +2350,11 @@
         uint8_t usig = (uint8_t) bits(machInst, 10);

         if (usig) {
-            // Placeholder for SveSudoti
-            //return SveSudoti<int32_t, int8_t, uint8_t>
-            //                         (machInst, zda, zn, zm, i2);
-            return new Unknown64(machInst);
+            return new SveSudoti<int8_t, uint8_t, int32_t>
+                                     (machInst, zda, zn, zm, i2);
         } else {
-            // Placeholder for SveUsdoti
-            //return SveUsdoti<int32_t, uint8_t, int8_t>
-            //                         (machInst, zda, zn, zm, i2);
-            return new Unknown64(machInst);
+            return new SveUsdoti<uint8_t, int8_t, int32_t>
+                                     (machInst, zda, zn, zm, i2);
         }

     } // decodeSveMixedSignDotProductIndexed
diff --git a/src/arch/arm/isa/insts/sve.isa b/src/arch/arm/isa/insts/sve.isa
index f7bf8a3..b6b560d 100644
--- a/src/arch/arm/isa/insts/sve.isa
+++ b/src/arch/arm/isa/insts/sve.isa
@@ -3056,6 +3056,15 @@
     def sveDotInst(name, Name, opClass, types, isIndexed = True):
         global header_output, exec_output, decoders
         code = sveEnabledCheckCode + '''
+        // Types of the extended versions of the source elements.
+ // Required to make sure the intermediate calculations don't overflow.
+        using ExtendedElementA = typename vector_element_traits::
+                                   extend_element<DElement,
+                                                  SElementA>::type;
+        using ExtendedElementB = typename vector_element_traits::
+                                   extend_element<DElement,
+                                                  SElementB>::type;
+
         unsigned eCount = ArmStaticInst::getCurSveVecLen<Element>(
                 xc->tcBase());
         for (int i = 0; i < eCount; ++i) {'''
@@ -3065,17 +3074,21 @@
             int s = segbase + imm;'''
         code += '''
             DElement res = AA64FpDest_xd[i];
-            DElement srcElem1, srcElem2;
+            ExtendedElementA srcElemA;
+            ExtendedElementB srcElemB;
             for (int j = 0; j <= 3; ++j) {
- srcElem1 = static_cast<DElement>(AA64FpOp1_xs[4 * i + j]);'''
+                srcElemA = static_cast<ExtendedElementA>
+                                          (AA64FpOp1_srcA[4 * i + j]);'''
         if isIndexed:
             code += '''
- srcElem2 = static_cast<DElement>(AA64FpOp2_xs[4 * s + j]);'''
+                srcElemB = static_cast<ExtendedElementB>
+                                          (AA64FpOp2_srcB[4 * s + j]);'''
         else:
             code += '''
- srcElem2 = static_cast<DElement>(AA64FpOp2_xs[4 * i + j]);'''
+                srcElemB = static_cast<ExtendedElementB>
+                                          (AA64FpOp2_srcB[4 * i + j]);'''
         code += '''
-                res += srcElem1 * srcElem2;
+                res += srcElemA * srcElemB;
             }
             AA64FpDestMerge_xd[i] = res;
         }'''
@@ -3087,7 +3100,7 @@
             header_output += SveWideningTerImmOpDeclare.subst(iop)
         else:
             header_output += SveWideningTerOpDeclare.subst(iop)
-        exec_output += SveWideningOpExecute.subst(iop)
+        exec_output += SveWideningTerOpExecute.subst(iop)
         for type in types:
             substDict = {'targs': type, 'class_name': 'Sve' + Name}
             exec_output += SveOpExecDeclare.subst(substDict)
@@ -4412,11 +4425,14 @@
     sveBinInst('sdivr', 'Sdivr', 'SimdDivOp', signedTypes, sdivrCode,
                PredType.MERGE, True)
     # SDOT (indexed)
-    sveDotInst('sdot', 'Sdoti', 'SimdAluOp', ['int8_t, int32_t',
-        'int16_t, int64_t'], isIndexed = True)
+    sveDotInst('sdot', 'Sdoti', 'SimdAluOp', ['int8_t, int8_t, int32_t',
+        'int16_t, int16_t, int64_t'], isIndexed = True)
     # SDOT (vectors)
-    sveDotInst('sdot', 'Sdotv', 'SimdAluOp', ['int8_t, int32_t',
-        'int16_t, int64_t'], isIndexed = False)
+    sveDotInst('sdot', 'Sdotv', 'SimdAluOp', ['int8_t, int8_t, int32_t',
+        'int16_t, int16_t, int64_t'], isIndexed = False)
+    # SUDOT (indexed)
+ sveDotInst('sudot', 'Sudoti', 'SimdAluOp', ['int8_t, uint8_t, int32_t'],
+               isIndexed = True)
     # SEL (predicates)
     selCode = 'destElem = srcElem1;'
     svePredLogicalInst('sel', 'PredSel', 'SimdPredAluOp', ('uint8_t',),
@@ -4801,11 +4817,17 @@
     sveBinInst('udivr', 'Udivr', 'SimdDivOp', unsignedTypes, udivrCode,
                PredType.MERGE, True)
     # UDOT (indexed)
-    sveDotInst('udot', 'Udoti', 'SimdAluOp', ['uint8_t, uint32_t',
-        'uint16_t, uint64_t'], isIndexed = True)
+    sveDotInst('udot', 'Udoti', 'SimdAluOp', ['uint8_t, uint8_t, uint32_t',
+        'uint16_t, uint16_t, uint64_t'], isIndexed = True)
     # UDOT (vectors)
-    sveDotInst('udot', 'Udotv', 'SimdAluOp', ['uint8_t, uint32_t',
-        'uint16_t, uint64_t'], isIndexed = False)
+    sveDotInst('udot', 'Udotv', 'SimdAluOp', ['uint8_t, uint8_t, uint32_t',
+        'uint16_t, uint16_t, uint64_t'], isIndexed = False)
+    # USDOT (indexed)
+ sveDotInst('usdot', 'Usdoti', 'SimdAluOp', ['uint8_t, int8_t, int32_t'],
+               isIndexed = True)
+    # USDOT (vectors)
+ sveDotInst('usdot', 'Usdotv', 'SimdAluOp', ['uint8_t, int8_t, int32_t'],
+               isIndexed = False)
     # UMAX (immediate)
     sveWideImmInst('umax', 'UmaxImm', 'SimdCmpOp', unsignedTypes, maxCode)
     # UMAX (vectors)
diff --git a/src/arch/arm/isa/templates/sve.isa b/src/arch/arm/isa/templates/sve.isa
index 9043e23..7c41b40 100644
--- a/src/arch/arm/isa/templates/sve.isa
+++ b/src/arch/arm/isa/templates/sve.isa
@@ -1113,17 +1113,22 @@
 }};

 def template SveWideningTerImmOpDeclare {{
-template <class _SElement, class _DElement>
+template <class _SElementA, class _SElementB, class _DElement>
 class %(class_name)s : public %(base_class)s
 {
+  static_assert(sizeof(_SElementA) == sizeof(_SElementB),
+                "Source elements must have the same size.");
+
   private:
     %(reg_idx_arr_decl)s;

   protected:
     typedef _DElement Element;
-    typedef _SElement SElement;
+    typedef _SElementA SElementA;
+    typedef _SElementB SElementB;
     typedef _DElement DElement;
-    typedef _SElement TPSElem;
+    typedef _SElementA TPSrcAElem;
+    typedef _SElementB TPSrcBElem;
     typedef _DElement TPDElem;

   public:
@@ -1142,7 +1147,7 @@
 }};

 def template SveWideningTerOpDeclare {{
-template <class _SElement, class _DElement>
+template <class _SElementA, class _SElementB, class _DElement>
 class %(class_name)s : public %(base_class)s
 {
   private:
@@ -1150,9 +1155,11 @@

   protected:
     typedef _DElement Element;
-    typedef _SElement SElement;
+    typedef _SElementA SElementA;
+    typedef _SElementB SElementB;
     typedef _DElement DElement;
-    typedef _SElement TPSElem;
+    typedef _SElementA TPSrcAElem;
+    typedef _SElementB TPSrcBElem;
     typedef _DElement TPDElem;

   public:
@@ -1269,6 +1276,26 @@
     }
 }};

+def template SveWideningTerOpExecute {{
+    template <class SElementA, class SElementB, class DElement>
+    Fault %(class_name)s<SElementA, SElementB, DElement>::execute
+           (ExecContext *xc,
+            trace::InstRecord *traceData) const
+    {
+        Fault fault = NoFault;
+        %(op_decl)s;
+        %(op_rd)s;
+
+        %(code)s;
+        if (fault == NoFault)
+        {
+            %(op_wb)s;
+        }
+
+        return fault;
+    }
+}};
+
 def template SveNonTemplatedOpExecute {{
     Fault
     %(class_name)s::execute(ExecContext *xc,

--
To view, visit https://gem5-review.googlesource.com/c/public/gem5/+/70732?usp=email To unsubscribe, or for help writing mail filters, visit https://gem5-review.googlesource.com/settings?usp=email

Gerrit-MessageType: newchange
Gerrit-Project: public/gem5
Gerrit-Branch: develop
Gerrit-Change-Id: I83841654cee74b940f967b3a37b99d87c01bd92c
Gerrit-Change-Number: 70732
Gerrit-PatchSet: 1
Gerrit-Owner: Giacomo Travaglini <giacomo.travagl...@arm.com>
Gerrit-Reviewer: Richard Cooper <richard.coo...@arm.com>
Gerrit-Attention: Richard Cooper <richard.coo...@arm.com>
_______________________________________________
gem5-dev mailing list -- gem5-dev@gem5.org
To unsubscribe send an email to gem5-dev-le...@gem5.org

Reply via email to