This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 5e648f0 Introduce dtype protocol (#178)
5e648f0 is described below
commit 5e648f052f93a423f7ef3cb20d2675925d64bf96
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Oct 20 11:52:30 2025 -0700
Introduce dtype protocol (#178)
This PR adds support for `__dlpack_data_type__` to enable exchange of
dtype from arguments. Note the intention is to use DLDataType to ingest
dtype into tvm ffi since it is the most close to metal repr.
Likely we also can look into numpy dtype protocol once that
materializes.
---
python/tvm_ffi/_dtype.py | 44 +++++++++++++++++++++++++++++---------
python/tvm_ffi/cython/dtype.pxi | 2 +-
python/tvm_ffi/cython/function.pxi | 18 +++++++++++++++-
tests/python/test_dtype.py | 7 ++++++
tests/python/test_function.py | 16 ++++++++++++++
5 files changed, 75 insertions(+), 12 deletions(-)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index b712c65..de2d1f1 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -60,15 +60,39 @@ class dtype(str):
"""
- __slots__ = ["__tvm_ffi_dtype__"]
- __tvm_ffi_dtype__: core.DataType
+ __slots__ = ["_tvm_ffi_dtype"]
+ _tvm_ffi_dtype: core.DataType
_NUMPY_DTYPE_TO_STR: ClassVar[dict[Any, str]] = {}
def __new__(cls, content: Any) -> dtype:
content = str(content)
val = str.__new__(cls, content)
- val.__tvm_ffi_dtype__ = core.DataType(content)
+ val._tvm_ffi_dtype = core.DataType(content)
+ return val
+
+ @staticmethod
+ def from_dlpack_data_type(dltype_data_type: tuple[int, int, int]) -> dtype:
+ """Create a dtype from a DLPack data type tuple.
+
+ Parameters
+ ----------
+ dltype_data_type
+ The DLPack data type tuple (type_code, bits, lanes).
+
+ Returns
+ -------
+ The created dtype.
+
+ """
+ cdtype = core._create_dtype_from_tuple(
+ core.DataType,
+ dltype_data_type[0],
+ dltype_data_type[1],
+ dltype_data_type[2],
+ )
+ val = str.__new__(dtype, str(cdtype))
+ val._tvm_ffi_dtype = cdtype
return val
def __repr__(self) -> str:
@@ -90,29 +114,29 @@ class dtype(str):
"""
cdtype = core._create_dtype_from_tuple(
core.DataType,
- self.__tvm_ffi_dtype__.type_code,
- self.__tvm_ffi_dtype__.bits,
+ self._tvm_ffi_dtype.type_code,
+ self._tvm_ffi_dtype.bits,
lanes,
)
val = str.__new__(dtype, str(cdtype))
- val.__tvm_ffi_dtype__ = cdtype
+ val._tvm_ffi_dtype = cdtype
return val
@property
def itemsize(self) -> int:
- return self.__tvm_ffi_dtype__.itemsize
+ return self._tvm_ffi_dtype.itemsize
@property
def type_code(self) -> int:
- return self.__tvm_ffi_dtype__.type_code
+ return self._tvm_ffi_dtype.type_code
@property
def bits(self) -> int:
- return self.__tvm_ffi_dtype__.bits
+ return self._tvm_ffi_dtype.bits
@property
def lanes(self) -> int:
- return self.__tvm_ffi_dtype__.lanes
+ return self._tvm_ffi_dtype.lanes
try:
diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi
index be32976..3d90346 100644
--- a/python/tvm_ffi/cython/dtype.pxi
+++ b/python/tvm_ffi/cython/dtype.pxi
@@ -155,7 +155,7 @@ cdef inline object make_ret_dtype(TVMFFIAny result):
cdtype = DataType.__new__(DataType)
(<DataType>cdtype).cdtype = result.v_dtype
val = str.__new__(_CLASS_DTYPE, cdtype.__str__())
- val.__tvm_ffi_dtype__ = cdtype
+ val._tvm_ffi_dtype = cdtype
return val
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 98fe027..1138476 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -308,7 +308,7 @@ cdef int TVMFFIPyArgSetterDType_(
"""Setter for dtype"""
cdef object arg = <object>py_arg
# dtype is a subclass of str, so this check occur before str
- arg = arg.__tvm_ffi_dtype__
+ arg = arg._tvm_ffi_dtype
out.type_index = kTVMFFIDataType
out.v_dtype = (<DataType>arg).cdtype
return 0
@@ -573,6 +573,18 @@ cdef int TVMFFIPyArgSetterDTypeFromNumpy_(
out.v_dtype = NUMPY_DTYPE_TO_DTYPE[py_obj]
return 0
+cdef int TVMFFIPyArgSetterDLPackDataTypeProtocol_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for dtype protocol"""
+ cdef object arg = <object>py_arg
+ cdef tuple dltype_data_type = arg.__dlpack_data_type__()
+ out.type_index = kTVMFFIDataType
+ out.v_dtype.code = <long long>dltype_data_type[0]
+ out.v_dtype.bits = <long long>dltype_data_type[1]
+ out.v_dtype.lanes = <long long>dltype_data_type[2]
+ return 0
cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
@@ -700,6 +712,10 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
if numpy is not None and isinstance(arg, numpy.dtype):
out.func = TVMFFIPyArgSetterDTypeFromNumpy_
return 0
+ if hasattr(arg_class, "__dlpack_data_type__"):
+ # prefer dlpack as it covers all DLDataType struct
+ out.func = TVMFFIPyArgSetterDLPackDataTypeProtocol_
+ return 0
if isinstance(arg, Exception):
out.func = TVMFFIPyArgSetterException_
return 0
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index 38e72fb..60d53ef 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -160,3 +160,10 @@ def test_ml_dtypes_dtype_conversion() -> None:
_check_dtype(np.dtype(ml_dtypes.float6_e2m3fn), 15, 6, 1)
_check_dtype(np.dtype(ml_dtypes.float6_e3m2fn), 16, 6, 1)
_check_dtype(np.dtype(ml_dtypes.float4_e2m1fn), 17, 4, 1)
+
+
+def test_dtype_from_dlpack_data_type() -> None:
+ dtype = tvm_ffi.dtype.from_dlpack_data_type((0, 8, 1))
+ assert dtype.type_code == 0
+ assert dtype.bits == 8
+ assert dtype.lanes == 1
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 797e1a4..6123a8e 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import ctypes
import gc
@@ -312,3 +313,18 @@ def test_function_with_opaque_ptr_protocol() -> None:
y = fecho(x)
assert isinstance(y, ctypes.c_void_p)
assert y.value == 10
+
+
+def test_function_with_dlpack_data_type_protocol() -> None:
+ class DLPackDataTypeProtocol:
+ def __init__(self, dlpack_data_type: tuple[int, int, int]) -> None:
+ self.dlpack_data_type = dlpack_data_type
+
+ def __dlpack_data_type__(self) -> tuple[int, int, int]:
+ return self.dlpack_data_type
+
+ dtype = tvm_ffi.dtype("float32")
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ x = DLPackDataTypeProtocol((dtype.type_code, dtype.bits, dtype.lanes))
+ y = fecho(x)
+ assert y == dtype