This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 4bc8925 [CYTHON] Introduce __tvm_ffi_tensor__ protocol (#108)
4bc8925 is described below
commit 4bc892542b937074e1d813ae7e7f8ea204be4aac
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Oct 13 14:04:10 2025 -0400
[CYTHON] Introduce __tvm_ffi_tensor__ protocol (#108)
Sometimes it can be helpful for another Tensor class to take ffi.Tensor
as a member. This PR introduces `__tvm_ffi_tensor__()` protocol to allow
fast fetch of the related tensors for fast calling.
---
python/tvm_ffi/cython/base.pxi | 2 +-
python/tvm_ffi/cython/function.pxi | 28 +++++++++++++++++++++++++++-
tests/python/test_tensor.py | 20 ++++++++++++++++++++
tests/scripts/benchmark_dlpack.py | 29 +++++++++++++++++++++++++++++
4 files changed, 77 insertions(+), 2 deletions(-)
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index a86ea00..4272a20 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -19,7 +19,7 @@ from libc.stdint cimport int32_t, int64_t, uint64_t,
uint32_t, uint8_t, int16_t
from libc.string cimport memcpy
from libcpp.vector cimport vector
from cpython.bytes cimport PyBytes_AsStringAndSize, PyBytes_FromStringAndSize,
PyBytes_AsString
-from cpython cimport Py_INCREF, Py_DECREF
+from cpython cimport Py_INCREF, Py_DECREF, Py_REFCNT
from cpython cimport PyErr_CheckSignals, PyGILState_Ensure,
PyGILState_Release, PyObject
from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 7dcab4b..0305fec 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -267,6 +267,26 @@ cdef int TVMFFIPyArgSetterDLPack_(
return 0
+cdef int TVMFFIPyArgSetterFFITensorCompatible_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for objects that implement the `__tvm_ffi_tensor__` protocol."""
+ cdef object arg = <object>py_arg
+ cdef TVMFFIObjectHandle temp_chandle
+ cdef Tensor tensor = arg.__tvm_ffi_tensor__()
+ cdef long ref_count = Py_REFCNT(tensor)
+ temp_chandle = tensor.chandle
+ out.type_index = kTVMFFITensor
+ out.v_ptr = temp_chandle
+ if ref_count == 1:
+ # keep alive the tensor, since the tensor is temporary
+ # and will be freed after we exit here
+ TVMFFIObjectIncRef(temp_chandle)
+ TVMFFIPyPushTempFFIObject(ctx, temp_chandle)
+ return 0
+
+
cdef int TVMFFIPyArgSetterDType_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
@@ -570,10 +590,16 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
if isinstance(arg, ObjectRValueRef):
out.func = TVMFFIPyArgSetterObjectRValueRef_
return 0
+ arg_class = type(arg)
+ if hasattr(arg_class, "__tvm_ffi_tensor__"):
+ # can directly map to tvm ffi tensor
+ # usually used for solutions that takes ffi.Tensor
+ # as a member variable
+ out.func = TVMFFIPyArgSetterFFITensorCompatible_
+ return 0
if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
# Check for DLPackExchangeAPI struct (new approach)
# This is checked on the CLASS, not the instance
- arg_class = type(arg)
if hasattr(arg_class, "__c_dlpack_exchange_api__"):
out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
temp_ptr = arg_class.__c_dlpack_exchange_api__
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 9b73ce9..48d933f 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -92,3 +92,23 @@ def test_tensor_class_override() -> None:
y = fecho(x)
assert isinstance(y, MyTensor)
tvm_ffi.core._set_class_tensor(old_tensor)
+
+
+def test_tvm_ffi_tensor_compatible() -> None:
+ class MyTensor:
+ def __init__(self, tensor: tvm_ffi.Tensor) -> None:
+ """Initialize the MyTensor."""
+ self._tensor = tensor
+
+ def __tvm_ffi_tensor__(self) -> tvm_ffi.Tensor:
+ """Implement __tvm_ffi_tensor__ protocol."""
+ return self._tensor
+
+ data = np.zeros((10, 8, 4, 2), dtype="int32")
+ if not hasattr(data, "__dlpack__"):
+ return
+ x = tvm_ffi.from_dlpack(data)
+ y = MyTensor(x)
+ fecho = tvm_ffi.get_global_func("testing.echo")
+ z = fecho(y)
+ assert z.__chandle__() == x.__chandle__()
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index 4366b58..6c36246 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -38,6 +38,18 @@ import torch
import tvm_ffi
+class TestFFITensor:
+ """Test FFI Tensor that exposes __tvm_ffi_tensor__ protocol."""
+
+ def __init__(self, tensor: tvm_ffi.Tensor) -> None:
+ """Initialize the TestFFITensor."""
+ self._tensor = tensor
+
+ def __tvm_ffi_tensor__(self) -> tvm_ffi.Tensor:
+ """Implement __tvm_ffi_tensor__ protocol."""
+ return self._tensor
+
+
def print_speed(name: str, speed: float) -> None:
print(f"{name:<60} {speed} sec/call")
@@ -262,6 +274,21 @@ def
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str)
)
+def tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat: int, device: str) ->
None:
+ """Measures overhead of running dlpack via auto convert by directly
+ take test wrapper as inputs. This effectively measure DLPack exchange in
tvm ffi.
+ """
+ x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+ y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+ z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
+ x = TestFFITensor(x)
+ y = TestFFITensor(y)
+ z = TestFFITensor(z)
+ bench_tvm_ffi_nop_autodlpack(
+ f"tvm_ffi.nop.autodlpack(TestFFITensor[{device}])", x, y, z, repeat
+ )
+
+
def bench_to_dlpack(x: Any, name: str, repeat: int) -> None:
x.__dlpack__()
start = time.time()
@@ -372,6 +399,8 @@ def main() -> None: # noqa: PLR0915
tvm_ffi_nop_autodlpack_from_numpy(repeat)
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
+ tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cpu")
+ tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cuda")
tvm_ffi_nop(repeat)
print("-------------------------------")
print("Benchmark x.__dlpack__ overhead")