This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b9a2b92 fix: More fix of `ml_dtypes<0.5` (#222)
b9a2b92 is described below
commit b9a2b9231aa94464d8228a1b2dfbb725941a0eea
Author: Yichen Yan <[email protected]>
AuthorDate: Wed Nov 5 09:49:24 2025 +0800
fix: More fix of `ml_dtypes<0.5` (#222)
This is a followup of #198 of a missing fix.
---
python/tvm_ffi/_dtype.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 5079226..ebc2585 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -314,7 +314,8 @@ try:
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e4m3fn)] =
"float8_e4m3fn"
dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float8_e5m2)] = "float8_e5m2"
- dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] =
"float4_e2m1fn"
+ if hasattr(ml_dtypes, "float4_e2m1fn"): # ml_dtypes >= 0.5.0
+ dtype._NUMPY_DTYPE_TO_STR[np.dtype(ml_dtypes.float4_e2m1fn)] =
"float4_e2m1fn"
except ImportError:
pass