This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 9574e9d  feat: support `ml_dtypes<0.5` (#198)
9574e9d is described below

commit 9574e9d0afb50eca58f56cf0cdce9f64f8cd319f
Author: Yichen Yan <[email protected]>
AuthorDate: Tue Oct 28 22:02:32 2025 +0800

    feat: support `ml_dtypes<0.5` (#198)
    
    As title, currently it failed with `ml_dtypes==0.4.x`:
    https://github.com/tile-ai/tilelang/pull/1108#issuecomment-3455197314
---
 python/tvm_ffi/cython/dtype.pxi | 20 ++++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 3d90346..15f9418 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -205,24 +205,28 @@ else:
 
 if ml_dtypes is not None:
     MLDTYPES_DTYPE_TO_DTYPE = {
-        numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
         numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
-        numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
         numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
         numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
-        numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1),
-        numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1),
-        numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1),
-        numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1),
     }
 
+    if hasattr(ml_dtypes, "int2"):  # ml_dtypes >= 0.5.0
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0, 
2, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] = DLDataType(1, 
2, 1)
+
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] = 
DLDataType(7, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] = 
DLDataType(8, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] = 
DLDataType(14, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] = 
DLDataType(15, 6, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] = 
DLDataType(16, 6, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] = 
DLDataType(17, 4, 1)
+
+
 if numpy is not None:
     NUMPY_DTYPE_TO_DTYPE = {
         numpy.dtype(numpy.int8): DLDataType(0, 8, 1),

Reply via email to