This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 9574e9d feat: support `ml_dtypes<0.5` (#198)
9574e9d is described below
commit 9574e9d0afb50eca58f56cf0cdce9f64f8cd319f
Author: Yichen Yan <[email protected]>
AuthorDate: Tue Oct 28 22:02:32 2025 +0800
feat: support `ml_dtypes<0.5` (#198)
As title, currently it failed with `ml_dtypes==0.4.x`:
https://github.com/tile-ai/tilelang/pull/1108#issuecomment-3455197314
---
python/tvm_ffi/cython/dtype.pxi | 20 ++++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index 3d90346..15f9418 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -205,24 +205,28 @@ else:
if ml_dtypes is not None:
MLDTYPES_DTYPE_TO_DTYPE = {
- numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
- numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
- numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
- numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),
numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1),
numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1),
numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1),
- numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1),
- numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1),
- numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1),
- numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1),
}
+ if hasattr(ml_dtypes, "int2"): # ml_dtypes >= 0.5.0
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0,
2, 1)
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] = DLDataType(1,
2, 1)
+
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] =
DLDataType(7, 8, 1)
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] =
DLDataType(8, 8, 1)
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] =
DLDataType(14, 8, 1)
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] =
DLDataType(15, 6, 1)
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] =
DLDataType(16, 6, 1)
+ MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] =
DLDataType(17, 4, 1)
+
+
if numpy is not None:
NUMPY_DTYPE_TO_DTYPE = {
numpy.dtype(numpy.int8): DLDataType(0, 8, 1),