https://github.com/momchil-velikov updated 
https://github.com/llvm/llvm-project/pull/123604

>From e825bc0f660eb3dce41ee062d04e4e39bbac5d2a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.veli...@arm.com>
Date: Fri, 6 Dec 2024 13:09:23 +0000
Subject: [PATCH] [AArch64] Refactor implementation of FP8 types (NFC)

* The FP8 scalar type (`__mfp8`) was described as a vector type
* The FP8 vector types were described/assumed to have
  integer element type (the element type ought to be `__mfp8`),
* Add support for `m` type specifier (denoting `__mfp8`)
  in `DecodeTypeFromStr` and create SVE builtin prototypes using
  the specifier, instead of `int8_t`.

[fixup] Add a comment about special case of mapping FP8 vectors to LLVM vector 
types
---
 .../clang/Basic/AArch64SVEACLETypes.def       | 35 +++++++++----------
 clang/lib/AST/ASTContext.cpp                  | 30 +++++++++-------
 clang/lib/AST/ItaniumMangle.cpp               |  2 +-
 clang/lib/AST/Type.cpp                        |  4 +--
 clang/lib/CodeGen/CodeGenTypes.cpp            | 22 +++++++-----
 clang/lib/CodeGen/Targets/AArch64.cpp         |  7 ++--
 clang/utils/TableGen/SveEmitter.cpp           |  4 +--
 7 files changed, 58 insertions(+), 46 deletions(-)

diff --git a/clang/include/clang/Basic/AArch64SVEACLETypes.def 
b/clang/include/clang/Basic/AArch64SVEACLETypes.def
index 2dd2754e778d60..a408bb0c54057c 100644
--- a/clang/include/clang/Basic/AArch64SVEACLETypes.def
+++ b/clang/include/clang/Basic/AArch64SVEACLETypes.def
@@ -57,6 +57,11 @@
 //  - IsBF true for vector of brain float elements.
 
//===----------------------------------------------------------------------===//
 
+#ifndef SVE_SCALAR_TYPE
+#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits) \
+  SVE_TYPE(Name, Id, SingletonId)
+#endif
+
 #ifndef SVE_VECTOR_TYPE
 #define SVE_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
   SVE_TYPE(Name, Id, SingletonId)
@@ -72,6 +77,11 @@
   SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, 
NF, false, false, true)
 #endif
 
+#ifndef SVE_VECTOR_TYPE_MFLOAT
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, 
ElBits, NF) \
+  SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, 
NF, false, false, false)
+#endif
+
 #ifndef SVE_VECTOR_TYPE_FLOAT
 #define SVE_VECTOR_TYPE_FLOAT(Name, MangledName, Id, SingletonId, NumEls, 
ElBits, NF) \
   SVE_VECTOR_TYPE_DETAILS(Name, MangledName, Id, SingletonId, NumEls, ElBits, 
NF, false, true, false)
@@ -97,16 +107,6 @@
   SVE_TYPE(Name, Id, SingletonId)
 #endif
 
-#ifndef AARCH64_VECTOR_TYPE
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId) \
-  SVE_TYPE(Name, Id, SingletonId)
-#endif
-
-#ifndef AARCH64_VECTOR_TYPE_MFLOAT
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, 
ElBits, NF) \
-  AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
-#endif
-
 //===- Vector point types -----------------------------------------------===//
 
 SVE_VECTOR_TYPE_INT("__SVInt8_t",  "__SVInt8_t",  SveInt8,  SveInt8Ty, 16,  8, 
1, true)
@@ -125,8 +125,7 @@ SVE_VECTOR_TYPE_FLOAT("__SVFloat64_t", "__SVFloat64_t", 
SveFloat64, SveFloat64Ty
 
 SVE_VECTOR_TYPE_BFLOAT("__SVBfloat16_t", "__SVBfloat16_t", SveBFloat16, 
SveBFloat16Ty, 8, 16, 1)
 
-// This is a 8 bits opaque type.
-SVE_VECTOR_TYPE_INT("__SVMfloat8_t", "__SVMfloat8_t",  SveMFloat8, 
SveMFloat8Ty, 16, 8, 1, false)
+SVE_VECTOR_TYPE_MFLOAT("__SVMfloat8_t", "__SVMfloat8_t",  SveMFloat8, 
SveMFloat8Ty, 16, 8, 1)
 
 //
 // x2
@@ -148,7 +147,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x2_t", 
"svfloat64x2_t", SveFloat64x2, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x2_t", "svbfloat16x2_t", 
SveBFloat16x2, SveBFloat16x2Ty, 8, 16, 2)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, 
SveMFloat8x2Ty, 16, 8, 2, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x2_t", "svmfloat8x2_t", SveMFloat8x2, 
SveMFloat8x2Ty, 16, 8, 2)
 
 //
 // x3
@@ -170,7 +169,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x3_t", 
"svfloat64x3_t", SveFloat64x3, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x3_t", "svbfloat16x3_t", 
SveBFloat16x3, SveBFloat16x3Ty, 8, 16, 3)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, 
SveMFloat8x3Ty, 16, 8, 3, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x3_t", "svmfloat8x3_t", SveMFloat8x3, 
SveMFloat8x3Ty, 16, 8, 3)
 
 //
 // x4
@@ -192,7 +191,7 @@ SVE_VECTOR_TYPE_FLOAT("__clang_svfloat64x4_t", 
"svfloat64x4_t", SveFloat64x4, Sv
 
 SVE_VECTOR_TYPE_BFLOAT("__clang_svbfloat16x4_t", "svbfloat16x4_t", 
SveBFloat16x4, SveBFloat16x4Ty, 8, 16, 4)
 
-SVE_VECTOR_TYPE_INT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, 
SveMFloat8x4Ty, 16, 8, 4, false)
+SVE_VECTOR_TYPE_MFLOAT("__clang_svmfloat8x4_t", "svmfloat8x4_t", SveMFloat8x4, 
SveMFloat8x4Ty, 16, 8, 4)
 
 SVE_PREDICATE_TYPE_ALL("__SVBool_t", "__SVBool_t", SveBool, SveBoolTy, 16, 1)
 SVE_PREDICATE_TYPE_ALL("__clang_svboolx2_t", "svboolx2_t", SveBoolx2, 
SveBoolx2Ty, 16, 2)
@@ -200,15 +199,15 @@ SVE_PREDICATE_TYPE_ALL("__clang_svboolx4_t", 
"svboolx4_t", SveBoolx4, SveBoolx4T
 
 SVE_OPAQUE_TYPE("__SVCount_t", "__SVCount_t", SveCount, SveCountTy)
 
-AARCH64_VECTOR_TYPE_MFLOAT("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 1, 8, 1)
+SVE_SCALAR_TYPE("__mfp8", "__mfp8", MFloat8, MFloat8Ty, 8)
 
 #undef SVE_VECTOR_TYPE
+#undef SVE_VECTOR_TYPE_MFLOAT
 #undef SVE_VECTOR_TYPE_BFLOAT
 #undef SVE_VECTOR_TYPE_FLOAT
 #undef SVE_VECTOR_TYPE_INT
 #undef SVE_PREDICATE_TYPE
 #undef SVE_PREDICATE_TYPE_ALL
 #undef SVE_OPAQUE_TYPE
-#undef AARCH64_VECTOR_TYPE_MFLOAT
-#undef AARCH64_VECTOR_TYPE
+#undef SVE_SCALAR_TYPE
 #undef SVE_TYPE
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index a4ba9fd0553464..cd1bcb3b9a063d 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2269,11 +2269,10 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) 
const {
     Width = 0;                                                                 
\
     Align = 16;                                                                
\
     break;
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, 
\
-                                   ElBits, NF)                                 
\
+#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits)              
\
   case BuiltinType::Id:                                                        
\
-    Width = NumEls * ElBits * NF;                                              
\
-    Align = NumEls * ElBits;                                                   
\
+    Width = Bits;                                                              
\
+    Align = Bits;                                                              
\
     break;
 #include "clang/Basic/AArch64SVEACLETypes.def"
 #define PPC_VECTOR_TYPE(Name, Id, Size)                                        
\
@@ -4423,15 +4422,14 @@ ASTContext::getBuiltinVectorTypeInfo(const BuiltinType 
*Ty) const {
                                ElBits, NF)                                     
\
   case BuiltinType::Id:                                                        
\
     return {BFloat16Ty, llvm::ElementCount::getScalable(NumEls), NF};
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls,     
\
+                               ElBits, NF)                                     
\
+  case BuiltinType::Id:                                                        
\
+    return {MFloat8Ty, llvm::ElementCount::getScalable(NumEls), NF};
 #define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) 
\
   case BuiltinType::Id:                                                        
\
     return {BoolTy, llvm::ElementCount::getScalable(NumEls), NF};
-#define AARCH64_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls, 
\
-                                   ElBits, NF)                                 
\
-  case BuiltinType::Id:                                                        
\
-    return {getIntTypeForBitwidth(ElBits, false),                              
\
-            llvm::ElementCount::getFixed(NumEls), NF};
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
 
 #define RVV_VECTOR_TYPE_INT(Name, Id, SingletonId, NumEls, ElBits, NF,         
\
@@ -4493,11 +4491,16 @@ QualType ASTContext::getScalableVectorType(QualType 
EltTy, unsigned NumElts,
       EltTySize == ElBits && NumElts == (NumEls * NF) && NumFields == 1) {     
\
     return SingletonId;                                                        
\
   }
+#define SVE_VECTOR_TYPE_MFLOAT(Name, MangledName, Id, SingletonId, NumEls,     
\
+                               ElBits, NF)                                     
\
+  if (EltTy->isMFloat8Type() && EltTySize == ElBits &&                         
\
+      NumElts == (NumEls * NF) && NumFields == 1) {                            
\
+    return SingletonId;                                                        
\
+  }
 #define SVE_PREDICATE_TYPE_ALL(Name, MangledName, Id, SingletonId, NumEls, NF) 
\
   if (EltTy->isBooleanType() && NumElts == (NumEls * NF) && NumFields == 1)    
\
     return SingletonId;
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
   } else if (Target->hasRISCVVTypes()) {
     uint64_t EltTySize = getTypeSize(EltTy);
@@ -12382,6 +12385,9 @@ static QualType DecodeTypeFromStr(const char *&Str, 
const ASTContext &Context,
   case 'p':
     Type = Context.getProcessIDType();
     break;
+  case 'm':
+    Type = Context.MFloat8Ty;
+    break;
   }
 
   // If there are modifiers and if we're allowed to parse them, go for it.
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 9948963d7f44b3..49089c0ea3c8ac 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -3433,7 +3433,7 @@ void CXXNameMangler::mangleType(const BuiltinType *T) {
     type_name = MangledName;                                                   
\
     Out << (type_name == Name ? "u" : "") << type_name.size() << type_name;    
\
     break;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                
\
+#define SVE_SCALAR_TYPE(Name, MangledName, Id, SingletonId, Bits)              
\
   case BuiltinType::Id:                                                        
\
     type_name = MangledName;                                                   
\
     Out << (type_name == Name ? "u" : "") << type_name.size() << type_name;    
\
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index caa0ac858a1bea..fde0746a175705 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2527,9 +2527,7 @@ bool Type::isSVESizelessBuiltinType() const {
 #define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 
\
   case BuiltinType::Id:                                                        
\
     return true;
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                
\
-  case BuiltinType::Id:                                                        
\
-    return false;
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
     default:
       return false;
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp 
b/clang/lib/CodeGen/CodeGenTypes.cpp
index 950b23f4e13b99..405242e97e75cb 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -505,15 +505,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
   case BuiltinType::Id:
 #define SVE_PREDICATE_TYPE(Name, MangledName, Id, SingletonId)                 
\
   case BuiltinType::Id:
-#define AARCH64_VECTOR_TYPE(Name, MangledName, Id, SingletonId)                
\
-  case BuiltinType::Id:
-#define SVE_OPAQUE_TYPE(Name, MangledName, Id, SingletonId)
+#define SVE_TYPE(Name, Id, SingletonId)
 #include "clang/Basic/AArch64SVEACLETypes.def"
       {
         ASTContext::BuiltinVectorTypeInfo Info =
             Context.getBuiltinVectorTypeInfo(cast<BuiltinType>(Ty));
-        auto VTy =
-            llvm::VectorType::get(ConvertType(Info.ElementType), Info.EC);
+        // The `__mfp8` type maps to `<1 x i8>` which can't be used to build
+        // a <N x i8> vector type, hence bypass the call to `ConvertType` for
+        // the element type and create the vector type directly.
+        auto *EltTy = Info.ElementType->isMFloat8Type()
+                          ? llvm::Type::getInt8Ty(getLLVMContext())
+                          : ConvertType(Info.ElementType);
+        auto *VTy = llvm::VectorType::get(EltTy, Info.EC);
         switch (Info.NumVectors) {
         default:
           llvm_unreachable("Expected 1, 2, 3 or 4 vectors!");
@@ -529,6 +532,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
       }
     case BuiltinType::SveCount:
       return llvm::TargetExtType::get(getLLVMContext(), "aarch64.svcount");
+    case BuiltinType::MFloat8:
+      return llvm::VectorType::get(llvm::Type::getInt8Ty(getLLVMContext()), 1,
+                                   false);
 #define PPC_VECTOR_TYPE(Name, Id, Size) \
     case BuiltinType::Id: \
       ResultType = \
@@ -650,9 +656,9 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
     // An ext_vector_type of Bool is really a vector of bits.
     llvm::Type *IRElemTy = VT->isExtVectorBoolType()
                                ? llvm::Type::getInt1Ty(getLLVMContext())
-                               : (VT->getElementType()->isMFloat8Type()
-                                      ? llvm::Type::getInt8Ty(getLLVMContext())
-                                      : ConvertType(VT->getElementType()));
+                           : VT->getElementType()->isMFloat8Type()
+                               ? llvm::Type::getInt8Ty(getLLVMContext())
+                               : ConvertType(VT->getElementType());
     ResultType = llvm::FixedVectorType::get(IRElemTy, VT->getNumElements());
     break;
   }
diff --git a/clang/lib/CodeGen/Targets/AArch64.cpp 
b/clang/lib/CodeGen/Targets/AArch64.cpp
index c702e79ff8eb98..057199c66f5a10 100644
--- a/clang/lib/CodeGen/Targets/AArch64.cpp
+++ b/clang/lib/CodeGen/Targets/AArch64.cpp
@@ -244,6 +244,7 @@ AArch64ABIInfo::convertFixedToScalableVectorType(const 
VectorType *VT) const {
 
     case BuiltinType::SChar:
     case BuiltinType::UChar:
+    case BuiltinType::MFloat8:
       return llvm::ScalableVectorType::get(
           llvm::Type::getInt8Ty(getVMContext()), 16);
 
@@ -776,8 +777,10 @@ bool AArch64ABIInfo::passAsPureScalableType(
     NPred += Info.NumVectors;
   else
     NVec += Info.NumVectors;
-  auto VTy = llvm::ScalableVectorType::get(CGT.ConvertType(Info.ElementType),
-                                           Info.EC.getKnownMinValue());
+  llvm::Type *EltTy = Info.ElementType->isMFloat8Type()
+                          ? llvm::Type::getInt8Ty(getVMContext())
+                          : CGT.ConvertType(Info.ElementType);
+  auto *VTy = llvm::ScalableVectorType::get(EltTy, Info.EC.getKnownMinValue());
 
   if (CoerceToSeq.size() + Info.NumVectors > 12)
     return false;
diff --git a/clang/utils/TableGen/SveEmitter.cpp 
b/clang/utils/TableGen/SveEmitter.cpp
index 0ecbf7cede1daa..687d344163e205 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -449,7 +449,7 @@ std::string SVEType::builtinBaseType() const {
   case TypeKind::PredicatePattern:
     return "i";
   case TypeKind::Fpm:
-    return "Wi";
+    return "UWi";
   case TypeKind::Predicate:
     return "b";
   case TypeKind::BFloat16:
@@ -457,7 +457,7 @@ std::string SVEType::builtinBaseType() const {
     return "y";
   case TypeKind::MFloat8:
     assert(ElementBitwidth == 8 && "Invalid MFloat8!");
-    return "c";
+    return "m";
   case TypeKind::Float:
     switch (ElementBitwidth) {
     case 16:

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to