4vtomat created this revision.
Herald added subscribers: jobnoorman, luke, VincentWu, vkmr, frasercrmck, 
luismarques, apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, jocewei, 
PkmX, the_o, brucehoult, MartinMosbeck, rogfer01, edward-jones, zzheng, jrtc27, 
shiva0217, kito-cheng, niosHD, sabuasal, simoncook, johnrusso, rbar, asb, 
arichardson.
Herald added a project: All.
4vtomat requested review of this revision.
Herald added subscribers: cfe-commits, wangpc, eopXD, MaskRay.
Herald added a project: clang.

This patch handles vector type and tuple type arguments
calling convention. Vector type arguments can be passed
directly by register or by reference, however tuple type
arguments are split into multiple vector type arguments,
all of these arguments should be either passed by references
or passed by registers.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D158050

Files:
  clang/include/clang/AST/Type.h
  clang/lib/CodeGen/Targets/RISCV.cpp
  clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c

Index: clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/RISCV/riscv-vector-callingconv-llvm-ir.c
@@ -0,0 +1,26 @@
+// REQUIRES: riscv-registered-target
+// RUN: %clang_cc1 -triple riscv64 -target-feature +v \
+// RUN:   -emit-llvm %s -o - | FileCheck -check-prefix=CHECK-LLVM %s
+
+#include <riscv_vector.h>
+
+// CHECK-LLVM: void @call1(<vscale x 4 x i32> %v0, <vscale x 8 x i32> %v1.coerce0, <vscale x 8 x i32> %v1.coerce1, <vscale x 8 x i32> %v2, <vscale x 2 x i32> %v3)
+void call1(vint32m2_t v0, vint32m4x2_t v1, vint32m4_t v2, vint32m1_t v3) {}
+
+// CHECK-LLVM: void @call2(<vscale x 2 x i32> %v0.coerce0, <vscale x 2 x i32> %v0.coerce1, <vscale x 2 x i32> %v0.coerce2, <vscale x 8 x i32> %v1.coerce0, <vscale x 8 x i32> %v1.coerce1, <vscale x 8 x i32> %v2, ptr noundef %0)
+void call2(vint32m1x3_t v0, vint32m4x2_t v1, vint32m4_t v2, vint32m2_t v3) {}
+
+// CHECK-LLVM: void @call3(<vscale x 8 x i32> %v0.coerce0, <vscale x 8 x i32> %v0.coerce1, ptr noundef %0, <vscale x 8 x i32> %v2.coerce0, <vscale x 8 x i32> %v2.coerce1)
+void call3(vint32m4x2_t v0, vint32m1_t v1, vint32m4x2_t v2) {}
+
+// CHECK-LLVM: void @call4(<vscale x 16 x i32> %v0, ptr noundef %0, <vscale x 16 x i32> %v2)
+void call4(vint32m8_t v0, vint32m1_t v1, vint32m8_t v2) {}
+
+// CHECK-LLVM: void @call5(ptr noundef %0, <vscale x 16 x i32> %v1, ptr noundef %1, <vscale x 16 x i32> %v3)
+void call5(vint32m1_t v0, vint32m8_t v1, vint32m1_t v2, vint32m8_t v3) {}
+
+// CHECK-LLVM: void @call6(<vscale x 1 x i8> %v0, <vscale x 64 x i8> %v1, <vscale x 2 x i32> %v2, <vscale x 1 x i8> %v3)
+void call6(vint8mf8_t v0, vint8m8_t v1, vint32m1_t v2, vint8mf8_t v3) {}
+
+// CHECK-LLVM: void @call7(ptr noundef %0, <vscale x 64 x i8> %v1, <vscale x 16 x i32> %v2, ptr noundef %1)
+void call7(vint8mf8_t v0, vint8m8_t v1, vint32m8_t v2, vint8mf8_t v3) {}
Index: clang/lib/CodeGen/Targets/RISCV.cpp
===================================================================
--- clang/lib/CodeGen/Targets/RISCV.cpp
+++ clang/lib/CodeGen/Targets/RISCV.cpp
@@ -8,6 +8,7 @@
 
 #include "ABIInfoImpl.h"
 #include "TargetInfo.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
 
 using namespace clang;
 using namespace clang::CodeGen;
@@ -19,6 +20,9 @@
 namespace {
 class RISCVABIInfo : public DefaultABIInfo {
 private:
+  using ArgRegPair = std::pair<CGFunctionInfoArgInfo *, unsigned>;
+  using ArgRegPairs = llvm::SmallVector<ArgRegPair>;
+
   // Size of the integer ('x') registers in bits.
   unsigned XLen;
   // Size of the floating point ('f') registers in bits. Note that the target
@@ -27,11 +31,15 @@
   unsigned FLen;
   static const int NumArgGPRs = 8;
   static const int NumArgFPRs = 8;
+  static const int NumArgVRs = 16;
   bool detectFPCCEligibleStructHelper(QualType Ty, CharUnits CurOff,
                                       llvm::Type *&Field1Ty,
                                       CharUnits &Field1Off,
                                       llvm::Type *&Field2Ty,
                                       CharUnits &Field2Off) const;
+  unsigned
+  computeMaxAssignedRegs(ArgRegPairs &RVVArgRegPairs,
+                         std::vector<std::vector<unsigned>> &MaxRegs) const;
 
 public:
   RISCVABIInfo(CodeGen::CodeGenTypes &CGT, unsigned XLen, unsigned FLen)
@@ -41,6 +49,9 @@
   // non-virtual, but computeInfo is virtual, so we overload it.
   void computeInfo(CGFunctionInfo &FI) const override;
 
+  ArgRegPairs calculateRVVArgVRegs(CGFunctionInfo &FI) const;
+  void classifyRVVArgumentType(ArgRegPairs RVVArgRegPairs) const;
+
   ABIArgInfo classifyArgumentType(QualType Ty, bool IsFixed, int &ArgGPRsLeft,
                                   int &ArgFPRsLeft) const;
   ABIArgInfo classifyReturnType(QualType RetTy) const;
@@ -92,9 +103,98 @@
   int ArgNum = 0;
   for (auto &ArgInfo : FI.arguments()) {
     bool IsFixed = ArgNum < NumFixedArgs;
+    ArgNum++;
+
+    if (ArgInfo.type.getTypePtr()->isRVVType())
+      continue;
+
     ArgInfo.info =
         classifyArgumentType(ArgInfo.type, IsFixed, ArgGPRsLeft, ArgFPRsLeft);
-    ArgNum++;
+  }
+
+  classifyRVVArgumentType(calculateRVVArgVRegs(FI));
+}
+
+// Calculate total vregs each RVV argument needs.
+RISCVABIInfo::ArgRegPairs
+RISCVABIInfo::calculateRVVArgVRegs(CGFunctionInfo &FI) const {
+  RISCVABIInfo::ArgRegPairs RVVArgRegPairs;
+  for (auto &ArgInfo : FI.arguments()) {
+    const QualType &Ty = ArgInfo.type;
+    if (!Ty->isRVVType())
+      continue;
+
+    // Calcluate the registers needed for each RVV type.
+    unsigned ElemSize = Ty->isRVVType(8, false)    ? 8
+                        : Ty->isRVVType(16, false) ? 16
+                        : Ty->isRVVType(32, false) ? 32
+                                                   : 64;
+    unsigned ElemCount = Ty->isRVVType(1)    ? 1
+                         : Ty->isRVVType(2)  ? 2
+                         : Ty->isRVVType(4)  ? 4
+                         : Ty->isRVVType(8)  ? 8
+                         : Ty->isRVVType(16) ? 16
+                         : Ty->isRVVType(32) ? 32
+                                             : 64;
+    unsigned RegsPerGroup =
+        std::max((ElemSize * ElemCount) / llvm::RISCV::RVVBitsPerBlock, 1U);
+
+    unsigned NumGroups = 1;
+    if (Ty->isRVVTupleType())
+      // Get the number of groups(NF) for each RVV type.
+      NumGroups = Ty->isRVVTupleType(2)   ? 2
+                  : Ty->isRVVTupleType(3) ? 3
+                  : Ty->isRVVTupleType(4) ? 4
+                  : Ty->isRVVTupleType(5) ? 5
+                  : Ty->isRVVTupleType(6) ? 6
+                  : Ty->isRVVTupleType(7) ? 7
+                                          : 8;
+
+    RVVArgRegPairs.push_back(
+        std::make_pair(&ArgInfo, NumGroups * RegsPerGroup));
+  }
+
+  return RVVArgRegPairs;
+}
+
+// Dynamic programming approach for finding the best vector register usages.
+// We can deduce the problem to 0/1 knapsack problem with:
+//   1. capacity == NumArgVRs
+//   2. weight == value == total VRs needed
+unsigned RISCVABIInfo::computeMaxAssignedRegs(
+    ArgRegPairs &RVVArgRegPairs,
+    std::vector<std::vector<unsigned>> &MaxRegs) const {
+  for (unsigned i = 1; i <= RVVArgRegPairs.size(); ++i) {
+    unsigned RegsNeeded = RVVArgRegPairs[i - 1].second;
+    for (unsigned j = 1; j <= NumArgVRs; ++j)
+      if (j < RegsNeeded)
+        MaxRegs[i][j] = MaxRegs[i - 1][j];
+      else
+        MaxRegs[i][j] = std::max(RegsNeeded + MaxRegs[i - 1][j - RegsNeeded],
+                                 MaxRegs[i - 1][j]);
+  }
+
+  return MaxRegs[RVVArgRegPairs.size()][NumArgVRs];
+}
+
+void RISCVABIInfo::classifyRVVArgumentType(ArgRegPairs RVVArgRegPairs) const {
+  unsigned ToBeAssigned = RVVArgRegPairs.size();
+  std::vector<std::vector<unsigned>> MaxRegs(
+      ToBeAssigned + 1, std::vector<unsigned>(NumArgVRs + 1, 0));
+  computeMaxAssignedRegs(RVVArgRegPairs, MaxRegs);
+
+  // Walk back through MaxRegs to determine which argument is passed by
+  // register.
+  unsigned RegsLeft = NumArgVRs;
+  while (ToBeAssigned--) {
+    auto *ArgInfo = RVVArgRegPairs[ToBeAssigned].first;
+    if (!RegsLeft ||
+        MaxRegs[ToBeAssigned + 1][RegsLeft] == MaxRegs[ToBeAssigned][RegsLeft])
+      ArgInfo->info = getNaturalAlignIndirect(ArgInfo->type, /*ByVal=*/false);
+    else {
+      ArgInfo->info = ABIArgInfo::getDirect();
+      RegsLeft -= RVVArgRegPairs[ToBeAssigned].second;
+    }
   }
 }
 
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -2332,6 +2332,9 @@
 
   bool isRVVType(unsigned Bitwidth, bool IsFloat) const;
 
+  bool isRVVTupleType() const;
+  bool isRVVTupleType(unsigned NumGroups) const;
+
   /// Return the implicit lifetime for this type, which must not be dependent.
   Qualifiers::ObjCLifetime getObjCARCImplicitLifetime() const;
 
@@ -7279,6 +7282,25 @@
   return Ret;
 }
 
+inline bool Type::isRVVTupleType() const {
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
+                        IsFP)                                                  \
+  (isSpecificBuiltinType(BuiltinType::Id) && NF != 1) ||
+  return
+#include "clang/Basic/RISCVVTypes.def"
+      false; // end of boolean or operation.
+}
+
+inline bool Type::isRVVTupleType(unsigned NumGroups) const {
+  bool Ret = false;
+#define RVV_VECTOR_TYPE(Name, Id, SingletonId, NumEls, ElBits, NF, IsSigned,   \
+                        IsFP)                                                  \
+  if (NF == NumGroups)                                                         \
+    Ret |= isSpecificBuiltinType(BuiltinType::Id);
+#include "clang/Basic/RISCVVTypes.def"
+  return Ret;
+}
+
 inline bool Type::isTemplateTypeParmType() const {
   return isa<TemplateTypeParmType>(CanonicalType);
 }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to