This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new f6303b2 [CYTHON] Enable ffi function creation from MLIR Packed
function (#100)
f6303b2 is described below
commit f6303b23fd97909b59f6ff67b85f2203371f5db1
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Oct 11 10:42:40 2025 -0400
[CYTHON] Enable ffi function creation from MLIR Packed function (#100)
MLIR Execution engine generates function that corresponds to the packed
signature. As of now it is hard to access the raw extern C function
pointer of SafeCall directly when we declare the signature in LLVM
dialect.
This helper enables us to create ffi.Function from safe call function
pointer that are exposed as MLIR packed function pointer calling
convention. It can help facilitate DSLs that leverages MLIR execution
engine to do JIT.
Note that in theory MLIR execution engine should be able to also support
some form of "extern C" feature that directly exposes the funtion
pointers of C-compatible functions with an attribute tag. So we keep
this feature in the python helper layer for now in case MLIR execution
engine supports it in the future.
---
python/tvm_ffi/core.pyi | 26 +++++++-
python/tvm_ffi/cython/base.pxi | 4 ++
python/tvm_ffi/cython/function.pxi | 73 +++++++++++++++++++---
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 84 ++++++++++++++++++++++++++
src/ffi/extra/testing.cc | 19 +++++-
tests/python/test_function.py | 24 +++++++-
6 files changed, 217 insertions(+), 13 deletions(-)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 54f292c..092515a 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -529,7 +529,7 @@ class Function(Object):
"""
@staticmethod
- def __from_extern_c__(c_symbol: int, keep_alive_object: Any | None = None)
-> Function:
+ def __from_extern_c__(c_symbol: int, *, keep_alive_object: Any | None =
None) -> Function:
"""Construct a ``Function`` from a C symbol and keep_alive_object.
Parameters
@@ -540,8 +540,10 @@ class Function(Object):
which is the function handle
keep_alive_object : object
- optional closure to be captured and kept alive
+ optional object to be captured and kept alive
Usually can be the execution engine that JITed the function
+ to ensure we keep the execution environment alive
+ as long as the function is alive
Returns
-------
@@ -550,6 +552,26 @@ class Function(Object):
"""
+ @staticmethod
+ def __from_mlir_packed_safe_call__(
+ mlir_packed_symbol: int, *, keep_alive_object: Any | None = None
+ ) -> Function:
+ """Construct a ``Function`` from a MLIR packed safe call function
pointer.
+
+ Parameters
+ ----------
+ mlir_packed_symbol : int
+ function pointer to the MLIR packed call function pointer
+ that represents a safe call function
+
+ keep_alive_object : object
+ optional object to be captured and kept alive
+ Usually can be the execution engine that JITed the function
+ to ensure we keep the execution environment alive
+ as long as the function is alive
+
+ """
+
def _register_global_func(
name: str, pyfunc: Callable[..., Any] | Function, override: bool
) -> Function: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 0478575..dcc338a 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -339,6 +339,10 @@ cdef extern from "tvm_ffi_python_helpers.h":
int TVMFFIPyArgSetterInt_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterBool_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
int TVMFFIPyArgSetterNone_(TVMFFIPyArgSetter*, TVMFFIPyCallContext*,
PyObject* arg, TVMFFIAny* out) except -1
+ # MLIRPackedSafeCall
+ void* TVMFFIPyMLIRPackedSafeCallCreate(void
(*mlir_packed_safe_call)(void**) noexcept, PyObject* keep_alive_object)
+ int TVMFFIPyMLIRPackedSafeCallInvoke(void* self, const TVMFFIAny* args,
int32_t num_args, TVMFFIAny* rv)
+ void TVMFFIPyMLIRPackedSafeCallDeleter(void* self)
# deleter for python objects
void TVMFFIPyObjectDeleter(void* py_obj) noexcept nogil
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 3988039..a6cc3be 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -679,24 +679,30 @@ cdef class Function(Object):
raise move_from_last_error().py_error()
@staticmethod
- def __from_extern_c__(c_symbol: int, keep_alive_object: object = None) ->
"Function":
- """Convert a function from extern C address and closure object
+ def __from_extern_c__(
+ c_symbol: int,
+ *,
+ keep_alive_object: object = None
+ ) -> Function:
+ """Convert a function from extern C address.
Parameters
----------
c_symbol : int
- function pointer to the safe call function
+ Function pointer to the safe call function.
The function pointer must ignore the first argument,
- which is the function handle
+ which is the function handle.
keep_alive_object : object
- optional closure to be captured and kept alive
- Usually can be the execution engine that JITed the function
+ Optional object to be captured and kept alive.
+ Usually this can be the execution engine that JIT-compiled the
function
+ to ensure we keep the execution environment alive
+ as long as the function is alive.
Returns
-------
Function
- The function object
+ The function object.
"""
cdef TVMFFIObjectHandle chandle
# must first convert to int64_t
@@ -725,6 +731,59 @@ cdef class Function(Object):
(<Object>func).chandle = chandle
return func
+ @staticmethod
+ def __from_mlir_packed_safe_call__(
+ mlir_packed_symbol: int,
+ *,
+ keep_alive_object: object = None
+ ) -> Function:
+ """Convert a function from MLIR packed safe call function pointer.
+
+ Parameters
+ ----------
+ mlir_packed_symbol : int
+ Function pointer to the MLIR packed call function
+ that represents a safe call function.
+
+ keep_alive_object : object
+ Optional object to be captured and kept alive.
+ Usually this can be the execution engine that JIT-compiled the
function
+ to ensure we keep the execution environment alive
+ as long as the function is alive.
+
+ Returns
+ -------
+ Function
+ The function object.
+ """
+ cdef TVMFFIObjectHandle chandle
+ # must first convert to int64_t
+ cdef int64_t c_symbol_as_long_long = mlir_packed_symbol
+ cdef void* packed_call_addr_ptr = <void*>c_symbol_as_long_long
+ cdef PyObject* keepalive_py_obj
+ if keep_alive_object is None:
+ keepalive_py_obj = NULL
+ else:
+ keepalive_py_obj = <PyObject*>keep_alive_object
+
+ cdef void* mlir_packed_safe_call = TVMFFIPyMLIRPackedSafeCallCreate(
+ <void (*)(void**) noexcept>packed_call_addr_ptr,
+ keepalive_py_obj
+ )
+ cdef int ret_code
+ ret_code = TVMFFIFunctionCreate(
+ mlir_packed_safe_call,
+ TVMFFIPyMLIRPackedSafeCallInvoke,
+ TVMFFIPyMLIRPackedSafeCallDeleter,
+ &chandle
+ )
+ if ret_code != 0:
+ # cleanup during error handling
+ TVMFFIPyMLIRPackedSafeCallDeleter(mlir_packed_safe_call)
+ CHECK_CALL(ret_code)
+ func = Function.__new__(Function)
+ (<Object>func).chandle = chandle
+ return func
_register_object_by_index(kTVMFFIFunction, Function)
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index e507f03..ea60d51 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -592,6 +592,90 @@ TVM_FFI_INLINE void
TVMFFIPyPushTempPyObject(TVMFFIPyCallContext* ctx, PyObject*
ctx->temp_py_objects[ctx->num_temp_py_objects++] = arg;
}
+//----------------------------------------------------------
+// Helpers for MLIR redirection
+//----------------------------------------------------------
+/*!
+ * \brief Function specialization that leverages MLIR packed safe call
definitions.
+ *
+ * The MLIR execution engine generates functions that correspond to the packed
signature.
+ * As of now, it is hard to access the raw extern C function pointer of
SafeCall
+ * directly when we declare the signature in LLVM dialect.
+ *
+ * Note that in theory, the MLIR execution engine should be able to support
+ * some form of "extern C" feature that directly exposes the function pointers
+ * of C-compatible functions with an attribute tag. So we keep this feature
+ * in the Python helper layer for now in case the MLIR execution engine
supports it in the future.
+ *
+ * This helper enables us to create ffi::Function from the MLIR packed
+ * safe call function pointer instead of following the redirection pattern
+ * in `TVMFFIPyMLIRPackedSafeCall::Invoke`.
+ *
+ * \sa TVMFFIPyMLIRPackedSafeCall::Invoke
+ */
+class TVMFFIPyMLIRPackedSafeCall {
+ public:
+ TVMFFIPyMLIRPackedSafeCall(void (*mlir_packed_safe_call)(void**), PyObject*
keep_alive_object)
+ : mlir_packed_safe_call_(mlir_packed_safe_call),
keep_alive_object_(keep_alive_object) {
+ if (keep_alive_object_) {
+ Py_IncRef(keep_alive_object_);
+ }
+ }
+
+ ~TVMFFIPyMLIRPackedSafeCall() {
+ if (keep_alive_object_) {
+ Py_DecRef(keep_alive_object_);
+ }
+ }
+
+ static int Invoke(void* func, const TVMFFIAny* args, int32_t num_args,
TVMFFIAny* rv) {
+ TVMFFIPyMLIRPackedSafeCall* self =
reinterpret_cast<TVMFFIPyMLIRPackedSafeCall*>(func);
+ int ret_code = 0;
+ void* handle = nullptr;
+ void* mlir_args[] = {&handle, const_cast<TVMFFIAny**>(&args), &num_args,
&rv, &ret_code};
+ (*self->mlir_packed_safe_call_)(mlir_args);
+ return ret_code;
+ }
+
+ static void Deleter(void* self) { delete
static_cast<TVMFFIPyMLIRPackedSafeCall*>(self); }
+
+ private:
+ void (*mlir_packed_safe_call_)(void**);
+ PyObject* keep_alive_object_;
+};
+
+/*!
+ * \brief Create a TVMFFIPyMLIRPackedSafeCall handle
+ * \param mlir_packed_safe_call The MLIR packed safe call function
+ * \param keep_alive_object The keep alive object
+ * \return The TVMFFIPyMLIRPackedSafeCall object
+ */
+void* TVMFFIPyMLIRPackedSafeCallCreate(void (*mlir_packed_safe_call)(void**),
+ PyObject* keep_alive_object) {
+ return new TVMFFIPyMLIRPackedSafeCall(mlir_packed_safe_call,
keep_alive_object);
+}
+
+/*!
+ * \brief Call the MLIR packed safe call function
+ * \param self The TVMFFIPyMLIRPackedSafeCall object
+ * \param args The arguments
+ * \param num_args The number of arguments
+ * \param rv The result
+ * \return The return code
+ */
+int TVMFFIPyMLIRPackedSafeCallInvoke(void* self, const TVMFFIAny* args,
int32_t num_args,
+ TVMFFIAny* rv) {
+ return TVMFFIPyMLIRPackedSafeCall::Invoke(self, args, num_args, rv);
+}
+
+/*!
+ * \brief Delete the TVMFFIPyMLIRPackedSafeCall object
+ * \param self The TVMFFIPyMLIRPackedSafeCall object
+ */
+void TVMFFIPyMLIRPackedSafeCallDeleter(void* self) {
+ return TVMFFIPyMLIRPackedSafeCall::Deleter(self);
+}
+
//------------------------------------------------------------------------------------
// Helpers for free-threaded python
//------------------------------------------------------------------------------------
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index b5b1202..13f0044 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -180,6 +180,15 @@ int __add_one_c_symbol(void*, const TVMFFIAny* args,
int32_t num_args, TVMFFIAny
TVM_FFI_SAFE_CALL_END();
}
+void _mlir_add_one_c_symbol(void** packed_args) {
+ void* handle = *reinterpret_cast<void**>(packed_args[0]);
+ const TVMFFIAny* args = *reinterpret_cast<const TVMFFIAny**>(packed_args[1]);
+ int32_t num_args = *reinterpret_cast<int32_t*>(packed_args[2]);
+ TVMFFIAny* rv = *reinterpret_cast<TVMFFIAny**>(packed_args[3]);
+ int* ret_code = reinterpret_cast<int*>(packed_args[4]);
+ *ret_code = __add_one_c_symbol(handle, args, num_args, rv);
+}
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
@@ -250,9 +259,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("testing.object_use_count", [](const Object* obj) { return
obj->use_count(); })
.def("testing.make_unregistered_object",
[]() { return ObjectRef(make_object<TestUnregisteredObject>(41,
42)); })
- .def("testing.get_add_one_c_symbol", []() {
- TVMFFISafeCallType symbol = __add_one_c_symbol;
- return reinterpret_cast<int64_t>(reinterpret_cast<void*>(symbol));
+ .def("testing.get_add_one_c_symbol",
+ []() {
+ TVMFFISafeCallType symbol = __add_one_c_symbol;
+ return reinterpret_cast<int64_t>(reinterpret_cast<void*>(symbol));
+ })
+ .def("testing.get_mlir_add_one_c_symbol", []() {
+ return
reinterpret_cast<int64_t>(reinterpret_cast<void*>(_mlir_add_one_c_symbol));
});
}
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index d7f26c4..71aecbc 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -235,7 +235,29 @@ def test_function_from_c_symbol() -> None:
keep_alive = [1, 2, 3]
base_ref_count = sys.getrefcount(keep_alive)
- fadd_one = tvm_ffi.Function.__from_extern_c__(add_one_c_symbol, keep_alive)
+ fadd_one = tvm_ffi.Function.__from_extern_c__(add_one_c_symbol,
keep_alive_object=keep_alive)
+ assert fadd_one(1) == 2
+ assert fadd_one(2) == 3
+ assert sys.getrefcount(keep_alive) == base_ref_count + 1
+ fadd_one = None
+ assert sys.getrefcount(keep_alive) == base_ref_count
+
+
+def test_function_from_mlir_packed_safe_call() -> None:
+ add_one_c_symbol =
tvm_ffi.get_global_func("testing.get_mlir_add_one_c_symbol")()
+ fadd_one =
tvm_ffi.Function.__from_mlir_packed_safe_call__(add_one_c_symbol)
+ assert fadd_one(1) == 2
+ assert fadd_one(2) == 3
+
+ keep_alive = [1, 2, 3]
+ base_ref_count = sys.getrefcount(keep_alive)
+ fadd_one = tvm_ffi.Function.__from_mlir_packed_safe_call__(
+ add_one_c_symbol, keep_alive_object=keep_alive
+ )
+
+ with pytest.raises(TypeError):
+ fadd_one(None)
+
assert fadd_one(1) == 2
assert fadd_one(2) == 3
assert sys.getrefcount(keep_alive) == base_ref_count + 1