This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0c9e7cda7d [FFI] Update `load_inline` interface (#18307)
0c9e7cda7d is described below
commit 0c9e7cda7d39cf24bd7676f0e67c3885ae95cff3
Author: Yaoyao Ding <[email protected]>
AuthorDate: Fri Sep 12 16:50:37 2025 -0400
[FFI] Update `load_inline` interface (#18307)
update load_inline interface
---
ffi/python/tvm_ffi/cpp/load_inline.py | 20 ++++++++++++++------
ffi/tests/python/test_load_inline.py | 3 ---
2 files changed, 14 insertions(+), 9 deletions(-)
diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py
b/ffi/python/tvm_ffi/cpp/load_inline.py
index 111dee8d52..3bc0fc4cbc 100644
--- a/ffi/python/tvm_ffi/cpp/load_inline.py
+++ b/ffi/python/tvm_ffi/cpp/load_inline.py
@@ -326,10 +326,12 @@ def load_inline(
cuda_sources: Sequence[str] | str, optional
The CUDA source code. It can be a list of sources or a single source.
functions: Mapping[str, str] | Sequence[str] | str, optional
- The functions in cpp_sources that will be exported to the tvm ffi
module. When a mapping is given, the keys
- are the names of the exported functions, and the values are docstrings
for the functions. When a sequence or a
- single string is given, they are the functions needed to be exported,
and the docstrings are set to empty
- strings. A single function name can also be given as a string.
+ The functions in cpp_sources or cuda_source that will be exported to
the tvm ffi module. When a mapping is
+ given, the keys are the names of the exported functions, and the
values are docstrings for the functions. When
+ a sequence or a single string is given, they are the functions needed
to be exported, and the docstrings are set
+ to empty strings. A single function name can also be given as a
string. When cpp_sources is given, the functions
+ must be declared (not necessarily defined) in the cpp_sources. When
cpp_sources is not given, the functions
+ must be defined in the cuda_sources. If not specified, no function
will be exported.
extra_cflags: Sequence[str], optional
The extra compiler flags for C++ compilation.
The default flags are:
@@ -369,6 +371,7 @@ def load_inline(
elif isinstance(cuda_sources, str):
cuda_sources = [cuda_sources]
cuda_source = "\n".join(cuda_sources)
+ with_cpp = len(cpp_sources) > 0
with_cuda = len(cuda_sources) > 0
extra_ldflags = extra_ldflags or []
@@ -381,8 +384,13 @@ def load_inline(
functions = {functions: ""}
elif isinstance(functions, Sequence):
functions = {name: "" for name in functions}
- cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
- cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
+
+ if with_cpp:
+ cpp_source = _decorate_with_tvm_ffi(cpp_source, functions)
+ cuda_source = _decorate_with_tvm_ffi(cuda_source, {})
+ else:
+ cpp_source = _decorate_with_tvm_ffi(cpp_source, {})
+ cuda_source = _decorate_with_tvm_ffi(cuda_source, functions)
# determine the cache dir for the built module
if build_directory is None:
diff --git a/ffi/tests/python/test_load_inline.py
b/ffi/tests/python/test_load_inline.py
index 89f00b1f36..2aa01a62ee 100644
--- a/ffi/tests/python/test_load_inline.py
+++ b/ffi/tests/python/test_load_inline.py
@@ -159,9 +159,6 @@ def test_load_inline_cpp_build_dir():
def test_load_inline_cuda():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
- cpp_sources=r"""
- void add_one_cuda(DLTensor* x, DLTensor* y);
- """,
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;