This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 3dd7a81  [REFACTOR][FEAT] Introduce generic value protocol (#312)
3dd7a81 is described below

commit 3dd7a8173363bdf79806610818121e83e99b3b56
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Dec 4 18:57:46 2025 -0500

    [REFACTOR][FEAT] Introduce generic value protocol (#312)
    
    This is a generic protocol that gives ability for
    classes to declare how they will convert to tvm-ffi compatible values.
    Also did a round of refactor to consolidate classes in python helpers.
---
 python/tvm_ffi/cython/base.pxi                 |   8 +
 python/tvm_ffi/cython/function.pxi             |  19 +++
 python/tvm_ffi/cython/tensor.pxi               |   2 +-
 python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 228 +++++++++++++++----------
 tests/python/test_function.py                  |  19 +++
 tests/python/test_stream.py                    |   3 +
 tests/python/test_tensor.py                    |  23 +++
 tests/scripts/benchmark_dlpack.py              |  44 +++--
 8 files changed, 246 insertions(+), 100 deletions(-)

diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index a24a5d8..933bc86 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -364,10 +364,18 @@ cdef extern from "tvm_ffi_python_helpers.h":
         int* c_api_ret_code
     ) except -1
 
+    int TVMFFIPySetArgumentGenericDispatcher(
+        TVMFFIPyArgSetterFactory setter_factory,
+        TVMFFIPyCallContext* ctx,
+        PyObject* py_arg,
+        TVMFFIAny* out
+    ) except -1
+
     size_t TVMFFIPyGetDispatchMapSize() noexcept
 
     void TVMFFIPyPushTempFFIObject(TVMFFIPyCallContext* ctx, 
TVMFFIObjectHandle arg) noexcept
     void TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg) 
noexcept
+    void TVMFFIPyPushExtraTempPyObject(TVMFFIPyCallContext* ctx, PyObject* arg)
     # the predefined setters for common POD types
     int TVMFFIPyArgSetterFloat_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, 
PyObject* arg, TVMFFIAny* out) except -1
     int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*, 
PyObject* arg, TVMFFIAny* out) except -1
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 189a6fc..29af699 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -674,6 +674,22 @@ cdef int TVMFFIPyArgSetterFloatProtocol_(
     return 0
 
 
+cdef int TVMFFIPyArgSetterFFIValueProtocol_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for class with __tvm_ffi_value__() method"""
+    cdef object arg = <object>py_arg
+    cdef object ffi_value_py_obj = arg.__tvm_ffi_value__()
+    cdef PyObject* ffi_value_py_obj_ptr = <PyObject*>ffi_value_py_obj
+    # keep alive the python object since this is a temporary object
+    # we must push to extra temp py objects stack to avoid overflow the temp 
py objects stack
+    TVMFFIPyPushExtraTempPyObject(ctx, ffi_value_py_obj_ptr)
+    return TVMFFIPySetArgumentGenericDispatcher(
+        TVMFFIPyArgSetterFactory_, ctx, ffi_value_py_obj_ptr, out
+    )
+
+
 cdef _DISPATCH_TYPE_KEEP_ALIVE = set()
 cdef _DISPATCH_TYPE_KEEP_ALIVE_LOCK = threading.Lock()
 
@@ -824,6 +840,9 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
     if hasattr(arg_class, "__tvm_ffi_float__"):
         out.func = TVMFFIPyArgSetterFloatProtocol_
         return 0
+    if hasattr(arg_class, "__tvm_ffi_value__"):
+        out.func = TVMFFIPyArgSetterFFIValueProtocol_
+        return 0
     if isinstance(arg, Exception):
         out.func = TVMFFIPyArgSetterException_
         return 0
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 844f7ef..841687a 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -438,7 +438,7 @@ def _dltensor_test_wrapper_exchange_api_ptr():
 cdef class DLTensorTestWrapper:
     """Wrapper of a Tensor that exposes DLPack protocol, only for testing 
purpose.
     """
-    __c_dlpack_exchange_api__: int = _dltensor_test_wrapper_exchange_api_ptr()
+    __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr()
 
     cdef Tensor tensor
     cdef dict __dict__
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index dc970d3..88c27d7 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -40,16 +40,43 @@
 #include <exception>
 #include <iostream>
 #include <unordered_map>
+#include <vector>
 
 
///--------------------------------------------------------------------------------
 /// We deliberately designed the data structure and function to be C-style
 //  prefixed with TVMFFIPy so they can be easily invoked through Cython.
 
///--------------------------------------------------------------------------------
 
+/*!
+ * \brief Thread-local call stack used by TVMFFIPyCallContext.
+ */
+class TVMFFIPyCallStack {
+ public:
+  /*! \brief The stack of arguments */
+  std::vector<TVMFFIAny> args_stack;
+  /*! \brief The top of the argument call stack currently */
+  int64_t args_stack_top = 0;
+  /*!
+   * \brief The stack of extra temporary Python objects that may not fit into
+   * one temp per argument budget, mainly used by value protocol.
+   */
+  std::vector<PyObject*> extra_temp_py_objects_stack;
+
+  /*! \brief Constructor to initialize the call stack */
+  TVMFFIPyCallStack() {
+    // keep it 4K as default stack size so it is page aligned
+    constexpr size_t kDefaultStackSize = 4096;
+    // fit everything roughly 4K stack
+    args_stack.resize(kDefaultStackSize / sizeof(TVMFFIAny));
+    extra_temp_py_objects_stack.reserve(16);
+  }
+};
+
 /*!
  * \brief Context for each ffi call to track the stream, device and temporary 
arguments.
  */
-struct TVMFFIPyCallContext {
+class TVMFFIPyCallContext {
+ public:
   /*! \brief The workspace for the packed arguments */
   TVMFFIAny* packed_args = nullptr;
   /*! \brief Detected device type, if any */
@@ -58,16 +85,77 @@ struct TVMFFIPyCallContext {
   int device_id = 0;
   /*! \brief Detected stream, if any */
   void* stream = nullptr;
+  /*! \brief the DLPack exchange API, if any */
+  const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+  /*! \brief pointer to the call stack space */
+  TVMFFIPyCallStack* call_stack = nullptr;
   /*! \brief the temporary arguments to be recycled */
   void** temp_ffi_objects = nullptr;
-  /*! \brief the number of temporary arguments */
-  int num_temp_ffi_objects = 0;
   /*! \brief the temporary arguments to be recycled */
   void** temp_py_objects = nullptr;
   /*! \brief the number of temporary arguments */
+  int num_temp_ffi_objects = 0;
+  /*! \brief the number of temporary arguments */
   int num_temp_py_objects = 0;
-  /*! \brief the DLPack exchange API, if any */
-  const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+
+  /*! \brief RAII guard constructor to create a TVMFFIPyCallContext */
+  TVMFFIPyCallContext(TVMFFIPyCallStack* call_stack, int64_t num_args) : 
call_stack(call_stack) {
+    // In most cases, it will try to allocate from temp_stack,
+    // then allocate from heap if the request goes beyond the stack size.
+    static_assert(sizeof(TVMFFIAny) >= (sizeof(void*) * 2));
+    static_assert(alignof(TVMFFIAny) % alignof(void*) == 0);
+    old_args_stack_top_ = call_stack->args_stack_top;
+    int64_t requested_count = num_args * 2;
+    TVMFFIAny* stack_head = call_stack->args_stack.data() + 
call_stack->args_stack_top;
+    if (call_stack->args_stack_top + requested_count >
+        static_cast<int64_t>(call_stack->args_stack.size())) {
+      // allocate from heap
+      heap_ptr_ = new TVMFFIAny[requested_count];
+      stack_head = heap_ptr_;
+    } else {
+      call_stack->args_stack_top += requested_count;
+    }
+    this->packed_args = stack_head;
+    // by default we co-locate the temporary arguments with packed arguments
+    // for better cache locality with one temp per argument budget.
+    this->temp_ffi_objects = reinterpret_cast<void**>(stack_head + num_args);
+    this->temp_py_objects = this->temp_ffi_objects + num_args;
+    this->old_extra_temp_py_objects_stack_top_ = 
call_stack->extra_temp_py_objects_stack.size();
+  }
+
+  ~TVMFFIPyCallContext() {
+    try {
+      // recycle the temporary arguments if any
+      for (int i = 0; i < this->num_temp_ffi_objects; ++i) {
+        TVMFFIObjectDecRef(this->temp_ffi_objects[i]);
+      }
+      for (int i = 0; i < this->num_temp_py_objects; ++i) {
+        Py_DecRef(static_cast<PyObject*>(this->temp_py_objects[i]));
+      }
+      for (size_t i = old_extra_temp_py_objects_stack_top_;
+           i < call_stack->extra_temp_py_objects_stack.size(); ++i) {
+        
Py_DecRef(static_cast<PyObject*>(call_stack->extra_temp_py_objects_stack[i]));
+      }
+      
call_stack->extra_temp_py_objects_stack.resize(old_extra_temp_py_objects_stack_top_);
+    } catch (const std::exception& ex) {
+      // very rare, catch c++ exception and set python error
+      PyErr_SetString(PyExc_RuntimeError, ex.what());
+    }
+    // now recycle the memory of the call stack
+    if (heap_ptr_ == nullptr) {
+      call_stack->args_stack_top = old_args_stack_top_;
+    } else {
+      delete[] heap_ptr_;
+    }
+  }
+
+ private:
+  /*! \brief the heap pointer */
+  TVMFFIAny* heap_ptr_ = nullptr;
+  /*! \brief the old stack top */
+  size_t old_args_stack_top_;
+  /*! \brief the begin index of the temporary Python objects stack */
+  size_t old_extra_temp_py_objects_stack_top_;
 };
 
 /*! \brief Argument setter for a given python argument. */
@@ -173,66 +261,6 @@ class TVMFFIPyCallManager {
     static thread_local TVMFFIPyCallManager inst;
     return &inst;
   }
-  /*!
-   * \brief auxiliary class that manages the call stack in RAII manner.
-   *
-   * In most cases, it will try to allocate from temp_stack,
-   * then allocate from heap if the request goes beyond the stack size.
-   */
-  class CallStack : public TVMFFIPyCallContext {
-   public:
-    CallStack(TVMFFIPyCallManager* manager, int64_t num_args) : 
manager_ptr_(manager) {
-      static_assert(sizeof(TVMFFIAny) >= (sizeof(void*) * 2));
-      static_assert(alignof(TVMFFIAny) % alignof(void*) == 0);
-      old_stack_top_ = manager->stack_top_;
-      int64_t requested_count = num_args * 2;
-      TVMFFIAny* stack_head = manager->temp_stack_.data() + 
manager->stack_top_;
-      if (manager->stack_top_ + requested_count >
-          static_cast<int64_t>(manager->temp_stack_.size())) {
-        // allocate from heap
-        heap_ptr_ = new TVMFFIAny[requested_count];
-        stack_head = heap_ptr_;
-      } else {
-        manager->stack_top_ += requested_count;
-      }
-      this->packed_args = stack_head;
-      this->temp_ffi_objects = reinterpret_cast<void**>(stack_head + num_args);
-      this->temp_py_objects = this->temp_ffi_objects + num_args;
-    }
-
-    ~CallStack() {
-      try {
-        // recycle the temporary arguments if any
-        for (int i = 0; i < this->num_temp_ffi_objects; ++i) {
-          TVMFFIObjectDecRef(this->temp_ffi_objects[i]);
-        }
-        for (int i = 0; i < this->num_temp_py_objects; ++i) {
-          Py_DecRef(static_cast<PyObject*>(this->temp_py_objects[i]));
-        }
-      } catch (const std::exception& ex) {
-        // very rare, catch c++ exception and set python error
-        PyErr_SetString(PyExc_RuntimeError, ex.what());
-      }
-      // now recycle the memory of the call stack
-      if (heap_ptr_ == nullptr) {
-        manager_ptr_->stack_top_ = old_stack_top_;
-      } else {
-        delete[] heap_ptr_;
-      }
-    }
-
-   private:
-    /*!
-     *\brief The manager of the call stack
-     * If stored on stack, must set it to point to parent.
-     */
-    TVMFFIPyCallManager* manager_ptr_ = nullptr;
-    /*! \brief The heap of the call stack */
-    TVMFFIAny* heap_ptr_ = nullptr;
-    /*! \brief The old stack size */
-    int64_t old_stack_top_ = 0;
-  };
-
   /*!
    * \brief Call a function with a variable number of arguments
    * \param setter_factory The factory function to create the setter
@@ -253,7 +281,7 @@ class TVMFFIPyCallManager {
     if (num_args == -1) return -1;
     try {
       // allocate a call stack
-      CallStack ctx(this, num_args);
+      TVMFFIPyCallContext ctx(&call_stack_, num_args);
       // Iterate over the arguments and set them
       for (int64_t i = 0; i < num_args; ++i) {
         PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
@@ -335,7 +363,7 @@ class TVMFFIPyCallManager {
     if (num_args == -1) return -1;
     try {
       // allocate a call stack
-      CallStack ctx(this, num_args);
+      TVMFFIPyCallContext ctx(&call_stack_, num_args);
       // Iterate over the arguments and set them
       for (int64_t i = 0; i < num_args; ++i) {
         PyObject* py_arg = PyTuple_GetItem(py_arg_tuple, i);
@@ -368,7 +396,7 @@ class TVMFFIPyCallManager {
                               TVMFFIFieldSetter field_setter, void* field_ptr, 
PyObject* py_arg,
                               int* c_api_ret_code) {
     try {
-      CallStack ctx(this, 1);
+      TVMFFIPyCallContext ctx(&call_stack_, 1);
       TVMFFIAny* c_arg = ctx.packed_args;
       if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
       c_api_ret_code[0] = (*field_setter)(field_ptr, c_arg);
@@ -380,10 +408,10 @@ class TVMFFIPyCallManager {
     }
   }
 
-  int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, PyObject* 
py_arg, TVMFFIAny* out,
-                       int* c_api_ret_code) {
+  TVM_FFI_INLINE int PyObjectToFFIAny(TVMFFIPyArgSetterFactory setter_factory, 
PyObject* py_arg,
+                                      TVMFFIAny* out, int* c_api_ret_code) {
     try {
-      CallStack ctx(this, 1);
+      TVMFFIPyCallContext ctx(&call_stack_, 1);
       TVMFFIAny* c_arg = ctx.packed_args;
       if (SetArgument(setter_factory, &ctx, py_arg, c_arg) != 0) return -1;
       c_api_ret_code[0] = TVMFFIAnyViewToOwnedAny(c_arg, out);
@@ -394,20 +422,7 @@ class TVMFFIPyCallManager {
       return -1;
     }
   }
-  /*!
-   * \brief Get the size of the dispatch map
-   * \return The size of the dispatch map
-   */
-  size_t GetDispatchMapSize() const { return dispatch_map_.size(); }
 
- private:
-  TVMFFIPyCallManager() {
-    static constexpr size_t kDefaultDispatchCapacity = 32;
-    // keep it 4K as default stack size so it is page aligned
-    static constexpr size_t kDefaultStackSize = 4096;
-    dispatch_map_.reserve(kDefaultDispatchCapacity);
-    temp_stack_.resize(kDefaultStackSize / sizeof(TVMFFIAny));
-  }
   /*!
    * \brief Set an py_arg to out.
    * \param setter_factory The factory function to create the setter
@@ -443,11 +458,23 @@ class TVMFFIPyCallManager {
     }
     return 0;
   }
+
+  /*!
+   * \brief Get the size of the dispatch map
+   * \return The size of the dispatch map
+   */
+  size_t GetDispatchMapSize() const { return dispatch_map_.size(); }
+
+ private:
+  TVMFFIPyCallManager() {
+    static constexpr size_t kDefaultDispatchCapacity = 32;
+    dispatch_map_.reserve(kDefaultDispatchCapacity);
+  }
+
   // internal dispacher
   std::unordered_map<PyTypeObject*, TVMFFIPyArgSetter> dispatch_map_;
-  // temp call stack
-  std::vector<TVMFFIAny> temp_stack_;
-  int64_t stack_top_ = 0;
+  // call stack
+  TVMFFIPyCallStack call_stack_;
 };
 
 /*!
@@ -514,6 +541,22 @@ TVM_FFI_INLINE int 
TVMFFIPyCallFieldSetter(TVMFFIPyArgSetterFactory setter_facto
                                                       py_arg, c_api_ret_code);
 }
 
+/*!
+ * \brief Set an python argument to a FFI Any using the generic dispatcher in 
call manager
+ * \param setter_factory The factory function to create the setter
+ * \param ctx The call context
+ * \param py_arg_tvm_ffi_value The python argument to be set using the 
__tvm_ffi_value__ protocol
+ * \param out The output argument
+ * \return 0 on success, nonzero on failure
+ */
+TVM_FFI_INLINE int 
TVMFFIPySetArgumentGenericDispatcher(TVMFFIPyArgSetterFactory setter_factory,
+                                                        TVMFFIPyCallContext* 
ctx,
+                                                        PyObject* 
py_arg_tvm_ffi_value,
+                                                        TVMFFIAny* out) {
+  return TVMFFIPyCallManager::ThreadLocal()->SetArgument(setter_factory, ctx, 
py_arg_tvm_ffi_value,
+                                                         out);
+}
+
 /*!
  * \brief Convert a Python object to a FFI Any
  * \param setter_factory The factory function to create the setter
@@ -560,6 +603,17 @@ TVM_FFI_INLINE void 
TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject*
   ctx->temp_py_objects[ctx->num_temp_py_objects++] = arg;
 }
 
+/*!
+ * \brief Push Extra temporary Python object to the call context that may go 
beyond one temp per
+ *        argument budget, mainly used by value protocol.
+ * \param ctx The call context
+ * \param arg The Python object to push
+ */
+TVM_FFI_INLINE void TVMFFIPyPushExtraTempPyObject(TVMFFIPyCallContext* ctx, 
PyObject* arg) {
+  Py_IncRef(arg);
+  ctx->call_stack->extra_temp_py_objects_stack.emplace_back(arg);
+}
+
 //----------------------------------------------------------
 // Helpers for MLIR redirection
 //----------------------------------------------------------
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index a24395a..8a494fb 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -383,3 +383,22 @@ def test_integral_float_variants_passing() -> None:
     y = fecho(FloatProtocol(10))
     assert isinstance(y, float)
     assert y == 10
+
+
+def test_function_with_value_protocol() -> None:
+    class ValueProtocol:
+        def __init__(self, value: Any) -> None:
+            self.value = value
+
+        def __tvm_ffi_value__(self) -> Any:
+            return self.value
+
+    fecho = tvm_ffi.get_global_func("testing.echo")
+    assert fecho(ValueProtocol(10)) == 10
+    assert tuple(fecho(ValueProtocol([1, 2, 3]))) == (1, 2, 3)
+    assert tuple(fecho(ValueProtocol([1, 2, ValueProtocol(3)]))) == (1, 2, 3)
+    nested_value_protocol = ValueProtocol(ValueProtocol(ValueProtocol(10)))
+    assert fecho(nested_value_protocol) == 10
+
+    nested_value_protocol = ValueProtocol([ValueProtocol(1), ValueProtocol(2), 
ValueProtocol(3)])
+    assert tuple(fecho(nested_value_protocol)) == (1, 2, 3)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index 3b58ccb..fa5c9b3 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -139,6 +139,9 @@ def test_torch_graph() -> None:
     device_type = device.dlpack_device_type()
     graph = torch.cuda.CUDAGraph()
     stream = torch.cuda.Stream(device_id)
+    x = torch.zeros(1, device="cuda")
     with tvm_ffi.use_torch_stream(torch.cuda.graph(graph, stream=stream)):
         assert torch.cuda.current_stream() == stream
         mod.check_stream(device_type, device_id, stream.cuda_stream)
+        # avoid cuda graph no capture warning
+        x = x + 1
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 0551d14..9c938a8 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from types import ModuleType
+from typing import Any, NamedTuple
 
 import pytest
 
@@ -113,6 +114,28 @@ def test_tvm_ffi_tensor_compatible() -> None:
     z = fecho(y)
     assert z.__chandle__() == x.__chandle__()
 
+    class MyNamedTuple(NamedTuple):
+        a: MyTensor
+        b: int
+
+    args = MyNamedTuple(a=y, b=1)
+    z = fecho(args)
+    assert z[0].__chandle__() == x.__chandle__()
+    assert z[1] == 1
+
+    class MyCustom:
+        def __init__(self, a: MyTensor, b: int) -> None:
+            self.a = a
+            self.b = b
+
+        def __tvm_ffi_value__(self) -> Any:
+            """Implement __tvm_ffi_value__ protocol."""
+            return (self.a, self.b)
+
+    z = fecho(MyCustom(a=y, b=2))
+    assert z[0].__chandle__() == x.__chandle__()
+    assert z[1] == 2
+
 
 @pytest.mark.skipif(
     torch is None or not torch.cuda.is_available() or torch.version.hip is 
None,
diff --git a/tests/scripts/benchmark_dlpack.py 
b/tests/scripts/benchmark_dlpack.py
index 23db3a8..4798d00 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -33,7 +33,7 @@ Summary of some takeaways:
 from __future__ import annotations
 
 import time
-from typing import Any, Callable
+from typing import Any, Callable, NamedTuple
 
 import numpy as np
 import torch
@@ -52,6 +52,14 @@ class TestFFITensor:
         return self._tensor
 
 
+class TestNamedTuple(NamedTuple):
+    """Test FFI NamedTuple."""
+
+    x: torch.Tensor
+    y: torch.Tensor
+    z: torch.Tensor
+
+
 def print_speed(name: str, speed: float) -> None:
     print(f"{name:<60} {speed} sec/call")
 
@@ -231,6 +239,20 @@ def bench_tvm_ffi_nop_autodlpack(name: str, x: Any, y: 
Any, z: Any, repeat: int)
     print_speed(name, speed)
 
 
+def bench_tvm_ffi_nop_autodlpack_tuple(name: str, args: TestNamedTuple, 
repeat: int) -> None:
+    """Measures overhead of running dlpack via auto convert by directly
+    take torch.Tensor as inputs.
+    """
+    nop = tvm_ffi.get_global_func("testing.nop")
+    nop(args)
+    start = time.time()
+    for i in range(repeat):
+        nop(args)
+    end = time.time()
+    speed = (end - start) / repeat
+    print_speed(name, speed)
+
+
 def tvm_ffi_nop_autodlpack_from_torch(
     repeat: int, device: str = "cpu", stream: bool = False
 ) -> None:
@@ -276,18 +298,16 @@ def 
tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat: int, device: str)
     )
 
 
-def tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat: int, device: str) -> 
None:
+def tvm_ffi_nop_autodlpack_from_test_tensor_namedtuple(repeat: int, device: 
str) -> None:
     """Measures overhead of running dlpack via auto convert by directly
     take test wrapper as inputs. This effectively measure DLPack exchange in 
tvm ffi.
     """
-    x = tvm_ffi.from_dlpack(torch.arange(1, device=device))
-    y = tvm_ffi.from_dlpack(torch.arange(1, device=device))
-    z = tvm_ffi.from_dlpack(torch.arange(1, device=device))
-    x = TestFFITensor(x)
-    y = TestFFITensor(y)
-    z = TestFFITensor(z)
-    bench_tvm_ffi_nop_autodlpack(
-        f"tvm_ffi.nop.autodlpack(TestFFITensor[{device}])", x, y, z, repeat
+    x = torch.arange(1, device=device)
+    y = torch.arange(1, device=device)
+    z = torch.arange(1, device=device)
+    args = TestNamedTuple(x=x, y=y, z=z)
+    bench_tvm_ffi_nop_autodlpack_tuple(
+        f"tvm_ffi.nop.autodlpack(NamedTuple[{device}])", args, repeat
     )
 
 
@@ -414,8 +434,8 @@ def main() -> None:  # noqa: PLR0915
     tvm_ffi_nop_autodlpack_from_numpy(repeat)
     tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cpu")
     tvm_ffi_nop_autodlpack_from_dltensor_test_wrapper(repeat, "cuda")
-    tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cpu")
-    tvm_ffi_nop_autodlpack_from_test_ffi_tensor(repeat, "cuda")
+    tvm_ffi_nop_autodlpack_from_test_tensor_namedtuple(repeat, "cpu")
+    tvm_ffi_nop_autodlpack_from_test_tensor_namedtuple(repeat, "cuda")
     tvm_ffi_nop(repeat)
     print("-------------------------------")
     print("Benchmark x.__dlpack__ overhead")

Reply via email to