stellaraccident updated this revision to Diff 465230.
stellaraccident added a comment.
Herald added a project: clang.
Herald added a subscriber: cfe-commits.

Add fix to MicrosoftMangle.cpp that caused buildbot failure.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D133823/new/

https://reviews.llvm.org/D133823

Files:
  clang/lib/AST/MicrosoftMangle.cpp
  llvm/include/llvm/ADT/APFloat.h
  llvm/lib/Support/APFloat.cpp
  llvm/unittests/ADT/APFloatTest.cpp
  mlir/include/mlir-c/BuiltinTypes.h
  mlir/include/mlir/IR/Builders.h
  mlir/include/mlir/IR/BuiltinTypes.h
  mlir/include/mlir/IR/BuiltinTypes.td
  mlir/include/mlir/IR/Types.h
  mlir/lib/AsmParser/TokenKinds.def
  mlir/lib/AsmParser/TypeParser.cpp
  mlir/lib/CAPI/IR/BuiltinTypes.cpp
  mlir/lib/IR/AsmPrinter.cpp
  mlir/lib/IR/Builders.cpp
  mlir/lib/IR/BuiltinTypes.cpp
  mlir/lib/IR/MLIRContext.cpp
  mlir/lib/IR/Types.cpp
  mlir/test/IR/attribute.mlir
  mlir/test/lib/Dialect/Test/TestOps.td

Index: mlir/test/lib/Dialect/Test/TestOps.td
===================================================================
--- mlir/test/lib/Dialect/Test/TestOps.td
+++ mlir/test/lib/Dialect/Test/TestOps.td
@@ -193,6 +193,14 @@
   let assemblyFormat = "$attr attr-dict";
 }
 
+def FloatAttrOp : TEST_Op<"float_attrs"> {
+  // TODO: Clean up the OpBase float type and attribute selectors so they
+  // can express all of the types.
+  let arguments = (ins
+    AnyAttr:$float_attr
+  );
+}
+
 def I32Case5:  I32EnumAttrCase<"case5", 5>;
 def I32Case10: I32EnumAttrCase<"case10", 10>;
 
Index: mlir/test/IR/attribute.mlir
===================================================================
--- mlir/test/IR/attribute.mlir
+++ mlir/test/IR/attribute.mlir
@@ -31,6 +31,42 @@
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Test float attributes
+//===----------------------------------------------------------------------===//
+
+func.func @float_attrs_pass() {
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f8E5M2
+    float_attr = 2. : f8E5M2
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f16
+    float_attr = 2. : f16
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : bf16
+    float_attr = 2. : bf16
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f32
+    float_attr = 2. : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f64
+    float_attr = 2. : f64
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f80
+    float_attr = 2. : f80
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 2.000000e+00 : f128
+    float_attr = 2. : f128
+  } : () -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test integer attributes
 //===----------------------------------------------------------------------===//
Index: mlir/lib/IR/Types.cpp
===================================================================
--- mlir/lib/IR/Types.cpp
+++ mlir/lib/IR/Types.cpp
@@ -18,6 +18,7 @@
 
 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
 
+bool Type::isFloat8E5M2() const { return isa<Float8E5M2Type>(); }
 bool Type::isBF16() const { return isa<BFloat16Type>(); }
 bool Type::isF16() const { return isa<Float16Type>(); }
 bool Type::isF32() const { return isa<Float32Type>(); }
Index: mlir/lib/IR/MLIRContext.cpp
===================================================================
--- mlir/lib/IR/MLIRContext.cpp
+++ mlir/lib/IR/MLIRContext.cpp
@@ -206,6 +206,7 @@
   StorageUniquer typeUniquer;
 
   /// Cached Type Instances.
+  Float8E5M2Type f8E5M2Ty;
   BFloat16Type bf16Ty;
   Float16Type f16Ty;
   Float32Type f32Ty;
@@ -276,6 +277,7 @@
 
   //// Types.
   /// Floating-point Types.
+  impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
   impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
   impl->f16Ty = TypeUniquer::get<Float16Type>(this);
   impl->f32Ty = TypeUniquer::get<Float32Type>(this);
@@ -840,6 +842,9 @@
 /// This should not be used directly.
 StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
 
+Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
+  return context->getImpl().f8E5M2Ty;
+}
 BFloat16Type BFloat16Type::get(MLIRContext *context) {
   return context->getImpl().bf16Ty;
 }
Index: mlir/lib/IR/BuiltinTypes.cpp
===================================================================
--- mlir/lib/IR/BuiltinTypes.cpp
+++ mlir/lib/IR/BuiltinTypes.cpp
@@ -88,6 +88,8 @@
 //===----------------------------------------------------------------------===//
 
 unsigned FloatType::getWidth() {
+  if (isa<Float8E5M2Type>())
+    return 8;
   if (isa<Float16Type, BFloat16Type>())
     return 16;
   if (isa<Float32Type>())
@@ -103,6 +105,8 @@
 
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
+  if (isa<Float8E5M2Type>())
+    return APFloat::Float8E5M2();
   if (isa<BFloat16Type>())
     return APFloat::BFloat();
   if (isa<Float16Type>())
Index: mlir/lib/IR/Builders.cpp
===================================================================
--- mlir/lib/IR/Builders.cpp
+++ mlir/lib/IR/Builders.cpp
@@ -33,6 +33,10 @@
 // Types.
 //===----------------------------------------------------------------------===//
 
+FloatType Builder::getFloat8E5M2Type() {
+  return FloatType::getFloat8E5M2(context);
+}
+
 FloatType Builder::getBF16Type() { return FloatType::getBF16(context); }
 
 FloatType Builder::getF16Type() { return FloatType::getF16(context); }
Index: mlir/lib/IR/AsmPrinter.cpp
===================================================================
--- mlir/lib/IR/AsmPrinter.cpp
+++ mlir/lib/IR/AsmPrinter.cpp
@@ -2179,6 +2179,7 @@
                            opaqueTy.getTypeData());
       })
       .Case<IndexType>([&](Type) { os << "index"; })
+      .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
       .Case<Float16Type>([&](Type) { os << "f16"; })
       .Case<Float32Type>([&](Type) { os << "f32"; })
Index: mlir/lib/CAPI/IR/BuiltinTypes.cpp
===================================================================
--- mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -68,6 +68,14 @@
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+bool mlirTypeIsAFloat8E5M2(MlirType type) {
+  return unwrap(type).isFloat8E5M2();
+}
+
+MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
+}
+
 bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
 
 MlirType mlirBF16TypeGet(MlirContext ctx) {
Index: mlir/lib/AsmParser/TypeParser.cpp
===================================================================
--- mlir/lib/AsmParser/TypeParser.cpp
+++ mlir/lib/AsmParser/TypeParser.cpp
@@ -30,6 +30,7 @@
   case Token::kw_tuple:
   case Token::kw_vector:
   case Token::inttype:
+  case Token::kw_f8E5M2:
   case Token::kw_bf16:
   case Token::kw_f16:
   case Token::kw_f32:
@@ -286,6 +287,9 @@
   }
 
   // float-type
+  case Token::kw_f8E5M2:
+    consumeToken(Token::kw_f8E5M2);
+    return builder.getFloat8E5M2Type();
   case Token::kw_bf16:
     consumeToken(Token::kw_bf16);
     return builder.getBF16Type();
Index: mlir/lib/AsmParser/TokenKinds.def
===================================================================
--- mlir/lib/AsmParser/TokenKinds.def
+++ mlir/lib/AsmParser/TokenKinds.def
@@ -93,6 +93,7 @@
 TOK_KEYWORD(f32)
 TOK_KEYWORD(f64)
 TOK_KEYWORD(f80)
+TOK_KEYWORD(f8E5M2)
 TOK_KEYWORD(f128)
 TOK_KEYWORD(false)
 TOK_KEYWORD(floordiv)
Index: mlir/include/mlir/IR/Types.h
===================================================================
--- mlir/include/mlir/IR/Types.h
+++ mlir/include/mlir/IR/Types.h
@@ -123,6 +123,7 @@
   // Convenience predicates.  This is only for floating point types,
   // derived types should use isa/dyn_cast.
   bool isIndex() const;
+  bool isFloat8E5M2() const;
   bool isBF16() const;
   bool isF16() const;
   bool isF32() const;
Index: mlir/include/mlir/IR/BuiltinTypes.td
===================================================================
--- mlir/include/mlir/IR/BuiltinTypes.td
+++ mlir/include/mlir/IR/BuiltinTypes.td
@@ -76,6 +76,28 @@
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Float8E5M2Type
+
+def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> {
+  let summary = "8-bit floating point with 2 bit mantissa";
+  let description = [{
+    An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits
+    mantissa. This is not a standard type as defined by IEEE-754, but it
+    follows similar conventions with the following characteristics:
+
+      * bit encoding: S1E5M2
+      * exponent bias: 15
+      * infinities: supported with exponent set to all 1s and mantissa 0s
+      * NaNs: supported with exponent bits set to all 1s and mantissa of 
+        (01, 10, or 11)
+      * denormals when exponent is 0
+
+    Described in: https://arxiv.org/abs/2209.05433
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // BFloat16Type
 
Index: mlir/include/mlir/IR/BuiltinTypes.h
===================================================================
--- mlir/include/mlir/IR/BuiltinTypes.h
+++ mlir/include/mlir/IR/BuiltinTypes.h
@@ -46,6 +46,7 @@
   static FloatType getF64(MLIRContext *ctx);
   static FloatType getF80(MLIRContext *ctx);
   static FloatType getF128(MLIRContext *ctx);
+  static FloatType getFloat8E5M2(MLIRContext *ctx);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(Type type);
@@ -373,8 +374,12 @@
 }
 
 inline bool FloatType::classof(Type type) {
-  return type.isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
-                  Float80Type, Float128Type>();
+  return type.isa<Float8E5M2Type, BFloat16Type, Float16Type, Float32Type,
+                  Float64Type, Float80Type, Float128Type>();
+}
+
+inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) {
+  return Float8E5M2Type::get(ctx);
 }
 
 inline FloatType FloatType::getBF16(MLIRContext *ctx) {
Index: mlir/include/mlir/IR/Builders.h
===================================================================
--- mlir/include/mlir/IR/Builders.h
+++ mlir/include/mlir/IR/Builders.h
@@ -59,6 +59,7 @@
                        Attribute metadata = Attribute());
 
   // Types.
+  FloatType getFloat8E5M2Type();
   FloatType getBF16Type();
   FloatType getF16Type();
   FloatType getF32Type();
Index: mlir/include/mlir-c/BuiltinTypes.h
===================================================================
--- mlir/include/mlir-c/BuiltinTypes.h
+++ mlir/include/mlir-c/BuiltinTypes.h
@@ -67,6 +67,13 @@
 // Floating-point types.
 //===----------------------------------------------------------------------===//
 
+/// Checks whether the given type is an f8E5M2 type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type);
+
+/// Creates an f8E5M2 type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx);
+
 /// Checks whether the given type is a bf16 type.
 MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type);
 
Index: llvm/unittests/ADT/APFloatTest.cpp
===================================================================
--- llvm/unittests/ADT/APFloatTest.cpp
+++ llvm/unittests/ADT/APFloatTest.cpp
@@ -1752,18 +1752,20 @@
     const unsigned long long bitPattern[2];
     const unsigned bitPatternLength;
   } const GetZeroTest[] = {
-    { &APFloat::IEEEhalf(), false, {0, 0}, 1},
-    { &APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1},
-    { &APFloat::IEEEsingle(), false, {0, 0}, 1},
-    { &APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1},
-    { &APFloat::IEEEdouble(), false, {0, 0}, 1},
-    { &APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1},
-    { &APFloat::IEEEquad(), false, {0, 0}, 2},
-    { &APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2},
-    { &APFloat::PPCDoubleDouble(), false, {0, 0}, 2},
-    { &APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2},
-    { &APFloat::x87DoubleExtended(), false, {0, 0}, 2},
-    { &APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
+      {&APFloat::IEEEhalf(), false, {0, 0}, 1},
+      {&APFloat::IEEEhalf(), true, {0x8000ULL, 0}, 1},
+      {&APFloat::IEEEsingle(), false, {0, 0}, 1},
+      {&APFloat::IEEEsingle(), true, {0x80000000ULL, 0}, 1},
+      {&APFloat::IEEEdouble(), false, {0, 0}, 1},
+      {&APFloat::IEEEdouble(), true, {0x8000000000000000ULL, 0}, 1},
+      {&APFloat::IEEEquad(), false, {0, 0}, 2},
+      {&APFloat::IEEEquad(), true, {0, 0x8000000000000000ULL}, 2},
+      {&APFloat::PPCDoubleDouble(), false, {0, 0}, 2},
+      {&APFloat::PPCDoubleDouble(), true, {0x8000000000000000ULL, 0}, 2},
+      {&APFloat::x87DoubleExtended(), false, {0, 0}, 2},
+      {&APFloat::x87DoubleExtended(), true, {0, 0x8000ULL}, 2},
+      {&APFloat::Float8E5M2(), false, {0, 0}, 1},
+      {&APFloat::Float8E5M2(), true, {0x80ULL, 0}, 1},
   };
   const unsigned NumGetZeroTests = 12;
   for (unsigned i = 0; i < NumGetZeroTests; ++i) {
@@ -4754,7 +4756,7 @@
   EXPECT_TRUE(ilogb(F) == -1);
 }
 
-TEST(APFloatTest, ToDouble) {
+TEST(APFloatTest, IEEEdoubleToDouble) {
   APFloat DPosZero(0.0);
   APFloat DPosZeroToDouble(DPosZero.convertToDouble());
   EXPECT_TRUE(DPosZeroToDouble.isPosZero());
@@ -4790,7 +4792,9 @@
             DNegInf.convertToDouble());
   APFloat DQNaN = APFloat::getQNaN(APFloat::IEEEdouble());
   EXPECT_TRUE(std::isnan(DQNaN.convertToDouble()));
+}
 
+TEST(APFloatTest, IEEEsingleToDouble) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToDouble(FPosZero.convertToDouble());
   EXPECT_TRUE(FPosZeroToDouble.isPosZero());
@@ -4825,7 +4829,9 @@
             FNegInf.convertToDouble());
   APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle());
   EXPECT_TRUE(std::isnan(FQNaN.convertToDouble()));
+}
 
+TEST(APFloatTest, IEEEhalfToDouble) {
   APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf());
   APFloat HPosZeroToDouble(HPosZero.convertToDouble());
   EXPECT_TRUE(HPosZeroToDouble.isPosZero());
@@ -4867,7 +4873,9 @@
   APFloat BNegZero = APFloat::getZero(APFloat::IEEEhalf(), true);
   APFloat BNegZeroToDouble(BNegZero.convertToDouble());
   EXPECT_TRUE(BNegZeroToDouble.isNegZero());
+}
 
+TEST(APFloatTest, BFloatToDouble) {
   APFloat BOne(APFloat::BFloat(), "1.0");
   EXPECT_EQ(1.0, BOne.convertToDouble());
   APFloat BPosLargest = APFloat::getLargest(APFloat::BFloat(), false);
@@ -4901,7 +4909,35 @@
   EXPECT_TRUE(std::isnan(BQNaN.convertToDouble()));
 }
 
-TEST(APFloatTest, ToFloat) {
+TEST(APFloatTest, Float8E5M2ToDouble) {
+  APFloat One(APFloat::Float8E5M2(), "1.0");
+  EXPECT_EQ(1.0, One.convertToDouble());
+  APFloat Two(APFloat::Float8E5M2(), "2.0");
+  EXPECT_EQ(2.0, Two.convertToDouble());
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(5.734400e+04, PosLargest.convertToDouble());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-5.734400e+04, NegLargest.convertToDouble());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(0x1.p-14, PosSmallest.convertToDouble());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-0x1.p-14, NegSmallest.convertToDouble());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1p-16, SmallestDenorm.convertToDouble());
+
+  APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2());
+  EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
+  APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2());
+  EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
+}
+
+TEST(APFloatTest, IEEEsingleToFloat) {
   APFloat FPosZero(0.0F);
   APFloat FPosZeroToFloat(FPosZero.convertToFloat());
   EXPECT_TRUE(FPosZeroToFloat.isPosZero());
@@ -4935,7 +4971,9 @@
   EXPECT_EQ(-std::numeric_limits<float>::infinity(), FNegInf.convertToFloat());
   APFloat FQNaN = APFloat::getQNaN(APFloat::IEEEsingle());
   EXPECT_TRUE(std::isnan(FQNaN.convertToFloat()));
+}
 
+TEST(APFloatTest, IEEEhalfToFloat) {
   APFloat HPosZero = APFloat::getZero(APFloat::IEEEhalf());
   APFloat HPosZeroToFloat(HPosZero.convertToFloat());
   EXPECT_TRUE(HPosZeroToFloat.isPosZero());
@@ -4969,7 +5007,9 @@
   EXPECT_EQ(-std::numeric_limits<float>::infinity(), HNegInf.convertToFloat());
   APFloat HQNaN = APFloat::getQNaN(APFloat::IEEEhalf());
   EXPECT_TRUE(std::isnan(HQNaN.convertToFloat()));
+}
 
+TEST(APFloatTest, BFloatToFloat) {
   APFloat BPosZero = APFloat::getZero(APFloat::BFloat());
   APFloat BPosZeroToDouble(BPosZero.convertToFloat());
   EXPECT_TRUE(BPosZeroToDouble.isPosZero());
@@ -5008,4 +5048,41 @@
   APFloat BQNaN = APFloat::getQNaN(APFloat::BFloat());
   EXPECT_TRUE(std::isnan(BQNaN.convertToFloat()));
 }
+
+TEST(APFloatTest, Float8E5M2ToFloat) {
+  APFloat PosZero = APFloat::getZero(APFloat::Float8E5M2());
+  APFloat PosZeroToFloat(PosZero.convertToFloat());
+  EXPECT_TRUE(PosZeroToFloat.isPosZero());
+  APFloat NegZero = APFloat::getZero(APFloat::Float8E5M2(), true);
+  APFloat NegZeroToFloat(NegZero.convertToFloat());
+  EXPECT_TRUE(NegZeroToFloat.isNegZero());
+
+  APFloat One(APFloat::Float8E5M2(), "1.0");
+  EXPECT_EQ(1.0F, One.convertToFloat());
+  APFloat Two(APFloat::Float8E5M2(), "2.0");
+  EXPECT_EQ(2.0F, Two.convertToFloat());
+
+  APFloat PosLargest = APFloat::getLargest(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(5.734400e+04, PosLargest.convertToFloat());
+  APFloat NegLargest = APFloat::getLargest(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-5.734400e+04, NegLargest.convertToFloat());
+  APFloat PosSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), false);
+  EXPECT_EQ(0x1.p-14, PosSmallest.convertToFloat());
+  APFloat NegSmallest =
+      APFloat::getSmallestNormalized(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-0x1.p-14, NegSmallest.convertToFloat());
+
+  APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E5M2(), false);
+  EXPECT_TRUE(SmallestDenorm.isDenormal());
+  EXPECT_EQ(0x1.p-16, SmallestDenorm.convertToFloat());
+
+  APFloat PosInf = APFloat::getInf(APFloat::Float8E5M2());
+  EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
+  APFloat NegInf = APFloat::getInf(APFloat::Float8E5M2(), true);
+  EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
+  APFloat QNaN = APFloat::getQNaN(APFloat::Float8E5M2());
+  EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
 }
+
+} // namespace
Index: llvm/lib/Support/APFloat.cpp
===================================================================
--- llvm/lib/Support/APFloat.cpp
+++ llvm/lib/Support/APFloat.cpp
@@ -80,6 +80,7 @@
   static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
+  static const fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
   static const fltSemantics semX87DoubleExtended = {16383, -16382, 64, 80};
   static const fltSemantics semBogus = {0, 0, 0, 0};
 
@@ -131,12 +132,14 @@
       return IEEEsingle();
     case S_IEEEdouble:
       return IEEEdouble();
-    case S_x87DoubleExtended:
-      return x87DoubleExtended();
     case S_IEEEquad:
       return IEEEquad();
     case S_PPCDoubleDouble:
       return PPCDoubleDouble();
+    case S_Float8E5M2:
+      return Float8E5M2();
+    case S_x87DoubleExtended:
+      return x87DoubleExtended();
     }
     llvm_unreachable("Unrecognised floating semantics");
   }
@@ -151,12 +154,14 @@
       return S_IEEEsingle;
     else if (&Sem == &llvm::APFloat::IEEEdouble())
       return S_IEEEdouble;
-    else if (&Sem == &llvm::APFloat::x87DoubleExtended())
-      return S_x87DoubleExtended;
     else if (&Sem == &llvm::APFloat::IEEEquad())
       return S_IEEEquad;
     else if (&Sem == &llvm::APFloat::PPCDoubleDouble())
       return S_PPCDoubleDouble;
+    else if (&Sem == &llvm::APFloat::Float8E5M2())
+      return S_Float8E5M2;
+    else if (&Sem == &llvm::APFloat::x87DoubleExtended())
+      return S_x87DoubleExtended;
     else
       llvm_unreachable("Unknown floating semantics");
   }
@@ -173,18 +178,15 @@
   const fltSemantics &APFloatBase::IEEEdouble() {
     return semIEEEdouble;
   }
-  const fltSemantics &APFloatBase::IEEEquad() {
-    return semIEEEquad;
+  const fltSemantics &APFloatBase::IEEEquad() { return semIEEEquad; }
+  const fltSemantics &APFloatBase::PPCDoubleDouble() {
+    return semPPCDoubleDouble;
   }
+  const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
   const fltSemantics &APFloatBase::x87DoubleExtended() {
     return semX87DoubleExtended;
   }
-  const fltSemantics &APFloatBase::Bogus() {
-    return semBogus;
-  }
-  const fltSemantics &APFloatBase::PPCDoubleDouble() {
-    return semPPCDoubleDouble;
-  }
+  const fltSemantics &APFloatBase::Bogus() { return semBogus; }
 
   constexpr RoundingMode APFloatBase::rmNearestTiesToEven;
   constexpr RoundingMode APFloatBase::rmTowardPositive;
@@ -3353,6 +3355,33 @@
                     (mysignificand & 0x3ff)));
 }
 
+APInt IEEEFloat::convertFloat8E5M2APFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semFloat8E5M2);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 15; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x4))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0x1f;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0x1f;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(8, (((sign & 1) << 7) | ((myexponent & 0x1f) << 2) |
+                   (mysignificand & 0x3)));
+}
+
 // This function creates an APInt that is just a bit map of the floating
 // point constant as it would appear in memory.  It is not a conversion,
 // and treating the result as a normal integer is unlikely to be useful.
@@ -3376,6 +3405,9 @@
   if (semantics == (const llvm::fltSemantics *)&semPPCDoubleDoubleLegacy)
     return convertPPCDoubleDoubleAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2)
+    return convertFloat8E5M2APFloatToAPInt();
+
   assert(semantics == (const llvm::fltSemantics*)&semX87DoubleExtended &&
          "unknown format!");
   return convertF80LongDoubleAPFloatToAPInt();
@@ -3603,6 +3635,34 @@
   }
 }
 
+void IEEEFloat::initFromFloat8E5M2APInt(const APInt &api) {
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 2) & 0x1f;
+  uint32_t mysignificand = i & 0x3;
+
+  initialize(&semFloat8E5M2);
+  assert(partCount() == 1);
+
+  sign = i >> 7;
+  if (myexponent == 0 && mysignificand == 0) {
+    makeZero(sign);
+  } else if (myexponent == 0x1f && mysignificand == 0) {
+    makeInf(sign);
+  } else if (myexponent == 0x1f && mysignificand != 0) {
+    category = fcNaN;
+    exponent = exponentNaN();
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 15; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -14;
+    else
+      *significandParts() |= 0x4; // integer bit
+  }
+}
+
 /// Treat api as containing the bits of a floating point number.  Currently
 /// we infer the floating point type from the size of the APInt.  The
 /// isIEEE argument distinguishes between PPC128 and IEEE128 (not meaningful
@@ -3623,6 +3683,8 @@
     return initFromQuadrupleAPInt(api);
   if (Sem == &semPPCDoubleDoubleLegacy)
     return initFromPPCDoubleDoubleAPInt(api);
+  if (Sem == &semFloat8E5M2)
+    return initFromFloat8E5M2APInt(api);
 
   llvm_unreachable(nullptr);
 }
Index: llvm/include/llvm/ADT/APFloat.h
===================================================================
--- llvm/include/llvm/ADT/APFloat.h
+++ llvm/include/llvm/ADT/APFloat.h
@@ -153,10 +153,13 @@
     S_BFloat,
     S_IEEEsingle,
     S_IEEEdouble,
-    S_x87DoubleExtended,
     S_IEEEquad,
     S_PPCDoubleDouble,
-    S_MaxSemantics = S_PPCDoubleDouble
+    // 8-bit floating point number following IEEE-754 conventions with bit
+    // layout S1E5M2 as described in https://arxiv.org/abs/2209.05433
+    S_Float8E5M2,
+    S_x87DoubleExtended,
+    S_MaxSemantics = S_x87DoubleExtended,
   };
 
   static const llvm::fltSemantics &EnumToSemantics(Semantics S);
@@ -168,6 +171,7 @@
   static const fltSemantics &IEEEdouble() LLVM_READNONE;
   static const fltSemantics &IEEEquad() LLVM_READNONE;
   static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
+  static const fltSemantics &Float8E5M2() LLVM_READNONE;
   static const fltSemantics &x87DoubleExtended() LLVM_READNONE;
 
   /// A Pseudo fltsemantic used to construct APFloats that cannot conflict with
@@ -552,6 +556,7 @@
   APInt convertQuadrupleAPFloatToAPInt() const;
   APInt convertF80LongDoubleAPFloatToAPInt() const;
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
+  APInt convertFloat8E5M2APFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
   void initFromBFloatAPInt(const APInt &api);
@@ -560,6 +565,7 @@
   void initFromQuadrupleAPInt(const APInt &api);
   void initFromF80LongDoubleAPInt(const APInt &api);
   void initFromPPCDoubleDoubleAPInt(const APInt &api);
+  void initFromFloat8E5M2APInt(const APInt &api);
 
   void assign(const IEEEFloat &);
   void copySignificand(const IEEEFloat &);
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -838,6 +838,8 @@
   case APFloat::S_x87DoubleExtended: Out << 'X'; break;
   case APFloat::S_IEEEquad: Out << 'Y'; break;
   case APFloat::S_PPCDoubleDouble: Out << 'Z'; break;
+  default:
+    llvm_unreachable("Tried to mangle unexpected APFloat semantics");
   }
 
   mangleBits(Number.bitcastToAPInt());
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to