This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new eb5492a  fix: Compatibility with torch <= 2.8 (#54)
eb5492a is described below

commit eb5492a1d2feaff5f11a51f32cf5ad339c789665
Author: Zihao Ye <[email protected]>
AuthorDate: Thu Sep 25 02:42:09 2025 -0700

    fix: Compatibility with torch <= 2.8 (#54)
    
    Note that `Float8_e8m0fnu` is also introduced in torch 2.8 and should be
    guarded.
---
 python/tvm_ffi/_optional_torch_c_dlpack.py | 6 ++++--
 python/tvm_ffi/cython/dtype.pxi            | 5 ++++-
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index cc40b31..bd88bff 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -120,10 +120,10 @@ DLDataType getDLDataTypeForDLPackv1(const Tensor& t) {
     case ScalarType::Float8_e4m3fnuz:
       dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
       break;
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
     case ScalarType::Float8_e8m0fnu:
       dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
       break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
     case ScalarType::Float4_e2m1fn_x2:
       dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
       dtype.lanes = 2;
@@ -269,11 +269,13 @@ static Device getATenDeviceForDLPackv1(DLDeviceType type, 
c10::DeviceIndex index
 
 ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) {
   ScalarType stype = ScalarType::Undefined;
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
   if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
     TORCH_CHECK(
         dtype.lanes == 1,
         "ATen does not support lanes != 1 for dtype code", 
std::to_string(dtype.code));
   }
+#endif
   switch (dtype.code) {
     case DLDataTypeCode::kDLUInt:
       switch (dtype.bits) {
@@ -405,6 +407,7 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
               false, "Unsupported kDLFloat8_e4m3fnuz bits ", 
std::to_string(dtype.bits));
       }
       break;
+#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
     case DLDataTypeCode::kDLFloat8_e8m0fnu:
       switch (dtype.bits) {
         case 8:
@@ -415,7 +418,6 @@ ScalarType toScalarTypeForDLPackv1(const DLDataType& dtype) 
{
               false, "Unsupported kDLFloat8_e8m0fnu bits ", 
std::to_string(dtype.bits));
       }
       break;
-#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 8
     case DLDataTypeCode::kDLFloat4_e2m1fn:
       switch (dtype.bits) {
         case 4:
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 7649885..9d76be6 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -172,8 +172,11 @@ if torch is not None:
         torch.float8_e4m3fnuz: DLDataType(11, 8, 1),
         torch.float8_e5m2: DLDataType(12, 8, 1),
         torch.float8_e5m2fnuz: DLDataType(13, 8, 1),
-        torch.float8_e8m0fnu: DLDataType(14, 8, 1),
     }
+    if hasattr(torch, "float8_e8m0fnu"):
+        TORCH_DTYPE_TO_DTYPE[torch.float8_e8m0fnu] = DLDataType(14, 8, 1)
+    if hasattr(torch, "float4_e2m1fn_x2"):
+        TORCH_DTYPE_TO_DTYPE[torch.float4_e2m1fn_x2] = DLDataType(17, 4, 2)
 
     def _convert_torch_dtype_to_ffi_dtype(torch_dtype):
         cdef DLDataType cdtype = TORCH_DTYPE_TO_DTYPE[torch_dtype]

Reply via email to