This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new eb5492a fix: Compatibility with torch <= 2.8 (#54)
eb5492a is described below
commit eb5492a1d2feaff5f11a51f32cf5ad339c789665
Author: Zihao Ye <[email protected]>
AuthorDate: Thu Sep 25 02:42:09 2025 -0700
fix: Compatibility with torch <= 2.8 (#54)
Note that `Float8_e8m0fnu` is also introduced in torch 2.8 and should be
guarded.
---
python/tvm_ffi/_optional_torch_c_dlpack.py | 6 ++++--
python/tvm_ffi/cython/dtype.pxi | 5 ++++-
2 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index cc40b31..bd88bff 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -120,10 +120,10 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
case ScalarType::Float8_e4m3fnuz:
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
break;
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
case ScalarType::Float8_e8m0fnu:
dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
case ScalarType::Float4_e2m1fn_x2:
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
dtype.lanes = 2;
@@ -269,11 +269,13 @@ static Device getATenDeviceForDLPackv1(DLDeviceType type,
c10::DeviceIndex index
ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) {
ScalarType stype = ScalarType::Undefined;
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
TORCH_CHECK(
dtype.lanes == 1,
"ATen does not support lanes != 1 for dtype code",
std::to_string(dtype.code));
}
+#endif
switch (dtype.code) {
case DLDataTypeCode::kDLUInt:
switch (dtype.bits) {
@@ -405,6 +407,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype)
{
false, "Unsupported kDLFloat8_e4m3fnuz bits ",
std::to_string(dtype.bits));
}
break;
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
case DLDataTypeCode::kDLFloat8_e8m0fnu:
switch (dtype.bits) {
case 8:
@@ -415,7 +418,6 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype)
{
false, "Unsupported kDLFloat8_e8m0fnu bits ",
std::to_string(dtype.bits));
}
break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
case DLDataTypeCode::kDLFloat4_e2m1fn:
switch (dtype.bits) {
case 4:
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 7649885..9d76be6 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -172,8 +172,11 @@ if torch is not None:
torch.float8_e4m3fnuz: DLDataType(11, 8, 1),
torch.float8_e5m2: DLDataType(12, 8, 1),
torch.float8_e5m2fnuz: DLDataType(13, 8, 1),
- torch.float8_e8m0fnu: DLDataType(14, 8, 1),
}
+ if hasattr(torch, "float8_e8m0fnu"):
+ TORCH_DTYPE_TO_DTYPE[torch.float8_e8m0fnu] = DLDataType(14, 8, 1)
+ if hasattr(torch, "float4_e2m1fn_x2"):
+ TORCH_DTYPE_TO_DTYPE[torch.float4_e2m1fn_x2] = DLDataType(17, 4, 2)
def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
cdef DLDataType cdtype = TORCH_DTYPE_TO_DTYPE[torch_dtype]