kito-cheng updated this revision to Diff 426276.
kito-cheng added a comment.

Changes:

- Extract more utils functions to RISCVVIntrinsicUtils


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D124730/new/

https://reviews.llvm.org/D124730

Files:
  clang/include/clang/Support/RISCVVIntrinsicUtils.h
  clang/lib/Support/RISCVVIntrinsicUtils.cpp
  clang/utils/TableGen/RISCVVEmitter.cpp

Index: clang/utils/TableGen/RISCVVEmitter.cpp
===================================================================
--- clang/utils/TableGen/RISCVVEmitter.cpp
+++ clang/utils/TableGen/RISCVVEmitter.cpp
@@ -32,9 +32,6 @@
 class RVVEmitter {
 private:
   RecordKeeper &Records;
-  // Concat BasicType, LMUL and Proto as key
-  StringMap<RVVType> LegalTypes;
-  StringSet<> IllegalTypes;
 
 public:
   RVVEmitter(RecordKeeper &R) : Records(R) {}
@@ -48,20 +45,11 @@
   /// Emit all the information needed to map builtin -> LLVM IR intrinsic.
   void createCodeGen(raw_ostream &o);
 
-  std::string getSuffixStr(char Type, int Log2LMUL, StringRef Prototypes);
-
 private:
   /// Create all intrinsics and add them to \p Out
   void createRVVIntrinsics(std::vector<std::unique_ptr<RVVIntrinsic>> &Out);
   /// Print HeaderCode in RVVHeader Record to \p Out
   void printHeaderCode(raw_ostream &OS);
-  /// Compute output and input types by applying different config (basic type
-  /// and LMUL with type transformers). It also record result of type in legal
-  /// or illegal set to avoid compute the  same config again. The result maybe
-  /// have illegal RVVType.
-  Optional<RVVTypes> computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
-                                  ArrayRef<std::string> PrototypeSeq);
-  Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL, StringRef Proto);
 
   /// Emit Acrh predecessor definitions and body, assume the element of Defs are
   /// sorted by extension.
@@ -73,14 +61,39 @@
   // non-empty string.
   bool emitMacroRestrictionStr(RISCVPredefinedMacroT PredefinedMacros,
                                raw_ostream &o);
-  // Slice Prototypes string into sub prototype string and process each sub
-  // prototype string individually in the Handler.
-  void parsePrototypes(StringRef Prototypes,
-                       std::function<void(StringRef)> Handler);
 };
 
 } // namespace
 
+static BasicType ParseBasicType(char c) {
+  switch (c) {
+  case 'c':
+    return BasicType::Int8;
+    break;
+  case 's':
+    return BasicType::Int16;
+    break;
+  case 'i':
+    return BasicType::Int32;
+    break;
+  case 'l':
+    return BasicType::Int64;
+    break;
+  case 'x':
+    return BasicType::Float16;
+    break;
+  case 'f':
+    return BasicType::Float32;
+    break;
+  case 'd':
+    return BasicType::Float64;
+    break;
+
+  default:
+    return BasicType::Unknown;
+  }
+}
+
 void emitCodeGenSwitchBody(const RVVIntrinsic *RVVI, raw_ostream &OS) {
   if (!RVVI->getIRName().empty())
     OS << "  ID = Intrinsic::riscv_" + RVVI->getIRName() + ";\n";
@@ -202,24 +215,28 @@
   constexpr int Log2LMULs[] = {-3, -2, -1, 0, 1, 2, 3};
   // Print RVV boolean types.
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('c', Log2LMUL, "m");
+    auto T = RVVType::computeType(BasicType::Int8, Log2LMUL, TypeProfile::Mask);
     if (T.hasValue())
       printType(T.getValue());
   }
   // Print RVV int/float types.
   for (char I : StringRef("csil")) {
+    BasicType BT = ParseBasicType(I);
     for (int Log2LMUL : Log2LMULs) {
-      auto T = computeType(I, Log2LMUL, "v");
+      auto T = RVVType::computeType(BT, Log2LMUL, TypeProfile::Vector);
       if (T.hasValue()) {
         printType(T.getValue());
-        auto UT = computeType(I, Log2LMUL, "Uv");
+        auto UT = RVVType::computeType(
+            BT, Log2LMUL,
+            TypeProfile(PrimitiveType::Vector, TypeModifier::UnsignedInteger));
         printType(UT.getValue());
       }
     }
   }
   OS << "#if defined(__riscv_zvfh)\n";
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('x', Log2LMUL, "v");
+    auto T =
+        RVVType::computeType(BasicType::Float16, Log2LMUL, TypeProfile::Vector);
     if (T.hasValue())
       printType(T.getValue());
   }
@@ -227,7 +244,8 @@
 
   OS << "#if defined(__riscv_f)\n";
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('f', Log2LMUL, "v");
+    auto T =
+        RVVType::computeType(BasicType::Float32, Log2LMUL, TypeProfile::Vector);
     if (T.hasValue())
       printType(T.getValue());
   }
@@ -235,7 +253,8 @@
 
   OS << "#if defined(__riscv_d)\n";
   for (int Log2LMUL : Log2LMULs) {
-    auto T = computeType('d', Log2LMUL, "v");
+    auto T =
+        RVVType::computeType(BasicType::Float64, Log2LMUL, TypeProfile::Vector);
     if (T.hasValue())
       printType(T.getValue());
   }
@@ -359,32 +378,6 @@
   OS << "\n";
 }
 
-void RVVEmitter::parsePrototypes(StringRef Prototypes,
-                                 std::function<void(StringRef)> Handler) {
-  const StringRef Primaries("evwqom0ztul");
-  while (!Prototypes.empty()) {
-    size_t Idx = 0;
-    // Skip over complex prototype because it could contain primitive type
-    // character.
-    if (Prototypes[0] == '(')
-      Idx = Prototypes.find_first_of(')');
-    Idx = Prototypes.find_first_of(Primaries, Idx);
-    assert(Idx != StringRef::npos);
-    Handler(Prototypes.slice(0, Idx + 1));
-    Prototypes = Prototypes.drop_front(Idx + 1);
-  }
-}
-
-std::string RVVEmitter::getSuffixStr(char Type, int Log2LMUL,
-                                     StringRef Prototypes) {
-  SmallVector<std::string> SuffixStrs;
-  parsePrototypes(Prototypes, [&](StringRef Proto) {
-    auto T = computeType(Type, Log2LMUL, Proto);
-    SuffixStrs.push_back(T.getValue()->getShortStr());
-  });
-  return join(SuffixStrs, "_");
-}
-
 void RVVEmitter::createRVVIntrinsics(
     std::vector<std::unique_ptr<RVVIntrinsic>> &Out) {
   std::vector<Record *> RV = Records.getAllDerivedDefinitions("RVVBuiltin");
@@ -419,13 +412,14 @@
 
     // Parse prototype and create a list of primitive type with transformers
     // (operand) in ProtoSeq. ProtoSeq[0] is output operand.
-    SmallVector<std::string> ProtoSeq;
-    parsePrototypes(Prototypes, [&ProtoSeq](StringRef Proto) {
-      ProtoSeq.push_back(Proto.str());
-    });
+    SmallVector<TypeProfile> ProtoSeq = parsePrototypes(Prototypes);
+
+    SmallVector<TypeProfile> SuffixProtoSeq = parsePrototypes(SuffixProto);
+    SmallVector<TypeProfile> MangledSuffixProtoSeq =
+        parsePrototypes(MangledSuffixProto);
 
     // Compute Builtin types
-    SmallVector<std::string> ProtoMaskSeq = ProtoSeq;
+    SmallVector<TypeProfile> ProtoMaskSeq = ProtoSeq;
     if (HasMasked) {
       // If HasMaskedOffOperand, insert result type as first input operand.
       if (HasMaskedOffOperand) {
@@ -436,10 +430,10 @@
           // (void, op0 address, op1 address, ...)
           // to
           // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
+          TypeProfile MaskoffType = ProtoSeq[1];
+          MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
           for (unsigned I = 0; I < NF; ++I)
-            ProtoMaskSeq.insert(
-                ProtoMaskSeq.begin() + NF + 1,
-                ProtoSeq[1].substr(1)); // Use substr(1) to skip '*'
+            ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, MaskoffType);
         }
       }
       if (HasMaskedOffOperand && NF > 1) {
@@ -448,28 +442,32 @@
         // to
         // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
         // ...)
-        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, "m");
+        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + NF + 1, TypeProfile::Mask);
       } else {
-        // If HasMasked, insert 'm' as first input operand.
-        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, "m");
+        // If HasMasked, insert TypeProfile:Mask as first input operand.
+        ProtoMaskSeq.insert(ProtoMaskSeq.begin() + 1, TypeProfile::Mask);
       }
     }
-    // If HasVL, append 'z' to last operand
+    // If HasVL, append TypeProfile:VL to last operand
     if (HasVL) {
-      ProtoSeq.push_back("z");
-      ProtoMaskSeq.push_back("z");
+      ProtoSeq.push_back(TypeProfile::VL);
+      ProtoMaskSeq.push_back(TypeProfile::VL);
     }
 
     // Create Intrinsics for each type and LMUL.
     for (char I : TypeRange) {
       for (int Log2LMUL : Log2LMULList) {
-        Optional<RVVTypes> Types = computeTypes(I, Log2LMUL, NF, ProtoSeq);
+        BasicType BT = ParseBasicType(I);
+        Optional<RVVTypes> Types =
+            RVVType::computeTypes(BT, Log2LMUL, NF, ProtoSeq);
         // Ignored to create new intrinsic if there are any illegal types.
         if (!Types.hasValue())
           continue;
 
-        auto SuffixStr = getSuffixStr(I, Log2LMUL, SuffixProto);
-        auto MangledSuffixStr = getSuffixStr(I, Log2LMUL, MangledSuffixProto);
+        auto SuffixStr =
+            RVVIntrinsic::getSuffixStr(BT, Log2LMUL, SuffixProtoSeq);
+        auto MangledSuffixStr =
+            RVVIntrinsic::getSuffixStr(BT, Log2LMUL, MangledSuffixProtoSeq);
         // Create a unmasked intrinsic
         Out.push_back(std::make_unique<RVVIntrinsic>(
             Name, SuffixStr, MangledName, MangledSuffixStr, IRName,
@@ -480,7 +478,7 @@
         if (HasMasked) {
           // Create a masked intrinsic
           Optional<RVVTypes> MaskTypes =
-              computeTypes(I, Log2LMUL, NF, ProtoMaskSeq);
+              RVVType::computeTypes(BT, Log2LMUL, NF, ProtoMaskSeq);
           Out.push_back(std::make_unique<RVVIntrinsic>(
               Name, SuffixStr, MangledName, MangledSuffixStr, MaskedIRName,
               /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicy,
@@ -501,45 +499,6 @@
   }
 }
 
-Optional<RVVTypes>
-RVVEmitter::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
-                         ArrayRef<std::string> PrototypeSeq) {
-  // LMUL x NF must be less than or equal to 8.
-  if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
-    return llvm::None;
-
-  RVVTypes Types;
-  for (const std::string &Proto : PrototypeSeq) {
-    auto T = computeType(BT, Log2LMUL, Proto);
-    if (!T.hasValue())
-      return llvm::None;
-    // Record legal type index
-    Types.push_back(T.getValue());
-  }
-  return Types;
-}
-
-Optional<RVVTypePtr> RVVEmitter::computeType(BasicType BT, int Log2LMUL,
-                                             StringRef Proto) {
-  std::string Idx = Twine(Twine(BT) + Twine(Log2LMUL) + Proto).str();
-  // Search first
-  auto It = LegalTypes.find(Idx);
-  if (It != LegalTypes.end())
-    return &(It->second);
-  if (IllegalTypes.count(Idx))
-    return llvm::None;
-  // Compute type and record the result.
-  RVVType T(BT, Log2LMUL, Proto);
-  if (T.isValid()) {
-    // Record legal type index and value.
-    LegalTypes.insert({Idx, T});
-    return &(LegalTypes[Idx]);
-  }
-  // Record illegal type index.
-  IllegalTypes.insert(Idx);
-  return llvm::None;
-}
-
 void RVVEmitter::emitArchMacroAndBody(
     std::vector<std::unique_ptr<RVVIntrinsic>> &Defs, raw_ostream &OS,
     std::function<void(raw_ostream &, const RVVIntrinsic &)> PrintBody) {
Index: clang/lib/Support/RISCVVIntrinsicUtils.cpp
===================================================================
--- clang/lib/Support/RISCVVIntrinsicUtils.cpp
+++ clang/lib/Support/RISCVVIntrinsicUtils.cpp
@@ -22,6 +22,14 @@
 namespace clang {
 namespace RISCV {
 
+const TypeProfile TypeProfile::Mask = TypeProfile(PrimitiveType::MaskVector);
+const TypeProfile TypeProfile::VL = TypeProfile(PrimitiveType::SizeT);
+const TypeProfile TypeProfile::Vector = TypeProfile(PrimitiveType::Vector);
+
+// Concat BasicType, LMUL and Proto as key
+static StringMap<RVVType> LegalTypes;
+static StringSet<> IllegalTypes;
+
 //===----------------------------------------------------------------------===//
 // Type implementation
 //===----------------------------------------------------------------------===//
@@ -70,7 +78,7 @@
   return *this;
 }
 
-RVVType::RVVType(BasicType BT, int Log2LMUL, StringRef prototype)
+RVVType::RVVType(BasicType BT, int Log2LMUL, const TypeProfile &prototype)
     : BT(BT), LMUL(LMULType(Log2LMUL)) {
   applyBasicType();
   applyModifier(prototype);
@@ -326,31 +334,31 @@
 
 void RVVType::applyBasicType() {
   switch (BT) {
-  case 'c':
+  case BasicType::Int8:
     ElementBitwidth = 8;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 's':
+  case BasicType::Int16:
     ElementBitwidth = 16;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 'i':
+  case BasicType::Int32:
     ElementBitwidth = 32;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 'l':
+  case BasicType::Int64:
     ElementBitwidth = 64;
     ScalarType = ScalarTypeKind::SignedInteger;
     break;
-  case 'x':
+  case BasicType::Float16:
     ElementBitwidth = 16;
     ScalarType = ScalarTypeKind::Float;
     break;
-  case 'f':
+  case BasicType::Float32:
     ElementBitwidth = 32;
     ScalarType = ScalarTypeKind::Float;
     break;
-  case 'd':
+  case BasicType::Float64:
     ElementBitwidth = 64;
     ScalarType = ScalarTypeKind::Float;
     break;
@@ -360,162 +368,460 @@
   assert(ElementBitwidth != 0 && "Bad element bitwidth!");
 }
 
-void RVVType::applyModifier(StringRef Transformer) {
-  if (Transformer.empty())
-    return;
+Optional<TypeProfile>
+TypeProfile::parseTypeProfile(llvm::StringRef TypeProfileStr) {
+  TypeProfile TP;
+  PrimitiveType PT = PrimitiveType::Invalid;
+  if (TypeProfileStr.empty())
+    return TP;
   // Handle primitive type transformer
-  auto PType = Transformer.back();
+  auto PType = TypeProfileStr.back();
   switch (PType) {
   case 'e':
-    Scale = 0;
+    PT = PrimitiveType::Scalar;
     break;
   case 'v':
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Vector;
     break;
   case 'w':
-    ElementBitwidth *= 2;
-    LMUL *= 2;
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Widening2XVector;
     break;
   case 'q':
-    ElementBitwidth *= 4;
-    LMUL *= 4;
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Widening4XVector;
     break;
   case 'o':
-    ElementBitwidth *= 8;
-    LMUL *= 8;
-    Scale = LMUL.getScale(ElementBitwidth);
+    PT = PrimitiveType::Widening8XVector;
     break;
   case 'm':
-    ScalarType = ScalarTypeKind::Boolean;
-    Scale = LMUL.getScale(ElementBitwidth);
-    ElementBitwidth = 1;
+    PT = PrimitiveType::MaskVector;
     break;
   case '0':
-    ScalarType = ScalarTypeKind::Void;
+    PT = PrimitiveType::Void;
     break;
   case 'z':
-    ScalarType = ScalarTypeKind::Size_t;
+    PT = PrimitiveType::SizeT;
     break;
   case 't':
-    ScalarType = ScalarTypeKind::Ptrdiff_t;
+    PT = PrimitiveType::Ptrdiff;
     break;
   case 'u':
-    ScalarType = ScalarTypeKind::UnsignedLong;
+    PT = PrimitiveType::UnsignedLong;
     break;
   case 'l':
-    ScalarType = ScalarTypeKind::SignedLong;
+    PT = PrimitiveType::SignedLong;
     break;
   default:
     llvm_unreachable("Illegal primitive type transformers!");
   }
-  Transformer = Transformer.drop_back();
+  TP.PT = static_cast<uint8_t>(PT);
+  TypeProfileStr = TypeProfileStr.drop_back();
 
   // Extract and compute complex type transformer. It can only appear one time.
-  if (Transformer.startswith("(")) {
-    size_t Idx = Transformer.find(')');
+  if (TypeProfileStr.startswith("(")) {
+    size_t Idx = TypeProfileStr.find(')');
     assert(Idx != StringRef::npos);
-    StringRef ComplexType = Transformer.slice(1, Idx);
-    Transformer = Transformer.drop_front(Idx + 1);
-    assert(!Transformer.contains('(') &&
+    StringRef ComplexType = TypeProfileStr.slice(1, Idx);
+    TypeProfileStr = TypeProfileStr.drop_front(Idx + 1);
+    assert(!TypeProfileStr.contains('(') &&
            "Only allow one complex type transformer");
 
-    auto UpdateAndCheckComplexProto = [&]() {
-      Scale = LMUL.getScale(ElementBitwidth);
-      const StringRef VectorPrototypes("vwqom");
-      if (!VectorPrototypes.contains(PType))
-        llvm_unreachable("Complex type transformer only supports vector type!");
-      if (Transformer.find_first_of("PCKWS") != StringRef::npos)
-        llvm_unreachable(
-            "Illegal type transformer for Complex type transformer");
-    };
-    auto ComputeFixedLog2LMUL =
-        [&](StringRef Value,
-            std::function<bool(const int32_t &, const int32_t &)> Compare) {
-          int32_t Log2LMUL;
-          Value.getAsInteger(10, Log2LMUL);
-          if (!Compare(Log2LMUL, LMUL.Log2LMUL)) {
-            ScalarType = Invalid;
-            return false;
-          }
-          // Update new LMUL
-          LMUL = LMULType(Log2LMUL);
-          UpdateAndCheckComplexProto();
-          return true;
-        };
     auto ComplexTT = ComplexType.split(":");
+    VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
     if (ComplexTT.first == "Log2EEW") {
       uint32_t Log2EEW;
-      ComplexTT.second.getAsInteger(10, Log2EEW);
-      // update new elmul = (eew/sew) * lmul
-      LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
-      // update new eew
-      ElementBitwidth = 1 << Log2EEW;
-      ScalarType = ScalarTypeKind::SignedInteger;
-      UpdateAndCheckComplexProto();
+      if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
+        llvm_unreachable("Invalid Log2EEW value!");
+        return None;
+      }
+      switch (Log2EEW) {
+      case 3:
+        VTM = VectorTypeModifier::Log2EEW3;
+        break;
+      case 4:
+        VTM = VectorTypeModifier::Log2EEW4;
+        break;
+      case 5:
+        VTM = VectorTypeModifier::Log2EEW5;
+        break;
+      case 6:
+        VTM = VectorTypeModifier::Log2EEW6;
+        break;
+      default:
+        llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
+        return None;
+      }
     } else if (ComplexTT.first == "FixedSEW") {
       uint32_t NewSEW;
-      ComplexTT.second.getAsInteger(10, NewSEW);
-      // Set invalid type if src and dst SEW are same.
-      if (ElementBitwidth == NewSEW) {
-        ScalarType = Invalid;
-        return;
+      if (ComplexTT.second.getAsInteger(10, NewSEW)) {
+        llvm_unreachable("Invalid FixedSEW value!");
+        return None;
+      }
+      switch (NewSEW) {
+      case 8:
+        VTM = VectorTypeModifier::FixedSEW8;
+        break;
+      case 16:
+        VTM = VectorTypeModifier::FixedSEW16;
+        break;
+      case 32:
+        VTM = VectorTypeModifier::FixedSEW32;
+        break;
+      case 64:
+        VTM = VectorTypeModifier::FixedSEW64;
+        break;
+      default:
+        llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
+        return None;
       }
-      // Update new SEW
-      ElementBitwidth = NewSEW;
-      UpdateAndCheckComplexProto();
     } else if (ComplexTT.first == "LFixedLog2LMUL") {
-      // New LMUL should be larger than old
-      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::greater<int32_t>()))
-        return;
+      int32_t Log2LMUL;
+      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
+        llvm_unreachable("Invalid LFixedLog2LMUL value!");
+        return None;
+      }
+      switch (Log2LMUL) {
+      case -3:
+        VTM = VectorTypeModifier::LFixedLog2LMULN3;
+        break;
+      case -2:
+        VTM = VectorTypeModifier::LFixedLog2LMULN2;
+        break;
+      case -1:
+        VTM = VectorTypeModifier::LFixedLog2LMULN1;
+        break;
+      case 0:
+        VTM = VectorTypeModifier::LFixedLog2LMUL0;
+        break;
+      case 1:
+        VTM = VectorTypeModifier::LFixedLog2LMUL1;
+        break;
+      case 2:
+        VTM = VectorTypeModifier::LFixedLog2LMUL2;
+        break;
+      case 3:
+        VTM = VectorTypeModifier::LFixedLog2LMUL3;
+        break;
+      default:
+        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
+        return None;
+      }
     } else if (ComplexTT.first == "SFixedLog2LMUL") {
-      // New LMUL should be smaller than old
-      if (!ComputeFixedLog2LMUL(ComplexTT.second, std::less<int32_t>()))
-        return;
+      int32_t Log2LMUL;
+      if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
+        llvm_unreachable("Invalid SFixedLog2LMUL value!");
+        return None;
+      }
+      switch (Log2LMUL) {
+      case -3:
+        VTM = VectorTypeModifier::SFixedLog2LMULN3;
+        break;
+      case -2:
+        VTM = VectorTypeModifier::SFixedLog2LMULN2;
+        break;
+      case -1:
+        VTM = VectorTypeModifier::SFixedLog2LMULN1;
+        break;
+      case 0:
+        VTM = VectorTypeModifier::SFixedLog2LMUL0;
+        break;
+      case 1:
+        VTM = VectorTypeModifier::SFixedLog2LMUL1;
+        break;
+      case 2:
+        VTM = VectorTypeModifier::SFixedLog2LMUL2;
+        break;
+      case 3:
+        VTM = VectorTypeModifier::SFixedLog2LMUL3;
+        break;
+      default:
+        llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
+        return None;
+      }
+
     } else {
       llvm_unreachable("Illegal complex type transformers!");
     }
+    TP.VTM = static_cast<uint8_t>(VTM);
   }
 
   // Compute the remain type transformers
-  for (char I : Transformer) {
+  TypeModifier TM = TypeModifier::NoModifier;
+  for (char I : TypeProfileStr) {
     switch (I) {
     case 'P':
-      if (IsConstant)
+      if ((TM & TypeModifier::Const) == TypeModifier::Const)
         llvm_unreachable("'P' transformer cannot be used after 'C'");
-      if (IsPointer)
+      if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
         llvm_unreachable("'P' transformer cannot be used twice");
-      IsPointer = true;
+      TM |= TypeModifier::Pointer;
       break;
     case 'C':
-      if (IsConstant)
-        llvm_unreachable("'C' transformer cannot be used twice");
-      IsConstant = true;
+      TM |= TypeModifier::Const;
       break;
     case 'K':
-      IsImmediate = true;
+      TM |= TypeModifier::Immediate;
       break;
     case 'U':
-      ScalarType = ScalarTypeKind::UnsignedInteger;
+      TM |= TypeModifier::UnsignedInteger;
       break;
     case 'I':
-      ScalarType = ScalarTypeKind::SignedInteger;
+      TM |= TypeModifier::SignedInteger;
       break;
     case 'F':
-      ScalarType = ScalarTypeKind::Float;
+      TM |= TypeModifier::Float;
       break;
     case 'S':
+      TM |= TypeModifier::LMUL1;
+      break;
+    default:
+      llvm_unreachable("Illegal non-primitive type transformer!");
+    }
+  }
+  TP.TM = static_cast<uint8_t>(TM);
+
+  return TP;
+}
+
+void RVVType::applyModifier(const TypeProfile &Transformer) {
+  // Handle primitive type transformer
+  switch (static_cast<PrimitiveType>(Transformer.PT)) {
+  case PrimitiveType::Scalar:
+    Scale = 0;
+    break;
+  case PrimitiveType::Vector:
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::Widening2XVector:
+    ElementBitwidth *= 2;
+    LMUL *= 2;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::Widening4XVector:
+    ElementBitwidth *= 4;
+    LMUL *= 4;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::Widening8XVector:
+    ElementBitwidth *= 8;
+    LMUL *= 8;
+    Scale = LMUL.getScale(ElementBitwidth);
+    break;
+  case PrimitiveType::MaskVector:
+    ScalarType = ScalarTypeKind::Boolean;
+    Scale = LMUL.getScale(ElementBitwidth);
+    ElementBitwidth = 1;
+    break;
+  case PrimitiveType::Void:
+    ScalarType = ScalarTypeKind::Void;
+    break;
+  case PrimitiveType::SizeT:
+    ScalarType = ScalarTypeKind::Size_t;
+    break;
+  case PrimitiveType::Ptrdiff:
+    ScalarType = ScalarTypeKind::Ptrdiff_t;
+    break;
+  case PrimitiveType::UnsignedLong:
+    ScalarType = ScalarTypeKind::UnsignedLong;
+    break;
+  case PrimitiveType::SignedLong:
+    ScalarType = ScalarTypeKind::SignedLong;
+    break;
+  case PrimitiveType::Invalid:
+    ScalarType = ScalarTypeKind::Invalid;
+    return;
+  default:
+    llvm_unreachable("Illegal primitive type transformers!");
+  }
+
+  switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
+  case VectorTypeModifier::Log2EEW3:
+    applyLog2EEW(3);
+    break;
+  case VectorTypeModifier::Log2EEW4:
+    applyLog2EEW(4);
+    break;
+  case VectorTypeModifier::Log2EEW5:
+    applyLog2EEW(5);
+    break;
+  case VectorTypeModifier::Log2EEW6:
+    applyLog2EEW(6);
+    break;
+  case VectorTypeModifier::FixedSEW8:
+    applyFixedSEW(8);
+    break;
+  case VectorTypeModifier::FixedSEW16:
+    applyFixedSEW(16);
+    break;
+  case VectorTypeModifier::FixedSEW32:
+    applyFixedSEW(32);
+    break;
+  case VectorTypeModifier::FixedSEW64:
+    applyFixedSEW(64);
+    break;
+  case VectorTypeModifier::LFixedLog2LMULN3:
+    applyFixedLog2LMUL(-3, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMULN2:
+    applyFixedLog2LMUL(-2, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMULN1:
+    applyFixedLog2LMUL(-1, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL0:
+    applyFixedLog2LMUL(0, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL1:
+    applyFixedLog2LMUL(1, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL2:
+    applyFixedLog2LMUL(2, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::LFixedLog2LMUL3:
+    applyFixedLog2LMUL(3, /* LargerThan= */ true);
+    break;
+  case VectorTypeModifier::SFixedLog2LMULN3:
+    applyFixedLog2LMUL(-3, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMULN2:
+    applyFixedLog2LMUL(-2, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMULN1:
+    applyFixedLog2LMUL(-1, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL0:
+    applyFixedLog2LMUL(0, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL1:
+    applyFixedLog2LMUL(1, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL2:
+    applyFixedLog2LMUL(2, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::SFixedLog2LMUL3:
+    applyFixedLog2LMUL(3, /* LargerThan= */ false);
+    break;
+  case VectorTypeModifier::NoModifier:
+    break;
+  default:
+    llvm_unreachable("Illegal vector type modifier!");
+  }
+
+  for (unsigned TypeModifierMaskShift = 0;
+       TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
+       ++TypeModifierMaskShift) {
+    unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
+    if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
+        TypeModifierMask)
+      continue;
+    switch (static_cast<TypeModifier>(TypeModifierMask)) {
+    case TypeModifier::Pointer:
+      IsPointer = true;
+      break;
+    case TypeModifier::Const:
+      IsConstant = true;
+      break;
+    case TypeModifier::Immediate:
+      IsImmediate = true;
+      IsConstant = true;
+      break;
+    case TypeModifier::UnsignedInteger:
+      ScalarType = ScalarTypeKind::UnsignedInteger;
+      break;
+    case TypeModifier::SignedInteger:
+      ScalarType = ScalarTypeKind::SignedInteger;
+      break;
+    case TypeModifier::Float:
+      ScalarType = ScalarTypeKind::Float;
+      break;
+    case TypeModifier::LMUL1:
       LMUL = LMULType(0);
       // Update ElementBitwidth need to update Scale too.
       Scale = LMUL.getScale(ElementBitwidth);
       break;
     default:
-      llvm_unreachable("Illegal non-primitive type transformer!");
+      llvm_unreachable("Unknown type modifier mask!");
     }
   }
 }
 
+void RVVType::applyLog2EEW(unsigned Log2EEW) {
+  // update new elmul = (eew/sew) * lmul
+  LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
+  // update new eew
+  ElementBitwidth = 1 << Log2EEW;
+  ScalarType = ScalarTypeKind::SignedInteger;
+  Scale = LMUL.getScale(ElementBitwidth);
+}
+
+void RVVType::applyFixedSEW(unsigned NewSEW) {
+  // Set invalid type if src and dst SEW are same.
+  if (ElementBitwidth == NewSEW) {
+    ScalarType = ScalarTypeKind::Invalid;
+    return;
+  }
+  // Update new SEW
+  ElementBitwidth = NewSEW;
+  Scale = LMUL.getScale(ElementBitwidth);
+}
+
+void RVVType::applyFixedLog2LMUL(int Log2LMUL, bool LargerThan) {
+  if (LargerThan) {
+    if (Log2LMUL < LMUL.Log2LMUL) {
+      ScalarType = ScalarTypeKind::Invalid;
+      return;
+    }
+  } else {
+    if (Log2LMUL > LMUL.Log2LMUL) {
+      ScalarType = ScalarTypeKind::Invalid;
+      return;
+    }
+  }
+  // Update new LMUL
+  LMUL = LMULType(Log2LMUL);
+  Scale = LMUL.getScale(ElementBitwidth);
+}
+
+Optional<RVVTypes> RVVType::computeTypes(BasicType BT, int Log2LMUL,
+                                         unsigned NF,
+                                         ArrayRef<TypeProfile> PrototypeSeq) {
+  // LMUL x NF must be less than or equal to 8.
+  if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
+    return llvm::None;
+
+  RVVTypes Types;
+  for (const TypeProfile &Proto : PrototypeSeq) {
+    auto T = computeType(BT, Log2LMUL, Proto);
+    if (!T.hasValue())
+      return llvm::None;
+    // Record legal type index
+    Types.push_back(T.getValue());
+  }
+  return Types;
+}
+
+Optional<RVVTypePtr> RVVType::computeType(BasicType BT, int Log2LMUL,
+                                          TypeProfile Proto) {
+  std::string Idx =
+      Twine(Twine(static_cast<int>(BT)) + Twine(Log2LMUL) + Proto.IndexStr())
+          .str();
+  // Search first
+  auto It = LegalTypes.find(Idx);
+  if (It != LegalTypes.end())
+    return &(It->second);
+  if (IllegalTypes.count(Idx))
+    return llvm::None;
+  // Compute type and record the result.
+  RVVType T(BT, Log2LMUL, Proto);
+  if (T.isValid()) {
+    // Record legal type index and value.
+    LegalTypes.insert({Idx, T});
+    return &(LegalTypes[Idx]);
+  }
+  // Record illegal type index.
+  IllegalTypes.insert(Idx);
+  return llvm::None;
+}
+
 //===----------------------------------------------------------------------===//
 // RVVIntrinsic implementation
 //===----------------------------------------------------------------------===//
@@ -593,5 +899,36 @@
   return S;
 }
 
+std::string
+RVVIntrinsic::getSuffixStr(BasicType Type, int Log2LMUL,
+                           const llvm::SmallVector<TypeProfile> &TypeProfiles) {
+  SmallVector<std::string> SuffixStrs;
+  for (auto TP : TypeProfiles) {
+    auto T = RVVType::computeType(Type, Log2LMUL, TP);
+    SuffixStrs.push_back(T.getValue()->getShortStr());
+  }
+  return join(SuffixStrs, "_");
+}
+
+SmallVector<TypeProfile> parsePrototypes(StringRef Prototypes) {
+  SmallVector<TypeProfile> TypeProfiles;
+  const StringRef Primaries("evwqom0ztul");
+  while (!Prototypes.empty()) {
+    size_t Idx = 0;
+    // Skip over complex prototype because it could contain primitive type
+    // character.
+    if (Prototypes[0] == '(')
+      Idx = Prototypes.find_first_of(')');
+    Idx = Prototypes.find_first_of(Primaries, Idx);
+    assert(Idx != StringRef::npos);
+    auto TP = TypeProfile::parseTypeProfile(Prototypes.slice(0, Idx + 1));
+    if (!TP)
+      llvm_unreachable("Error during parsing prototype.");
+    TypeProfiles.push_back(*TP);
+    Prototypes = Prototypes.drop_front(Idx + 1);
+  }
+  return std::move(TypeProfiles);
+}
+
 } // end namespace RISCV
 } // end namespace clang
Index: clang/include/clang/Support/RISCVVIntrinsicUtils.h
===================================================================
--- clang/include/clang/Support/RISCVVIntrinsicUtils.h
+++ clang/include/clang/Support/RISCVVIntrinsicUtils.h
@@ -9,7 +9,10 @@
 #ifndef CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
 #define CLANG_SUPPORT_RISCVVINTRINSICUTILS_H
 
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include <cstdint>
 #include <string>
@@ -18,9 +21,128 @@
 namespace clang {
 namespace RISCV {
 
-using BasicType = char;
 using VScaleVal = llvm::Optional<unsigned>;
 
+// Modifier for vector type.
+enum class VectorTypeModifier : uint8_t {
+  NoModifier,
+  Log2EEW3,
+  Log2EEW4,
+  Log2EEW5,
+  Log2EEW6,
+  FixedSEW8,
+  FixedSEW16,
+  FixedSEW32,
+  FixedSEW64,
+  LFixedLog2LMULN3,
+  LFixedLog2LMULN2,
+  LFixedLog2LMULN1,
+  LFixedLog2LMUL0,
+  LFixedLog2LMUL1,
+  LFixedLog2LMUL2,
+  LFixedLog2LMUL3,
+  SFixedLog2LMULN3,
+  SFixedLog2LMULN2,
+  SFixedLog2LMULN1,
+  SFixedLog2LMUL0,
+  SFixedLog2LMUL1,
+  SFixedLog2LMUL2,
+  SFixedLog2LMUL3,
+};
+
+// Similar to basic type but used to describe what's kind of type related to
+// basic vector type, used to compute type info of arguments.
+enum class PrimitiveType : uint8_t {
+  Invalid,
+  Scalar,
+  Vector,
+  Widening2XVector,
+  Widening4XVector,
+  Widening8XVector,
+  MaskVector,
+  Void,
+  SizeT,
+  Ptrdiff,
+  UnsignedLong,
+  SignedLong,
+};
+
+// Modifier for type, used for both scalar and vector types.
+enum class TypeModifier : uint8_t {
+  NoModifier = 0,
+  Pointer = 1 << 0,
+  Const = 1 << 1,
+  Immediate = 1 << 2,
+  UnsignedInteger = 1 << 3,
+  SignedInteger = 1 << 4,
+  Float = 1 << 5,
+  LMUL1 = 1 << 6,
+  MaxOffset = 6,
+  LLVM_MARK_AS_BITMASK_ENUM(LMUL1),
+};
+
+// TypeProfile is used to compute type info of arguments or return value.
+struct TypeProfile {
+  constexpr TypeProfile() = default;
+  constexpr TypeProfile(PrimitiveType PT) : PT(static_cast<uint8_t>(PT)) {}
+  constexpr TypeProfile(PrimitiveType PT, TypeModifier TM)
+      : PT(static_cast<uint8_t>(PT)), TM(static_cast<uint8_t>(TM)) {}
+  constexpr TypeProfile(uint8_t PT, uint8_t VTM, uint8_t TM)
+      : PT(PT), VTM(VTM), TM(TM) {}
+
+  uint8_t PT = static_cast<uint8_t>(PrimitiveType::Invalid);
+  uint8_t VTM = static_cast<uint8_t>(VectorTypeModifier::NoModifier);
+  uint8_t TM = static_cast<uint8_t>(TypeModifier::NoModifier);
+
+  std::string IndexStr() const {
+    return std::to_string(PT) + "_" + std::to_string(VTM) + "_" +
+           std::to_string(TM);
+  };
+
+  bool operator!=(const TypeProfile &TP) const {
+    return TP.PT != PT || TP.VTM != VTM || TP.TM != TM;
+  }
+  bool operator>(const TypeProfile &TP) const {
+    return !(TP.PT <= PT && TP.VTM <= VTM && TP.TM <= TM);
+  }
+
+  static const TypeProfile Mask;
+  static const TypeProfile Vector;
+  static const TypeProfile VL;
+  static llvm::Optional<TypeProfile>
+  parseTypeProfile(llvm::StringRef PrototypeStr);
+};
+
+llvm::SmallVector<TypeProfile> parsePrototypes(llvm::StringRef Prototypes);
+
+// Basic type of vector type.
+enum class BasicType : uint8_t {
+  Unknown = 0,
+  Int8 = 1 << 0,
+  Int16 = 1 << 1,
+  Int32 = 1 << 2,
+  Int64 = 1 << 3,
+  Float16 = 1 << 4,
+  Float32 = 1 << 5,
+  Float64 = 1 << 6,
+  MaxOffset = 6,
+  LLVM_MARK_AS_BITMASK_ENUM(Float64),
+};
+
+// Type of vector type.
+enum ScalarTypeKind : uint8_t {
+  Void,
+  Size_t,
+  Ptrdiff_t,
+  UnsignedLong,
+  SignedLong,
+  Boolean,
+  SignedInteger,
+  UnsignedInteger,
+  Float,
+  Invalid,
+};
+
 // Exponential LMUL
 struct LMULType {
   int Log2LMUL;
@@ -32,20 +154,12 @@
   LMULType &operator*=(uint32_t RHS);
 };
 
+class RVVType;
+using RVVTypePtr = RVVType *;
+using RVVTypes = std::vector<RVVTypePtr>;
+
 // This class is compact representation of a valid and invalid RVVType.
 class RVVType {
-  enum ScalarTypeKind : uint32_t {
-    Void,
-    Size_t,
-    Ptrdiff_t,
-    UnsignedLong,
-    SignedLong,
-    Boolean,
-    SignedInteger,
-    UnsignedInteger,
-    Float,
-    Invalid,
-  };
   BasicType BT;
   ScalarTypeKind ScalarType = Invalid;
   LMULType LMUL;
@@ -64,8 +178,8 @@
   std::string ShortStr;
 
 public:
-  RVVType() : RVVType(BasicType(), 0, llvm::StringRef()) {}
-  RVVType(BasicType BT, int Log2LMUL, llvm::StringRef prototype);
+  RVVType() : BT(BasicType::Unknown), LMUL(0), Valid(false) {}
+  RVVType(BasicType BT, int Log2LMUL, const TypeProfile &Profile);
 
   // Return the string representation of a type, which is an encoded string for
   // passing to the BUILTIN() macro in Builtins.def.
@@ -114,7 +228,11 @@
 
   // Applies a prototype modifier to the current type. The result maybe an
   // invalid type.
-  void applyModifier(llvm::StringRef prototype);
+  void applyModifier(const TypeProfile &prototype);
+
+  void applyLog2EEW(unsigned Log2EEW);
+  void applyFixedSEW(unsigned NewSEW);
+  void applyFixedLog2LMUL(int Log2LMUL, bool LargerThan);
 
   // Compute and record a string for legal type.
   void initBuiltinStr();
@@ -124,10 +242,19 @@
   void initTypeStr();
   // Compute and record a short name of a type for C/C++ name suffix.
   void initShortStr();
+
+public:
+  /// Compute output and input types by applying different config (basic type
+  /// and LMUL with type transformers). It also record result of type in legal
+  /// or illegal set to avoid compute the  same config again. The result maybe
+  /// have illegal RVVType.
+  static llvm::Optional<RVVTypes>
+  computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
+               llvm::ArrayRef<TypeProfile> PrototypeSeq);
+  static llvm::Optional<RVVTypePtr> computeType(BasicType BT, int Log2LMUL,
+                                                TypeProfile Proto);
 };
 
-using RVVTypePtr = RVVType *;
-using RVVTypes = std::vector<RVVTypePtr>;
 using RISCVPredefinedMacroT = uint8_t;
 
 enum RISCVPredefinedMacro : RISCVPredefinedMacroT {
@@ -206,6 +333,10 @@
 
   // Return the type string for a BUILTIN() macro in Builtins.def.
   std::string getBuiltinTypeStr() const;
+
+  static std::string
+  getSuffixStr(BasicType Type, int Log2LMUL,
+               const llvm::SmallVector<TypeProfile> &TypeProfiles);
 };
 
 } // end namespace RISCV
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to