This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 0f8bf9f Introduce Device Protocol (#179)
0f8bf9f is described below
commit 0f8bf9fc582fff89e838cadf2aacbfb2a5724ddf
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Oct 20 21:07:30 2025 -0700
Introduce Device Protocol (#179)
This PR introduces `__dlpack_device__` protocol for Device class that
can be used to bring in device classes into ffi calls.
---
python/tvm_ffi/cython/function.pxi | 18 ++++++++++++++++++
tests/python/test_function.py | 16 ++++++++++++++++
2 files changed, 34 insertions(+)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 1138476..bf84091 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -324,6 +324,19 @@ cdef int TVMFFIPyArgSetterDevice_(
out.v_device = (<Device>arg).cdevice
return 0
+cdef int TVMFFIPyArgSetterDLPackDeviceProtocol_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for dlpack device protocol"""
+ cdef object arg = <object>py_arg
+ cdef tuple dlpack_device = arg.__dlpack_device__()
+ out.type_index = kTVMFFIDevice
+ out.v_device = TVMFFIDLDeviceFromIntPair(
+ <int32_t>dlpack_device[0],
+ <int32_t>dlpack_device[1]
+ )
+ return 0
cdef int TVMFFIPyArgSetterStr_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
@@ -716,6 +729,11 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
# prefer dlpack as it covers all DLDataType struct
out.func = TVMFFIPyArgSetterDLPackDataTypeProtocol_
return 0
+ if hasattr(arg_class, "__dlpack_device__") and not hasattr(arg_class,
"__dlpack__"):
+ # if a class have __dlpack_device__ but not __dlpack__
+ # then it is a DLPack device protocol
+ out.func = TVMFFIPyArgSetterDLPackDeviceProtocol_
+ return 0
if isinstance(arg, Exception):
out.func = TVMFFIPyArgSetterException_
return 0
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 6123a8e..8d6dd34 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -328,3 +328,19 @@ def test_function_with_dlpack_data_type_protocol() -> None:
x = DLPackDataTypeProtocol((dtype.type_code, dtype.bits, dtype.lanes))
y = fecho(x)
assert y == dtype
+
+
+def test_function_with_dlpack_device_protocol() -> None:
+ device = tvm_ffi.device("cuda:1")
+
+ class DLPackDeviceProtocol:
+ def __init__(self, device: tvm_ffi.Device) -> None:
+ self.device = device
+
+ def __dlpack_device__(self) -> tuple[int, int]:
+ return (self.device.dlpack_device_type(), self.device.index)
+
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ x = DLPackDeviceProtocol(device)
+ y = fecho(x)
+ assert y == device