This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ea7a1dd421603cda511865219395b6c6cf8d4860 Author: tqchen <[email protected]> AuthorDate: Mon Apr 21 18:41:01 2025 -0400 [FFI] DType support for dlpack v1.1 --- ffi/include/tvm/ffi/dtype.h | 114 ++++++++++++++++++++++++++++++++++---------- ffi/tests/cpp/test_dtype.cc | 27 ++++++++++- 2 files changed, 115 insertions(+), 26 deletions(-) diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h index 804edacd46..257b6bc158 100644 --- a/ffi/include/tvm/ffi/dtype.h +++ b/ffi/include/tvm/ffi/dtype.h @@ -43,12 +43,7 @@ namespace ffi { * * TOTO(tvm-team): update to latest DLPack types. */ -enum DLExtDataTypeCode { - kDLExtFloat8_e4m3fn = 6, - kDLExtFloat8_e5m2 = 7, - kDLExtFloat4_e2m1fn = 8, - kDLExtCustomBegin = 129 -}; +enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; namespace details { /*! @@ -121,15 +116,47 @@ inline void PrintDLDataTypeCodeAsStr(std::ostream& os, DLDataTypeCode type_code) os << "bfloat"; break; } - case kDLExtFloat8_e4m3fn: { + case kDLFloat8_e3m4: { + os << "float8_e3m4"; + break; + } + case kDLFloat8_e4m3: { + os << "float8_e4m3"; + break; + } + case kDLFloat8_e4m3b11fnuz: { + os << "float8_e4m3b11fnuz"; + break; + } + case kDLFloat8_e4m3fn: { os << "float8_e4m3fn"; break; } - case kDLExtFloat8_e5m2: { + case kDLFloat8_e4m3fnuz: { + os << "float8_e4m3fnuz"; + break; + } + case kDLFloat8_e5m2: { os << "float8_e5m2"; break; } - case kDLExtFloat4_e2m1fn: { + case kDLFloat8_e5m2fnuz: { + os << "float8_e5m2fnuz"; + break; + } + case kDLFloat8_e8m0fnu: { + os << "float8_e8m0fnu"; + break; + } + case kDLFloat6_e2m3fn: { + os << "float6_e2m3fn"; + break; + } + case kDLFloat6_e3m2fn: { + os << "float6_e3m2fn"; + break; + } + case kDLFloat4_e2m1fn: { os << "float4_e2m1fn"; break; } @@ -164,8 +191,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType dtype) { // NOLINT details::PrintDLDataTypeCodeAsStr(os, static_cast<DLDataTypeCode>(dtype.code)); if (dtype.code == kDLOpaqueHandle) return os; int16_t lanes = static_cast<int16_t>(dtype.lanes); - if (dtype.code != kDLExtFloat8_e4m3fn && dtype.code != kDLExtFloat8_e5m2 && - dtype.code != kDLExtFloat4_e2m1fn) { + if (dtype.code < kDLFloat8_e3m4) { os << static_cast<int>(dtype.bits); } if (lanes > 1) { @@ -223,22 +249,60 @@ inline DLDataType StringToDLDataType(const std::string& str) { return dtype; }; - if (str.substr(0, 3) == "int") { + if (str.compare(0, 3, "int") == 0) { dtype.code = kDLInt; scan = str.c_str() + 3; - } else if (str.substr(0, 4) == "uint") { + } else if (str.compare(0, 4, "uint") == 0) { dtype.code = kDLUInt; scan = str.c_str() + 4; - } else if (str.substr(0, 13) == "float4_e2m1fn") { - return parse_float(str, 13, DLExtDataTypeCode::kDLExtFloat4_e2m1fn, 4); - } else if (str.substr(0, 13) == "float8_e4m3fn") { - return parse_float(str, 13, DLExtDataTypeCode::kDLExtFloat8_e4m3fn, 8); - } else if (str.substr(0, 11) == "float8_e5m2") { - return parse_float(str, 11, DLExtDataTypeCode::kDLExtFloat8_e5m2, 8); - } else if (str.substr(0, 5) == "float") { - dtype.code = kDLFloat; - scan = str.c_str() + 5; - } else if (str.substr(0, 6) == "handle") { + } else if (str.compare(0, 5, "float") == 0) { + if (str.compare(5, 2, "8_") == 0) { + if (str.compare(7, 4, "e3m4") == 0) { + return parse_float(str, 11, kDLFloat8_e3m4, 8); + } else if (str.compare(7, 4, "e4m3") == 0) { + if (str.compare(11, 7, "b11fnuz") == 0) { + return parse_float(str, 18, kDLFloat8_e4m3b11fnuz, 8); + } else if (str.compare(11, 2, "fn") == 0) { + if (str.compare(13, 2, "uz") == 0) { + return parse_float(str, 15, kDLFloat8_e4m3fnuz, 8); + } else { + return parse_float(str, 13, kDLFloat8_e4m3fn, 8); + } + } else { + return parse_float(str, 11, kDLFloat8_e4m3, 8); + } + } else if (str.compare(7, 8, "e5m2fnuz") == 0) { + return parse_float(str, 15, kDLFloat8_e5m2fnuz, 8); + } else if (str.compare(7, 4, "e5m2") == 0) { + return parse_float(str, 11, kDLFloat8_e5m2, 8); + } else if (str.compare(7, 7, "e8m0fnu") == 0) { + return parse_float(str, 14, kDLFloat8_e8m0fnu, 8); + } else { + TVM_FFI_THROW(ValueError) << "unknown float8 type `" << str << '`'; + TVM_FFI_UNREACHABLE(); + } + } else if (str.compare(5, 2, "6_") == 0) { + if (str.compare(7, 6, "e2m3fn") == 0) { + return parse_float(str, 13, kDLFloat6_e2m3fn, 6); + } else if (str.compare(7, 6, "e3m2fn") == 0) { + return parse_float(str, 13, kDLFloat6_e3m2fn, 6); + } else { + TVM_FFI_THROW(ValueError) << "unknown float6 type `" << str << '`'; + TVM_FFI_UNREACHABLE(); + } + } else if (str.compare(5, 2, "4_") == 0) { + // kFloat4_e2m1fn + if (str.compare(7, 6, "e2m1fn") == 0) { + return parse_float(str, 13, kDLFloat4_e2m1fn, 4); + } else { + TVM_FFI_THROW(ValueError) << "unknown float4 type `" << str << '`'; + TVM_FFI_UNREACHABLE(); + } + } else { + dtype.code = kDLFloat; + scan = str.c_str() + 5; + } + } else if (str.compare(0, 6, "handle") == 0) { dtype.code = kDLOpaqueHandle; dtype.bits = 64; // handle uses 64 bit by default. scan = str.c_str() + 6; @@ -247,11 +311,11 @@ inline DLDataType StringToDLDataType(const std::string& str) { dtype.bits = 1; dtype.lanes = 1; return dtype; - } else if (str.substr(0, 6) == "bfloat") { + } else if (str.compare(0, 6, "bfloat") == 0) { dtype.code = kDLBfloat; dtype.bits = 16; scan = str.c_str() + 6; - } else if (str.substr(0, 6) == "custom") { + } else if (str.compare(0, 6, "custom") == 0) { dtype.code = details::ParseCustomDataTypeCode(str, &scan); } else { scan = str.c_str(); diff --git a/ffi/tests/cpp/test_dtype.cc b/ffi/tests/cpp/test_dtype.cc index ad769b740a..3e3e43430e 100644 --- a/ffi/tests/cpp/test_dtype.cc +++ b/ffi/tests/cpp/test_dtype.cc @@ -44,11 +44,36 @@ TEST(DType, StringConversion) { EXPECT_EQ(StringToDLDataType("bfloat16x2"), dtype); // test float8 - dtype = DLDataType{kDLExtFloat8_e4m3fn, 8, 2}; + dtype = DLDataType{kDLFloat8_e4m3fn, 8, 2}; EXPECT_EQ(DLDataTypeToString(dtype), "float8_e4m3fnx2"); EXPECT_EQ(StringToDLDataType("float8_e4m3fnx2"), dtype); } +TEST(DType, StringConversionAllDLPackTypes) { + std::vector<std::pair<DLDataType, std::string>> test_cases = { + {DLDataType{kDLFloat, 32, 1}, "float32"}, + {DLDataType{kDLInt, 16, 1}, "int16"}, + {DLDataType{kDLUInt, 16, 1}, "uint16"}, + {DLDataType{kDLBfloat, 16, 1}, "bfloat16"}, + {DLDataType{kDLFloat8_e3m4, 8, 1}, "float8_e3m4"}, + {DLDataType{kDLFloat8_e4m3, 8, 1}, "float8_e4m3"}, + {DLDataType{kDLFloat8_e4m3b11fnuz, 8, 1}, "float8_e4m3b11fnuz"}, + {DLDataType{kDLFloat8_e4m3fn, 8, 1}, "float8_e4m3fn"}, + {DLDataType{kDLFloat8_e4m3fnuz, 8, 1}, "float8_e4m3fnuz"}, + {DLDataType{kDLFloat8_e5m2, 8, 1}, "float8_e5m2"}, + {DLDataType{kDLFloat8_e5m2fnuz, 8, 1}, "float8_e5m2fnuz"}, + {DLDataType{kDLFloat8_e8m0fnu, 8, 1}, "float8_e8m0fnu"}, + {DLDataType{kDLFloat6_e2m3fn, 6, 1}, "float6_e2m3fn"}, + {DLDataType{kDLFloat6_e3m2fn, 6, 1}, "float6_e3m2fn"}, + {DLDataType{kDLFloat4_e2m1fn, 4, 1}, "float4_e2m1fn"}, + }; + + for (const auto& [dtype, str] : test_cases) { + EXPECT_EQ(DLDataTypeToString(dtype), str); + EXPECT_EQ(StringToDLDataType(str), dtype); + } +} + TEST(DataType, AnyConversion) { AnyView view0; EXPECT_EQ(view0.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFINone);
