kito-cheng created this revision.
Herald added subscribers: sunshaoce, VincentWu, luke957, vkmr, frasercrmck, 
evandro, luismarques, apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, 
jocewei, PkmX, the_o, brucehoult, MartinMosbeck, rogfer01, edward-jones, 
zzheng, jrtc27, niosHD, sabuasal, simoncook, johnrusso, rbar, asb, arichardson.
Herald added a project: All.
kito-cheng requested review of this revision.
Herald added subscribers: cfe-commits, pcwang-thead, eopXD, MaskRay.
Herald added a project: clang.

This patch is preparation for D111617 <https://reviews.llvm.org/D111617>, use 
class/struct/enum rather than
char/StringRef to present internal information as possible, that provide
more compact way to store those info and also easier to
serialize/deserialize.

And also that improve readability of the code, e.g. "v" vs
TypeProfile::Vector.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D124730

Files:
  clang/include/clang/Support/RISCVVIntrinsicUtils.h
  clang/lib/Support/RISCVVIntrinsicUtils.cpp
  clang/utils/TableGen/RISCVVEmitter.cpp

Index: clang/utils/TableGen/RISCVVEmitter.cpp
===================================================================
--- clang/utils/TableGen/RISCVVEmitter.cpp
+++ clang/utils/TableGen/RISCVVEmitter.cpp
@@ -48,7 +48,7 @@
   /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
   void createCodeGen(raw_ostream &o);
 
-  std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
+  std::string getSuffixStr(BasicType Type, int Log2LMUL, StringRef Prototypes);
 
 private:
   /// Create all intrinsics and add them to \p Out
@@ -60,8 +60,9 @@
   /// or illegal set to avoid compute the  same config again. The result maybe
   /// have illegal RVVType.
   Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
-                                  ArrayRef<std::string> PrototypeSeq);
-  Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
+                                  ArrayRef<TypeProfile> PrototypeSeq);
+  Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL,
+                                   TypeProfile Proto);
 
   /// Emit Acrh predecessor definitions and body, assume the element of Defs are
   /// sorted by extension.
@@ -76,11 +77,40 @@
   // Slice Prototypes string into sub prototype string and process each sub
   // prototype string individually in the Handler.
   void parsePrototypes(StringRef Prototypes,
-                       std::function<void(StringRef)> Handler);
+                       std::function<void(TypeProfile)> Handler);
 };
 
 } // namespace
 
+static BasicType ParseBasicType(char c) {
+  switch (c) {
+  case 'c':
+    return BasicType::Int8;
+    break;
+  case 's':
+    return BasicType::Int16;
+    break;
+  case 'i':
+    return BasicType::Int32;
+    break;
+  case 'l':
+    return BasicType::Int64;
+    break;
+  case 'x':
+    return BasicType::Float16;
+    break;
+  case 'f':
+    return BasicType::Float32;
+    break;
+  case 'd':
+    return BasicType::Float64;
+    break;
+
+  default:
+    return BasicType::Unknown;
+  }
+}
+
 void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) {
   if (!RVVI->getIRName().empty())
     OS << "  ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n";
@@ -202,24 +232,27 @@
   constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
   // Print RVV boolean types.
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('c', Log2LMUL, "m");
+    auto T = computeType(BasicType::Int8, Log2LMUL, TypeProfile::Mask);
     if (T.hasValue())
       printType(T.getValue());
   }
   // Print RVV int/float types.
   for (char I : StringRef("csil")) {
+    BasicType BT = ParseBasicType(I);
     for (int Log2LMUL : Log2LMULs) {
-      auto T = computeType(I, Log2LMUL, "v");
+      auto T = computeType(BT, Log2LMUL, TypeProfile::Vector);
       if (T.hasValue()) {
         printType(T.getValue());
-        auto UT = computeType(I, Log2LMUL, "Uv");
+        auto UT = computeType(
+            BT, Log2LMUL,
+            TypeProfile(PrimitiveType::Vector, TypeModifier::UnsignedInteger));
         printType(UT.getValue());
       }
     }
   }
   OS << "#if defined(__riscv_zvfh)\n";
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('x', Log2LMUL, "v");
+    auto T = computeType(BasicType::Float16, Log2LMUL, TypeProfile::Vector);
     if (T.hasValue())
       printType(T.getValue());
   }
@@ -227,7 +260,7 @@
 
   OS << "#if defined(__riscv_f)\n";
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('f', Log2LMUL, "v");
+    auto T = computeType(BasicType::Float32, Log2LMUL, TypeProfile::Vector);
     if (T.hasValue())
       printType(T.getValue());
   }
@@ -235,7 +268,7 @@
 
   OS << "#if defined(__riscv_d)\n";
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('d', Log2LMUL, "v");
+    auto T = computeType(BasicType::Float64, Log2LMUL, TypeProfile::Vector);
     if (T.hasValue())
       printType(T.getValue());
   }
@@ -360,7 +393,7 @@
 }
 
 void RVVEmitter::parsePrototypes(StringRef Prototypes,
-                                 std::function<void(StringRef)> Handler) {
+                                 std::function<void(TypeProfile)> Handler) {
   const StringRef Primaries("evwqom0ztul");
   while (!Prototypes.empty()) {
     size_t Idx = 0;
@@ -370,15 +403,18 @@
       Idx = Prototypes.find_first_of(')');
     Idx = Prototypes.find_first_of(Primaries, Idx);
     assert(Idx != StringRef::npos);
-    Handler(Prototypes.slice(0, Idx + 1));
+    auto TP = TypeProfile::parseTypeProfile(Prototypes.slice(0, Idx + 1));
+    if (!TP)
+      PrintFatalError("Error during parsing prototype.");
+    Handler(*TP);
     Prototypes = Prototypes.drop_front(Idx + 1);
   }
 }
 
-std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
+std::string RVVEmitter::getSuffixStr(BasicType Type, int Log2LMUL,
                                      StringRef Prototypes) {
   SmallVector<std::string> SuffixStrs;
-  parsePrototypes(Prototypes, [&](StringRef Proto) {
+  parsePrototypes(Prototypes, [&](TypeProfile Proto) {
     auto T = computeType(Type, Log2LMUL, Proto);
     SuffixStrs.push_back(T.getValue()->getShortStr());
   });
@@ -419,13 +455,13 @@
 
     // Parse prototype and create a list of primitive type with transformers
     // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
-    SmallVector<std::string> ProtoSeq;
-    parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
-      ProtoSeq.push_back(Proto.str());
+    SmallVector<TypeProfile> ProtoSeq;
+    parsePrototypes(Prototypes, [&ProtoSeq](TypeProfile Proto) {
+      ProtoSeq.push_back(Proto);
     });
 
     // Compute Builtin types
-    SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
+    SmallVector<TypeProfile> ProtoMaskSeq = ProtoSeq;
     if (HasMasked) {
       // If HasMaskedOffOperand, insert result type as first input operand.
       if (HasMaskedOffOperand) {
@@ -436,10 +472,10 @@
           // (void, op0 address, op1 address, ...)
           // to
           // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
+          TypeProfile MaskoffType = ProtoSeq[1];
+          MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
           for (unsigned I = 0; I < NF; ++I)
-            ProtoMaskSeq.insert(
-                ProtoMaskSeq.begin() + NF + 1,
-                ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
+            ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType);
         }
       }
       if (HasMaskedOffOperand && NF > 1) {
@@ -448,28 +484,29 @@
         // to
         // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
         // ...)
-        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
+        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, TypeProfile::Mask);
       } else {
-        // If HasMasked, insert 'm' as first input operand.
-        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
+        // If HasMasked, insert TypeProfile:Mask as first input operand.
+        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, TypeProfile::Mask);
       }
     }
-    // If HasVL, append 'z' to last operand
+    // If HasVL, append TypeProfile:VL to last operand
     if (HasVL) {
-      ProtoSeq.push_back("z");
-      ProtoMaskSeq.push_back("z");
+      ProtoSeq.push_back(TypeProfile::VL);
+      ProtoMaskSeq.push_back(TypeProfile::VL);
     }
 
     // Create Intrinsics for each type and LMUL.
     for (char I : TypeRange) {
       for (int Log2LMUL : Log2LMULList) {
-        Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
+        BasicType BT = ParseBasicType(I);
+        Optional<RVVTypes> Types = computeTypes(BT, Log2LMUL, NF, ProtoSeq);
         // Ignored to create new intrinsic if there are any illegal types.
         if (!Types.hasValue())
           continue;
 
-        auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
-        auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
+        auto SuffixStr = getSuffixStr(BT, Log2LMUL, SuffixProto);
+        auto MangledSuffixStr = getSuffixStr(BT, Log2LMUL, MangledSuffixProto);
         // Create a unmasked intrinsic
         Out.push_back(std::make_unique<RVVIntrinsic>(
             Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
@@ -480,7 +517,7 @@
         if (HasMasked) {
           // Create a masked intrinsic
           Optional<RVVTypes> MaskTypes =
-              computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
+              computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq);
           Out.push_back(std::make_unique<RVVIntrinsic>(
               Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName,
               /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy,
@@ -503,13 +540,13 @@
 
 Optional<RVVTypes>
 RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
-                         ArrayRef<std::string> PrototypeSeq) {
+                         ArrayRef<TypeProfile> PrototypeSeq) {
   // LMUL x NF must be less than or equal to 8.
   if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
     return llvm::None;
 
   RVVTypes Types;
-  for (const std::string &Proto : PrototypeSeq) {
+  for (const TypeProfile &Proto : PrototypeSeq) {
     auto T = computeType(BT, Log2LMUL, Proto);
     if (!T.hasValue())
       return llvm::None;
@@ -520,8 +557,10 @@
 }
 
 Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
-                                             StringRef Proto) {
-  std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
+                                             TypeProfile Proto) {
+  std::string Idx =
+      Twine(Twine(static_cast<int>(BT)) + Twine(Log2LMUL) + Proto.IndexStr())
+          .str();
   // Search first
   auto It = LegalTypes.find(Idx);
   if (It != LegalTypes.end())
Index: clang/lib/Support/RISCVVIntrinsicUtils.cpp
===================================================================
--- clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -22,6 +22,9 @@
 namespace clang {
 namespace RISCV {
 
+const TypeProfile TypeProfile::Mask = TypeProfile(PrimitiveType::MaskVector);
+const TypeProfile TypeProfile::VL = TypeProfile(PrimitiveType::SizeT);
+const TypeProfile TypeProfile::Vector = TypeProfile(PrimitiveType::Vector);
 //===----------------------------------------------------------------------===//
 // Type implementation
 //===----------------------------------------------------------------------===//
@@ -70,7 +73,7 @@
   return *this;
 }
 
-RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
+RVVType::RVVType(BasicType BT, int Log2LMUL, const TypeProfile &prototype)
     : BT(BT), LMUL(LMULType(Log2LMUL)) {
   applyBasicType();
   applyModifier(prototype);
@@ -326,31 +329,31 @@
 
 void RVVType::applyBasicType() {
   switch (BT) {
-  case 'c':
+  case BasicType::Int8:
     ElementBitwidth = 8;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 's':
+  case BasicType::Int16:
     ElementBitwidth = 16;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 'i':
+  case BasicType::Int32:
     ElementBitwidth = 32;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 'l':
+  case BasicType::Int64:
     ElementBitwidth = 64;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 'x':
+  case BasicType::Float16:
     ElementBitwidth = 16;
     ScalarType = ScalarTypeKind::Float;
     break;
-  case 'f':
+  case BasicType::Float32:
     ElementBitwidth = 32;
     ScalarType = ScalarTypeKind::Float;
     break;
-  case 'd':
+  case BasicType::Float64:
     ElementBitwidth = 64;
     ScalarType = ScalarTypeKind::Float;
     break;
@@ -360,160 +363,417 @@
   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
 }
 
-void RVVType::applyModifier(StringRef Transformer) {
-  if (Transformer.empty())
-    return;
+Optional<TypeProfile>
+TypeProfile::parseTypeProfile(llvm::StringRef TypeProfileStr) {
+  TypeProfile TP;
+  PrimitiveType PT = PrimitiveType::Invalid;
+  if (TypeProfileStr.empty())
+    return TP;
   // Handle primitive type transformer
-  auto PType = Transformer.back();
+  auto PType = TypeProfileStr.back();
   switch (PType) {
   case 'e':
-    Scale = 0;
+    PT = PrimitiveType::Scalar;
     break;
   case 'v':
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Vector;
     break;
   case 'w':
-    ElementBitwidth *= 2;
-    LMUL *= 2;
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Widening2XVector;
     break;
   case 'q':
-    ElementBitwidth *= 4;
-    LMUL *= 4;
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Widening4XVector;
     break;
   case 'o':
-    ElementBitwidth *= 8;
-    LMUL *= 8;
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Widening8XVector;
     break;
   case 'm':
-    ScalarType = ScalarTypeKind::Boolean;
-    Scale = LMUL.getScale(ElementBitwidth);
-    ElementBitwidth = 1;
+    PT = PrimitiveType::MaskVector;
     break;
   case '0':
-    ScalarType = ScalarTypeKind::Void;
+    PT = PrimitiveType::Void;
     break;
   case 'z':
-    ScalarType = ScalarTypeKind::Size_t;
+    PT = PrimitiveType::SizeT;
     break;
   case 't':
-    ScalarType = ScalarTypeKind::Ptrdiff_t;
+    PT = PrimitiveType::Ptrdiff;
     break;
   case 'u':
-    ScalarType = ScalarTypeKind::UnsignedLong;
+    PT = PrimitiveType::UnsignedLong;
     break;
   case 'l':
-    ScalarType = ScalarTypeKind::SignedLong;
+    PT = PrimitiveType::SignedLong;
     break;
   default:
     llvm_unreachable("Illegal primitive type transformers!");
   }
-  Transformer = Transformer.drop_back();
+  TP.PT = static_cast<uint8_t>(PT);
+  TypeProfileStr = TypeProfileStr.drop_back();
 
   // Extract and compute complex type transformer. It can only appear one time.
-  if (Transformer.startswith("(")) {
-    size_t Idx = Transformer.find(')');
+  if (TypeProfileStr.startswith("(")) {
+    size_t Idx = TypeProfileStr.find(')');
     assert(Idx != StringRef::npos);
-    StringRef ComplexType = Transformer.slice(1, Idx);
-    Transformer = Transformer.drop_front(Idx + 1);
-    assert(!Transformer.contains('(') &&
+    StringRef ComplexType = TypeProfileStr.slice(1, Idx);
+    TypeProfileStr = TypeProfileStr.drop_front(Idx + 1);
+    assert(!TypeProfileStr.contains('(') &&
            "Only allow one complex type transformer");
 
-    auto UpdateAndCheckComplexProto = [&]() {
-      Scale = LMUL.getScale(ElementBitwidth);
-      const StringRef VectorPrototypes("vwqom");
-      if (!VectorPrototypes.contains(PType))
-        llvm_unreachable("Complex type transformer only supports vector type!");
-      if (Transformer.find_first_of("PCKWS") != StringRef::npos)
-        llvm_unreachable(
-            "Illegal type transformer for Complex type transformer");
-    };
-    auto ComputeFixedLog2LMUL =
-        [&](StringRef Value,
-            std::function<bool(const int32_t &, const int32_t &)> Compare) {
-          int32_t Log2LMUL;
-          Value.getAsInteger(10, Log2LMUL);
-          if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
-            ScalarType = Invalid;
-            return false;
-          }
-          // Update new LMUL
-          LMUL = LMULType(Log2LMUL);
-          UpdateAndCheckComplexProto();
-          return true;
-        };
     auto ComplexTT = ComplexType.split(":");
+    VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
     if (ComplexTT.first == "Log2EEW") {
       uint32_t Log2EEW;
-      ComplexTT.second.getAsInteger(10, Log2EEW);
-      // update new elmul = (eew/sew) * lmul
-      LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
-      // update new eew
-      ElementBitwidth = 1 << Log2EEW;
-      ScalarType = ScalarTypeKind::SignedInteger;
-      UpdateAndCheckComplexProto();
+      if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
+        llvm_unreachable("Invalid Log2EEW value!");
+        return None;
+      }
+      switch (Log2EEW) {
+      case 3:
+        VTM = VectorTypeModifier::Log2EEW3;
+        break;
+      case 4:
+        VTM = VectorTypeModifier::Log2EEW4;
+        break;
+      case 5:
+        VTM = VectorTypeModifier::Log2EEW5;
+        break;
+      case 6:
+        VTM = VectorTypeModifier::Log2EEW6;
+        break;
+      default:
+        llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
+        return None;
+      }
     } else if (ComplexTT.first == "FixedSEW") {
       uint32_t NewSEW;
-      ComplexTT.second.getAsInteger(10, NewSEW);
-      // Set invalid type if src and dst SEW are same.
-      if (ElementBitwidth == NewSEW) {
-        ScalarType = Invalid;
-        return;
+      if (ComplexTT.second.getAsInteger(10, NewSEW)) {
+        llvm_unreachable("Invalid FixedSEW value!");
+        return None;
+      }
+      switch (NewSEW) {
+      case 8:
+        VTM = VectorTypeModifier::FixedSEW8;
+        break;
+      case 16:
+        VTM = VectorTypeModifier::FixedSEW16;
+        break;
+      case 32:
+        VTM = VectorTypeModifier::FixedSEW32;
+        break;
+      case 64:
+        VTM = VectorTypeModifier::FixedSEW64;
+        break;
+      default:
+        llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
+        return None;
       }
-      // Update new SEW
-      ElementBitwidth = NewSEW;
-      UpdateAndCheckComplexProto();
     } else if (ComplexTT.first == "LFixedLog2LMUL") {
-      // New LMUL should be larger than old
-      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
-        return;
+      int32_t Log2LMUL;
+      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
+        llvm_unreachable("Invalid LFixedLog2LMUL value!");
+        return None;
+      }
+      switch (Log2LMUL) {
+      case -3:
+        VTM = VectorTypeModifier::LFixedLog2LMULN3;
+        break;
+      case -2:
+        VTM = VectorTypeModifier::LFixedLog2LMULN2;
+        break;
+      case -1:
+        VTM = VectorTypeModifier::LFixedLog2LMULN1;
+        break;
+      case 0:
+        VTM = VectorTypeModifier::LFixedLog2LMUL0;
+        break;
+      case 1:
+        VTM = VectorTypeModifier::LFixedLog2LMUL1;
+        break;
+      case 2:
+        VTM = VectorTypeModifier::LFixedLog2LMUL2;
+        break;
+      case 3:
+        VTM = VectorTypeModifier::LFixedLog2LMUL3;
+        break;
+      default:
+        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
+        return None;
+      }
     } else if (ComplexTT.first == "SFixedLog2LMUL") {
-      // New LMUL should be smaller than old
-      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
-        return;
+      int32_t Log2LMUL;
+      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
+        llvm_unreachable("Invalid SFixedLog2LMUL value!");
+        return None;
+      }
+      switch (Log2LMUL) {
+      case -3:
+        VTM = VectorTypeModifier::SFixedLog2LMULN3;
+        break;
+      case -2:
+        VTM = VectorTypeModifier::SFixedLog2LMULN2;
+        break;
+      case -1:
+        VTM = VectorTypeModifier::SFixedLog2LMULN1;
+        break;
+      case 0:
+        VTM = VectorTypeModifier::SFixedLog2LMUL0;
+        break;
+      case 1:
+        VTM = VectorTypeModifier::SFixedLog2LMUL1;
+        break;
+      case 2:
+        VTM = VectorTypeModifier::SFixedLog2LMUL2;
+        break;
+      case 3:
+        VTM = VectorTypeModifier::SFixedLog2LMUL3;
+        break;
+      default:
+        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
+        return None;
+      }
+
     } else {
       llvm_unreachable("Illegal complex type transformers!");
     }
+    TP.VTM = static_cast<uint8_t>(VTM);
   }
 
   // Compute the remain type transformers
-  for (char I : Transformer) {
+  TypeModifier TM = TypeModifier::NoModifier;
+  for (char I : TypeProfileStr) {
     switch (I) {
     case 'P':
-      if (IsConstant)
+      if ((TM & TypeModifier::Const) == TypeModifier::Const)
         llvm_unreachable("'P' transformer cannot be used after 'C'");
-      if (IsPointer)
+      if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
         llvm_unreachable("'P' transformer cannot be used twice");
-      IsPointer = true;
+      TM |= TypeModifier::Pointer;
       break;
     case 'C':
-      if (IsConstant)
-        llvm_unreachable("'C' transformer cannot be used twice");
-      IsConstant = true;
+      TM |= TypeModifier::Const;
       break;
     case 'K':
-      IsImmediate = true;
+      TM |= TypeModifier::Immediate;
       break;
     case 'U':
-      ScalarType = ScalarTypeKind::UnsignedInteger;
+      TM |= TypeModifier::UnsignedInteger;
       break;
     case 'I':
-      ScalarType = ScalarTypeKind::SignedInteger;
+      TM |= TypeModifier::SignedInteger;
       break;
     case 'F':
-      ScalarType = ScalarTypeKind::Float;
+      TM |= TypeModifier::Float;
       break;
     case 'S':
+      TM |= TypeModifier::LMUL1;
+      break;
+    default:
+      llvm_unreachable("Illegal non-primitive type transformer!");
+    }
+  }
+  TP.TM = static_cast<uint8_t>(TM);
+
+  return TP;
+}
+
+void RVVType::applyModifier(const TypeProfile &Transformer) {
+  // Handle primitive type transformer
+  switch (static_cast<PrimitiveType>(Transformer.PT)) {
+  case PrimitiveType::Scalar:
+    Scale = 0;
+    break;
+  case PrimitiveType::Vector:
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::Widening2XVector:
+    ElementBitwidth *= 2;
+    LMUL *= 2;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::Widening4XVector:
+    ElementBitwidth *= 4;
+    LMUL *= 4;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::Widening8XVector:
+    ElementBitwidth *= 8;
+    LMUL *= 8;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::MaskVector:
+    ScalarType = ScalarTypeKind::Boolean;
+    Scale = LMUL.getScale(ElementBitwidth);
+    ElementBitwidth = 1;
+    break;
+  case PrimitiveType::Void:
+    ScalarType = ScalarTypeKind::Void;
+    break;
+  case PrimitiveType::SizeT:
+    ScalarType = ScalarTypeKind::Size_t;
+    break;
+  case PrimitiveType::Ptrdiff:
+    ScalarType = ScalarTypeKind::Ptrdiff_t;
+    break;
+  case PrimitiveType::UnsignedLong:
+    ScalarType = ScalarTypeKind::UnsignedLong;
+    break;
+  case PrimitiveType::SignedLong:
+    ScalarType = ScalarTypeKind::SignedLong;
+    break;
+  case PrimitiveType::Invalid:
+    ScalarType = ScalarTypeKind::Invalid;
+    return;
+  default:
+    llvm_unreachable("Illegal primitive type transformers!");
+  }
+
+  switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
+  case VectorTypeModifier::Log2EEW3:
+    applyLog2EEW(3);
+    break;
+  case VectorTypeModifier::Log2EEW4:
+    applyLog2EEW(4);
+    break;
+  case VectorTypeModifier::Log2EEW5:
+    applyLog2EEW(5);
+    break;
+  case VectorTypeModifier::Log2EEW6:
+    applyLog2EEW(6);
+    break;
+  case VectorTypeModifier::FixedSEW8:
+    applyFixedSEW(8);
+    break;
+  case VectorTypeModifier::FixedSEW16:
+    applyFixedSEW(16);
+    break;
+  case VectorTypeModifier::FixedSEW32:
+    applyFixedSEW(32);
+    break;
+  case VectorTypeModifier::FixedSEW64:
+    applyFixedSEW(64);
+    break;
+  case VectorTypeModifier::LFixedLog2LMULN3:
+    applyFixedLog2LMUL(-3, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMULN2:
+    applyFixedLog2LMUL(-2, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMULN1:
+    applyFixedLog2LMUL(-1, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL0:
+    applyFixedLog2LMUL(0, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL1:
+    applyFixedLog2LMUL(1, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL2:
+    applyFixedLog2LMUL(2, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL3:
+    applyFixedLog2LMUL(3, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::SFixedLog2LMULN3:
+    applyFixedLog2LMUL(-3, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMULN2:
+    applyFixedLog2LMUL(-2, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMULN1:
+    applyFixedLog2LMUL(-1, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL0:
+    applyFixedLog2LMUL(0, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL1:
+    applyFixedLog2LMUL(1, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL2:
+    applyFixedLog2LMUL(2, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL3:
+    applyFixedLog2LMUL(3, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::NoModifier:
+    break;
+  default:
+    llvm_unreachable("Illegal vector type modifier!");
+  }
+
+  for (unsigned TypeModifierMaskShift = 0;
+       TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
+       ++TypeModifierMaskShift) {
+    unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
+    if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
+        TypeModifierMask)
+      continue;
+    switch (static_cast<TypeModifier>(TypeModifierMask)) {
+    case TypeModifier::Pointer:
+      IsPointer = true;
+      break;
+    case TypeModifier::Const:
+      IsConstant = true;
+      break;
+    case TypeModifier::Immediate:
+      IsImmediate = true;
+      IsConstant = true;
+      break;
+    case TypeModifier::UnsignedInteger:
+      ScalarType = ScalarTypeKind::UnsignedInteger;
+      break;
+    case TypeModifier::SignedInteger:
+      ScalarType = ScalarTypeKind::SignedInteger;
+      break;
+    case TypeModifier::Float:
+      ScalarType = ScalarTypeKind::Float;
+      break;
+    case TypeModifier::LMUL1:
       LMUL = LMULType(0);
       // Update ElementBitwidth need to update Scale too.
       Scale = LMUL.getScale(ElementBitwidth);
       break;
     default:
-      llvm_unreachable("Illegal non-primitive type transformer!");
+      llvm_unreachable("Unknown type modifier mask!");
+    }
+  }
+}
+
+void RVVType::applyLog2EEW(unsigned Log2EEW) {
+  // update new elmul = (eew/sew) * lmul
+  LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
+  // update new eew
+  ElementBitwidth = 1 << Log2EEW;
+  ScalarType = ScalarTypeKind::SignedInteger;
+  Scale = LMUL.getScale(ElementBitwidth);
+}
+
+void RVVType::applyFixedSEW(unsigned NewSEW) {
+  // Set invalid type if src and dst SEW are same.
+  if (ElementBitwidth == NewSEW) {
+    ScalarType = ScalarTypeKind::Invalid;
+    return;
+  }
+  // Update new SEW
+  ElementBitwidth = NewSEW;
+  Scale = LMUL.getScale(ElementBitwidth);
+}
+
+void RVVType::applyFixedLog2LMUL(int Log2LMUL, bool LargerThan) {
+  if (LargerThan) {
+    if (Log2LMUL < LMUL.Log2LMUL) {
+      ScalarType = ScalarTypeKind::Invalid;
+      return;
+    }
+  } else {
+    if (Log2LMUL > LMUL.Log2LMUL) {
+      ScalarType = ScalarTypeKind::Invalid;
+      return;
     }
   }
+  // Update new LMUL
+  LMUL = LMULType(Log2LMUL);
+  Scale = LMUL.getScale(ElementBitwidth);
 }
 
 //===----------------------------------------------------------------------===//
Index: clang/include/clang/Support/RISCVVIntrinsicUtils.h
===================================================================
--- clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -9,6 +9,7 @@
 #ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
 #define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
 
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/StringRef.h"
 #include <cstdint>
@@ -18,9 +19,126 @@
 namespace clang {
 namespace RISCV {
 
-using BasicType = char;
 using VScaleVal = llvm::Optional<unsigned>;
 
+// Modifier for vector type.
+enum class VectorTypeModifier : uint8_t {
+  NoModifier,
+  Log2EEW3,
+  Log2EEW4,
+  Log2EEW5,
+  Log2EEW6,
+  FixedSEW8,
+  FixedSEW16,
+  FixedSEW32,
+  FixedSEW64,
+  LFixedLog2LMULN3,
+  LFixedLog2LMULN2,
+  LFixedLog2LMULN1,
+  LFixedLog2LMUL0,
+  LFixedLog2LMUL1,
+  LFixedLog2LMUL2,
+  LFixedLog2LMUL3,
+  SFixedLog2LMULN3,
+  SFixedLog2LMULN2,
+  SFixedLog2LMULN1,
+  SFixedLog2LMUL0,
+  SFixedLog2LMUL1,
+  SFixedLog2LMUL2,
+  SFixedLog2LMUL3,
+};
+
+// Similar to basic type but used to describe what's kind of type related to
+// basic vector type, used to compute type info of arguments.
+enum class PrimitiveType : uint8_t {
+  Invalid,
+  Scalar,
+  Vector,
+  Widening2XVector,
+  Widening4XVector,
+  Widening8XVector,
+  MaskVector,
+  Void,
+  SizeT,
+  Ptrdiff,
+  UnsignedLong,
+  SignedLong,
+};
+
+// Modifier for type, used for both scalar and vector types.
+enum class TypeModifier : uint8_t {
+  NoModifier = 0,
+  Pointer = 1 << 0,
+  Const = 1 << 1,
+  Immediate = 1 << 2,
+  UnsignedInteger = 1 << 3,
+  SignedInteger = 1 << 4,
+  Float = 1 << 5,
+  LMUL1 = 1 << 6,
+  MaxOffset = 6,
+  LLVM_MARK_AS_BITMASK_ENUM(LMUL1),
+};
+
+// TypeProfile is used to compute type info of arguments or return value.
+struct TypeProfile {
+  constexpr TypeProfile() = default;
+  constexpr TypeProfile(PrimitiveType PT) : PT(static_cast<uint8_t>(PT)) {}
+  constexpr TypeProfile(PrimitiveType PT, TypeModifier TM)
+      : PT(static_cast<uint8_t>(PT)), TM(static_cast<uint8_t>(TM)) {}
+  constexpr TypeProfile(uint8_t PT, uint8_t VTM, uint8_t TM)
+      : PT(PT), VTM(VTM), TM(TM) {}
+
+  uint8_t PT = static_cast<uint8_t>(PrimitiveType::Invalid);
+  uint8_t VTM = static_cast<uint8_t>(VectorTypeModifier::NoModifier);
+  uint8_t TM = static_cast<uint8_t>(TypeModifier::NoModifier);
+
+  std::string IndexStr() const {
+    return std::to_string(PT) + "_" + std::to_string(VTM) + "_" +
+           std::to_string(TM);
+  };
+
+  bool operator!=(const TypeProfile &TP) const {
+    return TP.PT != PT || TP.VTM != VTM || TP.TM != TM;
+  }
+  bool operator>(const TypeProfile &TP) const {
+    return !(TP.PT <= PT && TP.VTM <= VTM && TP.TM <= TM);
+  }
+
+  static const TypeProfile Mask;
+  static const TypeProfile Vector;
+  static const TypeProfile VL;
+  static llvm::Optional<TypeProfile>
+  parseTypeProfile(llvm::StringRef PrototypeStr);
+};
+
+// Basic type of vector type.
+enum class BasicType : uint8_t {
+  Unknown = 0,
+  Int8 = 1 << 0,
+  Int16 = 1 << 1,
+  Int32 = 1 << 2,
+  Int64 = 1 << 3,
+  Float16 = 1 << 4,
+  Float32 = 1 << 5,
+  Float64 = 1 << 6,
+  MaxOffset = 6,
+  LLVM_MARK_AS_BITMASK_ENUM(Float64),
+};
+
+// Type of vector type.
+enum ScalarTypeKind : uint8_t {
+  Void,
+  Size_t,
+  Ptrdiff_t,
+  UnsignedLong,
+  SignedLong,
+  Boolean,
+  SignedInteger,
+  UnsignedInteger,
+  Float,
+  Invalid,
+};
+
 // Exponential LMUL
 struct LMULType {
   int Log2LMUL;
@@ -34,18 +152,6 @@
 
 // This class is compact representation of a valid and invalid RVVType.
 class RVVType {
-  enum ScalarTypeKind : uint32_t {
-    Void,
-    Size_t,
-    Ptrdiff_t,
-    UnsignedLong,
-    SignedLong,
-    Boolean,
-    SignedInteger,
-    UnsignedInteger,
-    Float,
-    Invalid,
-  };
   BasicType BT;
   ScalarTypeKind ScalarType = Invalid;
   LMULType LMUL;
@@ -64,8 +170,8 @@
   std::string ShortStr;
 
 public:
-  RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {}
-  RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype);
+  RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {}
+  RVVType(BasicType BT, int Log2LMUL, const TypeProfile &Profile);
 
   // Return the string representation of a type, which is an encoded string for
   // passing to the BUILTIN() macro in Builtins.def.
@@ -114,7 +220,11 @@
 
   // Applies a prototype modifier to the current type. The result maybe an
   // invalid type.
-  void applyModifier(llvm::StringRef prototype);
+  void applyModifier(const TypeProfile &prototype);
+
+  void applyLog2EEW(unsigned Log2EEW);
+  void applyFixedSEW(unsigned NewSEW);
+  void applyFixedLog2LMUL(int Log2LMUL, bool LargerThan);
 
   // Compute and record a string for legal type.
   void initBuiltinStr();
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to