This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b17709a doc: Update Kernel Library Guide (#431)
b17709a is described below
commit b17709aabcc2f1bd776fb001bd05a9c1b4cfe421
Author: Junru Shao <[email protected]>
AuthorDate: Fri Feb 6 12:13:36 2026 -0800
doc: Update Kernel Library Guide (#431)
---
docs/conf.py | 7 +-
docs/guides/kernel_library_guide.rst | 368 ++++++++++++++++++++++----------
examples/kernel_library/load_scale.py | 34 +++
examples/kernel_library/scale_kernel.cu | 66 ++++++
examples/kernel_library/tvm_ffi_utils.h | 75 +++++++
include/tvm/ffi/base_details.h | 16 +-
6 files changed, 446 insertions(+), 120 deletions(-)
diff --git a/docs/conf.py b/docs/conf.py
index 8812dd1..41e61b0 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -92,11 +92,8 @@ PREDEFINED += TVM_FFI_DLL= TVM_FFI_DLL_EXPORT=
TVM_FFI_INLINE= \
TVM_FFI_EXTRA_CXX_API= TVM_FFI_WEAK=
TVM_FFI_DOXYGEN_MODE \
__cplusplus=201703
EXCLUDE_SYMBOLS += *details* *TypeTraits* std \
- *use_default_type_traits_v* *is_optional_type_v*
*operator* \
- TVM_FFI_LOG_EXCEPTION_CALL_BEGIN
TVM_FFI_LOG_EXCEPTION_CALL_END \
- TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY
TVM_FFI_STATIC_INIT_BLOCK \
- TVM_FFI_STATIC_INIT_BLOCK_DEF_
TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN
-EXCLUDE_PATTERNS += *details.h *internal*
+ *use_default_type_traits_v* *is_optional_type_v*
*operator*
+EXCLUDE_PATTERNS += */function_details.h */container_details.h
ENABLE_PREPROCESSING = YES
MACRO_EXPANSION = YES
WARNINGS = YES
diff --git a/docs/guides/kernel_library_guide.rst
b/docs/guides/kernel_library_guide.rst
index 01faf20..c042486 100644
--- a/docs/guides/kernel_library_guide.rst
+++ b/docs/guides/kernel_library_guide.rst
@@ -15,180 +15,334 @@
.. specific language governing permissions and limitations
.. under the License.
-====================
Kernel Library Guide
====================
-This guide serves as a quick start for shipping kernel libraries with TVM FFI.
The shipped kernel libraries are of python version and ML framework agnostic.
With the help of TVM FFI, we can connect the kernel libraries to multiple ML
framework, such as PyTorch, XLA, JAX, together with the minimal efforts.
+This guide covers shipping C++/CUDA kernel libraries with TVM-FFI. The
resulting
+libraries are agnostic to Python version and ML framework — a single ``.so``
works
+with PyTorch, JAX, PaddlePaddle, NumPy, and more.
-Tensor
-======
+.. seealso::
-Almost all kernel libraries are about tensor computation and manipulation. For
better adaptation to different ML frameworks, TVM FFI provides a minimal set of
data structures to represent tensors from ML frameworks, including the tensor
basic attributes and storage pointer.
-To be specific, in TVM FFI, two types of tensor constructs,
:cpp:class:`~tvm::ffi::Tensor` and :cpp:class:`~tvm::ffi::TensorView`, can be
used to represent a tensor from ML frameworks.
+ - :doc:`../get_started/quickstart`: End-to-end walkthrough of a simpler
``add_one`` kernel
+ - :doc:`../packaging/cpp_tooling`: Build toolchain, CMake integration, and
library distribution
+ - All example code in this guide is under
+ `examples/kernel_library/
<https://github.com/apache/tvm-ffi/tree/main/examples/kernel_library>`_.
+ - Production examples:
+ `FlashInfer <https://github.com/flashinfer-ai/flashinfer>`_ ships CUDA
kernels via TVM-FFI.
-Tensor and TensorView
----------------------
-Both :cpp:class:`~tvm::ffi::Tensor` and :cpp:class:`~tvm::ffi::TensorView` are
designed to represent tensors from ML frameworks that interact with the TVM FFI
ABI. They are backed by the `DLTensor` in DLPack in practice. The main
difference is whether it is an owning tensor structure.
+Anatomy of a Kernel Function
+-----------------------------
-:cpp:class:`tvm::ffi::Tensor`
- :cpp:class:`~tvm::ffi::Tensor` is a completely owning tensor with reference
counting. It can be created on either C++ or Python side and passed between
either side. And TVM FFI internally keeps a reference count to track lifetime
of the tensors. When the reference count goes to zero, its underlying deleter
function will be called to free the tensor storage.
+Every TVM-FFI CUDA kernel follows the same sequence:
-:cpp:class:`tvm::ffi::TensorView`
- :cpp:class:`~tvm::ffi::TensorView` is a non-owning view of an existing
tensor, pointing to an existing tensor (e.g., a tensor allocated by PyTorch).
+1. **Validate** inputs (device, dtype, shape, contiguity)
+2. **Set device guard** to match the tensor's device
+3. **Acquire stream** from the host framework
+4. **Dispatch** on dtype and **launch** the kernel
-It is **recommended** to use :cpp:class:`~tvm::ffi::TensorView` when possible,
that helps us to support more cases, including cases where only view but not
strong reference are passed, like XLA buffer. It is also more lightweight.
However, since :cpp:class:`~tvm::ffi::TensorView` is a non-owning view, it is
the user's responsibility to ensure the lifetime of underlying tensor.
+Here is a complete ``Scale`` kernel that computes ``y = x * factor``:
-Tensor Attributes
------------------
+.. literalinclude:: ../../examples/kernel_library/scale_kernel.cu
+ :language: cpp
+ :start-after: [function.begin]
+ :end-before: [function.end]
-For convenience, :cpp:class:`~tvm::ffi::TensorView` and
:cpp:class:`~tvm::ffi::Tensor` align the following attributes retrieval mehtods
to :cpp:class:`at::Tensor` interface, to obtain tensor basic attributes and
storage pointer:
-``dim``, ``dtype``, ``sizes``, ``size``, ``strides``, ``stride``, ``numel``,
``data_ptr``, ``device``, ``is_contiguous``
+The CUDA kernel itself is a standard ``__global__`` function:
-Please refer to the documentation of both tensor classes for their details.
Here highlight some non-primitive attributes:
+.. literalinclude:: ../../examples/kernel_library/scale_kernel.cu
+ :language: cpp
+ :start-after: [cuda_kernel.begin]
+ :end-before: [cuda_kernel.end]
-:c:struct:`DLDataType`
- The ``dtype`` of the tensor. It's represented by a struct with three fields:
code, bits, and lanes, defined by DLPack protocol.
+The following subsections break down each step.
-:c:struct:`DLDevice`
- The ``device`` where the tensor is stored. It is represented by a struct with
two fields: device_type and device_id, defined by DLPack protocol.
-:cpp:class:`tvm::ffi::ShapeView`
- The ``sizes`` and ``strides`` attributes retrieval are returned as
:cpp:class:`~tvm::ffi::ShapeView`. It is an iterate-able data structure storing
the shapes or strides data as ``int64_t`` array.
+Input Validation
+~~~~~~~~~~~~~~~~
+
+Kernel functions should validate inputs early and fail with clear error
messages.
+A common pattern is to define reusable ``CHECK_*`` macros on top of
+:c:macro:`TVM_FFI_CHECK` (see :doc:`../concepts/exception_handling`):
+
+.. literalinclude:: ../../examples/kernel_library/tvm_ffi_utils.h
+ :language: cpp
+ :start-after: [check_macros.begin]
+ :end-before: [check_macros.end]
+
+For **user-facing errors** (bad arguments, unsupported dtypes, shape
mismatches),
+use :c:macro:`TVM_FFI_THROW` or :c:macro:`TVM_FFI_CHECK` with a specific error
kind
+so that callers receive an actionable message:
+
+.. code-block:: cpp
+
+ TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
+ TVM_FFI_CHECK(input.numel() > 0, ValueError) << "input must be non-empty";
+ TVM_FFI_CHECK(input.numel() == output.numel(), ValueError) << "size
mismatch";
+
+For **internal invariants** that indicate bugs in the kernel itself, use
+:c:macro:`TVM_FFI_ICHECK`:
+
+.. code-block:: cpp
+
+ TVM_FFI_ICHECK_GE(n, 0) << "element count must be non-negative";
-Tensor Allocation
------------------
-TVM FFI provides several methods to create or allocate tensors at C++ runtime.
Generally, there are two types of tensor creation methods:
+Device Guard and Stream
+~~~~~~~~~~~~~~~~~~~~~~~
-* Allocate a tensor with new storage from scratch, i.e.
:cpp:func:`~tvm::ffi::Tensor::FromEnvAlloc` and
:cpp:func:`~tvm::ffi::Tensor::FromNDAlloc`. By this types of methods, the
shapes, strides, data types, devices and other attributes are required for the
allocation.
-* Create a tensor with existing storage following DLPack protocol, i.e.
:cpp:func:`~tvm::ffi::Tensor::FromDLPack` and
:cpp:func:`~tvm::ffi::Tensor::FromDLPackVersioned`. By this types of methods,
the shapes, data types, devices and other attributes can be inferred from the
DLPack attributes.
+Before launching a CUDA kernel, two things must happen:
-FromEnvAlloc
-^^^^^^^^^^^^
+1. **Set the CUDA device** to match the tensor's device.
:cpp:class:`tvm::ffi::CUDADeviceGuard`
+ is an RAII guard that calls ``cudaSetDevice`` on construction and restores
the
+ original device on destruction.
-To better adapt to the ML framework, it is **recommended** to reuse the
framework tensor allocator anyway, instead of directly allocating the tensors
via CUDA runtime API, like ``cudaMalloc``. Since reusing the framework tensor
allocator:
+2. **Acquire the stream** from the host framework via
:cpp:func:`TVMFFIEnvGetStream`.
+ When Python code calls a kernel with PyTorch tensors, TVM-FFI automatically
+ captures PyTorch's current stream for the tensor's device.
-* Benefit from the framework's native caching allocator or related allocation
mechanism.
-* Help framework tracking memory usage and planning globally.
+A small helper keeps this concise:
-TVM FFI provides :cpp:func:`tvm::ffi::Tensor::FromEnvAlloc` to allocate a
tensor with the framework tensor allocator. To determine which framework tensor
allocator, TVM FFI infers it from the passed-in framework tensors. For example,
when calling the kernel library at Python side, there is an input framework
tensor if of type ``torch.Tensor``, TVM FFI will automatically bind the
:cpp:func:`at::empty` as the current framework tensor allocator by
``TVMFFIEnvTensorAlloc``. And then the :cpp [...]
+.. literalinclude:: ../../examples/kernel_library/tvm_ffi_utils.h
+ :language: cpp
+ :start-after: [get_stream.begin]
+ :end-before: [get_stream.end]
-.. code-block:: c++
+Every kernel function then follows the same two-line pattern:
- ffi::Tensor tensor = ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, ...);
+.. code-block:: cpp
-which is equivalent to:
+ ffi::CUDADeviceGuard guard(input.device().device_id);
+ cudaStream_t stream = get_cuda_stream(input.device());
-.. code-block:: c++
+See :doc:`../concepts/tensor` for details on stream handling and automatic
stream
+context updates.
- at::Tensor tensor = at::empty(...);
-FromNDAlloc
-^^^^^^^^^^^
+Dtype Dispatch
+~~~~~~~~~~~~~~
-:cpp:func:`tvm::ffi::Tensor::FromNDAlloc` can be used to create a tensor with
custom memory allocator. It is of simple usage by providing a custom memory
allocator and deleter for tensor allocation and free each, rather than relying
on any framework tensor allocator.
+Kernels typically support multiple dtypes. Dispatch on :c:struct:`DLDataType`
at
+runtime while instantiating templates at compile time:
-However, the tensors allocated by :cpp:func:`tvm::ffi::Tensor::FromNDAlloc`
only retain the function pointer to its custom deleter for deconstruction. The
custom deleters are all owned by the kernel library still. So it is important
to make sure the loaded kernel library, :py:class:`tvm_ffi.Module`, outlives
the tensors allocated by :cpp:func:`tvm::ffi::Tensor::FromNDAlloc`. Otherwise,
the function pointers to the custom deleter will be invalid. Here a typical
approach is to retain the l [...]
+.. code-block:: cpp
-But in the scenarios of linked runtime libraries and c++ applications, the
libraries alive globally throughout the entire lifetime of the process. So
:cpp:func:`tvm::ffi::Tensor::FromNDAlloc` works well in these scenarios without
the use-after-delete issue above. Otherwise, in general,
:cpp:func:`tvm::ffi::Tensor::FromEnvAlloc` is free of this issue, which is more
**recommended** in practice.
+ constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1};
+ constexpr DLDataType dl_float16 = DLDataType{kDLFloat, 16, 1};
+ if (input.dtype() == dl_float32) {
+ ScaleKernel<<<blocks, threads, 0, stream>>>(
+ static_cast<float*>(output.data_ptr()), ...);
+ } else if (input.dtype() == dl_float16) {
+ ScaleKernel<<<blocks, threads, 0, stream>>>(
+ static_cast<half*>(output.data_ptr()), ...);
+ } else {
+ TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
+ }
-FromNDAllocStrided
-^^^^^^^^^^^^^^^^^^
+For libraries that support many dtypes, define dispatch macros
+(see `FlashInfer's tvm_ffi_utils.h
<https://github.com/flashinfer-ai/flashinfer/blob/main/csrc/tvm_ffi_utils.h>`_
+for a production example).
-:cpp:func:`tvm::ffi::Tensor::FromNDAllocStrided` can be used to create a
tensor with a custom memory allocator and strided layout (e.g. column major
layout).
-Note that for tensor memory that will be returned from the kernel library to
the caller, we instead recommend using
:cpp:func:`tvm::ffi::Tensor::FromEnvAlloc`
-followed by :cpp:func:`tvm::ffi::Tensor::as_strided` to create a strided view
of the tensor.
-FromDLPack
-^^^^^^^^^^
+Export and Load
+---------------
-:cpp:func:`tvm::ffi::Tensor::FromDLPack` enables creating
:cpp:class:`~tvm::ffi::Tensor` from ``DLManagedTensor*``, working with
``ToDLPack`` for DLPack C Tensor Object ``DLTensor`` exchange protocol. Both
are used for DLPack pre V1.0 API. It is used for wrapping the existing
framework tensor to :cpp:class:`~tvm::ffi::Tensor`.
+Export and Build
+~~~~~~~~~~~~~~~~
-FromDLPackVersioned
-^^^^^^^^^^^^^^^^^^^
+**Export.** Use :c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC` to create a C symbol
+that follows the :doc:`TVM-FFI calling convention <../concepts/func_module>`:
-:cpp:func:`tvm::ffi::Tensor::FromDLPackVersioned` enables creating
:cpp:class:`~tvm::ffi::Tensor` from ``DLManagedTensorVersioned*``, working with
``ToDLPackVersioned`` for DLPack C Tensor Object ``DLTensor`` exchange
protocol. Both are used for DLPack post V1.0 API. It is used for wrapping the
existing framework tensor to :cpp:class:`~tvm::ffi::Tensor` too.
+.. literalinclude:: ../../examples/kernel_library/scale_kernel.cu
+ :language: cpp
+ :start-after: [export.begin]
+ :end-before: [export.end]
+
+This creates a symbol ``__tvm_ffi_scale`` in the shared library.
+
+**Build.** Compile the kernel into a shared library using GCC/NVCC or CMake
+(see :doc:`../packaging/cpp_tooling` for full details):
+
+.. code-block:: bash
+
+ nvcc -shared -O3 scale_kernel.cu -o build/scale_kernel.so \
+ -Xcompiler -fPIC,-fvisibility=hidden \
+ $(tvm-ffi-config --cxxflags) \
+ $(tvm-ffi-config --ldflags) \
+ $(tvm-ffi-config --libs)
+
+**Optional arguments.** Wrap any argument type with
:cpp:class:`tvm::ffi::Optional`
+to accept ``None`` from the Python side:
+
+.. code-block:: cpp
+
+ void MyKernel(TensorView output, TensorView input,
+ Optional<TensorView> bias, Optional<double> scale) {
+ if (bias.has_value()) {
+ // use bias.value().data_ptr()
+ }
+ double s = scale.value_or(1.0);
+ }
+
+.. code-block:: python
-Stream
-======
+ mod.my_kernel(y, x, None, None) # no bias, default scale
+ mod.my_kernel(y, x, bias_tensor, 2.0) # with bias and scale
-Besides of tensors, stream context is another key concept in kernel library,
especially for kernel execution. And the kernel library should be able to
obtain the current stream context from ML framework via TVM FFI.
-Stream Obtaining
-----------------
+Load from Python
+~~~~~~~~~~~~~~~~
-In practice, TVM FFI maintains a stream context table per device type and
index. And kernel libraries can obtain the current stream context on specific
device by :cpp:func:`TVMFFIEnvGetStream`. Here is an example:
+Use :py:func:`tvm_ffi.load_module` to load the library and call its functions.
+PyTorch tensors (and other framework tensors) are automatically converted to
+:cpp:class:`~tvm::ffi::TensorView` at the ABI boundary:
-.. code-block:: c++
+.. literalinclude:: ../../examples/kernel_library/load_scale.py
+ :language: python
+ :start-after: [load_and_call.begin]
+ :end-before: [load_and_call.end]
- void func(ffi::TensorView input, ...) {
- ffi::DLDevice device = input.device();
- cudaStream_t stream =
reinterpret_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
- }
+See :doc:`../get_started/quickstart` for examples with JAX, PaddlePaddle,
+NumPy, CuPy, Rust, and pure C++.
-which is equivalent to:
-.. code-block:: c++
+Tensor Handling
+---------------
- void func(at::Tensor input, ...) {
- c10::Device = input.device();
- cudaStream_t stream =
reinterpret_cast<cudaStream_t>(c10::cuda::getCurrentCUDAStream(device.index()).stream());
- }
+TensorView vs Tensor
+~~~~~~~~~~~~~~~~~~~~
-Stream Update
--------------
+TVM-FFI provides two tensor types (see :doc:`../concepts/tensor` for full
details):
-Corresponding to :cpp:func:`TVMFFIEnvGetStream`, TVM FFI updates the stream
context table via interface :cpp:func:`TVMFFIEnvSetStream`. But the updating
methods can be implicit and explicit.
+:cpp:class:`~tvm::ffi::TensorView` *(non-owning)*
+ A lightweight view of an existing tensor. **Use this for kernel parameters.**
+ It adds no reference count overhead and works with all framework tensors.
-Implicit Update
-^^^^^^^^^^^^^^^
+:cpp:class:`~tvm::ffi::Tensor` *(owning)*
+ A reference-counted tensor that manages its own lifetime. Use this only when
+ you need to **allocate and return** a tensor from C++.
-Similar to the tensor allocation
:ref:`guides/kernel_library_guide:FromNDAlloc`, TVM FFI does the implicit
update on stream context table as well. When converting the framework tensors
as mentioned above, TVM FFI automatically updates the stream context table, by
the device on which the converted framework tensors. For example, if there is
an framework tensor as ``torch.Tensor(device="cuda:3")``, TVM FFI would
automatically update the current stream of cuda device 3 to torch current cont
[...]
+.. important::
-Explicit Update
-^^^^^^^^^^^^^^^
+ Prefer :cpp:class:`~tvm::ffi::TensorView` in kernel signatures. It is more
+ lightweight, supports more use cases (including XLA buffers that only
provide
+ views), and avoids unnecessary reference counting.
-Once the devices on which the stream contexts reside cannot be inferred from
the tensors, the explicit update on stream context table is necessary. TVM FFI
provides :py:func:`tvm_ffi.use_torch_stream` and
:py:func:`tvm_ffi.use_raw_stream` for manual stream context update. However, it
is **recommended** to use implicit update above, to reduce code complexity.
-Device Guard
-============
+Tensor Metadata
+~~~~~~~~~~~~~~~
-When launching kernels, kernel libraries may require the current device
context to be set for a specific device. TVM FFI provides the
:cpp:class:`tvm::ffi::CUDADeviceGuard` class to manage this, similar to
:cpp:class:`c10::cuda::CUDAGuard`. When a
:cpp:class:`tvm::ffi::CUDADeviceGuard` object is constructed with a device
index, it saves the original device index (retrieved using ``cudaGetDevice``)
and sets the current device to the given index (using ``cudaSetDevice``). Upon
destruction [...]
+Both :cpp:class:`~tvm::ffi::TensorView` and :cpp:class:`~tvm::ffi::Tensor`
expose
+identical metadata accessors. These are the methods kernel code uses most:
+validating inputs, computing launch parameters, and accessing data pointers.
-.. code-block:: c++
+**Shape and elements.**
+:cpp:func:`~tvm::ffi::TensorView::ndim` returns the number of dimensions,
+:cpp:func:`~tvm::ffi::TensorView::shape` returns the full shape as a
+:cpp:class:`~tvm::ffi::ShapeView` (a lightweight ``span``-like view of
+``int64_t``), and :cpp:func:`~tvm::ffi::TensorView::size` returns the size of a
+single dimension (supports negative indexing, e.g. ``size(-1)`` for the last
+dimension). :cpp:func:`~tvm::ffi::TensorView::numel` returns the total element
+count — use it for computing grid dimensions:
- void func(ffi::TensorView input, ...) {
- // current device index is original device index
- ffi::CUDADeviceGuard device_guard(input.device().device_id);
- // current device index is input device index
- }
+.. code-block:: cpp
-After ``func`` returns, the ``device_guard`` is destructed, and the original
device index is restored.
+ int64_t n = input.numel();
+ int threads = 256;
+ int blocks = (n + threads - 1) / threads;
-Function Exporting
-==================
+**Dtype.** :cpp:func:`~tvm::ffi::TensorView::dtype` returns a
:c:struct:`DLDataType`
+with three fields: ``code`` (e.g. ``kDLFloat``, ``kDLBfloat``), ``bits``
+(e.g. 16, 32), and ``lanes`` (almost always 1). Compare it against predefined
+constants to dispatch on dtype:
-As we already have our kernel library wrapped with TVM FFI interface, our next
and final step is exporting kernel library to Python side. TVM FFI provides
macro :c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC` for exporting the kernel
functions to the output library files. So that at Python side, it is possible
to load the library files and call the kernel functions directly. For example,
we export our kernels as:
+.. code-block:: cpp
-.. code-block:: c++
+ constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1};
+ if (input.dtype() == dl_float32) { ... }
- void func(ffi::TensorView input, ffi::TensorView output);
- TVM_FFI_DLL_EXPORT_TYPED_FUNC(func_name, func);
+**Device.** :cpp:func:`~tvm::ffi::TensorView::device` returns a
:c:struct:`DLDevice`
+with ``device_type`` (e.g. ``kDLCUDA``) and ``device_id``. Use these for
+validation and to set the device guard:
-And then we compile the sources into ``lib.so``, or ``lib.dylib`` for macOS,
or ``lib.dll`` for Windows. Finally, we can load and call our kernel functions
at Python side as:
+.. code-block:: cpp
+
+ TVM_FFI_ICHECK_EQ(input.device().device_type, kDLCUDA);
+ ffi::CUDADeviceGuard guard(input.device().device_id);
+
+**Data pointer.** :cpp:func:`~tvm::ffi::TensorView::data_ptr` returns
``void*``;
+cast it to the appropriate typed pointer before passing it to a kernel:
+
+.. code-block:: cpp
+
+ auto* out = static_cast<float*>(output.data_ptr());
+ auto* in = static_cast<float*>(input.data_ptr());
+
+**Strides and contiguity.**
+:cpp:func:`~tvm::ffi::TensorView::strides` returns the stride array as a
+:cpp:class:`~tvm::ffi::ShapeView`, and
+:cpp:func:`~tvm::ffi::TensorView::stride` returns a single dimension's stride.
+:cpp:func:`~tvm::ffi::TensorView::IsContiguous` checks whether the tensor is
+contiguous in memory. Most kernels require contiguous inputs — the
+``CHECK_CONTIGUOUS`` macro shown above enforces this at the top of each
function.
+
+.. tip::
+
+ The API is designed to be familiar to PyTorch developers.
+ ``dim()``, ``sizes()``, ``size(i)``, ``stride(i)``, and ``is_contiguous()``
+ are all available as aliases of their TVM-FFI counterparts.
+ See :doc:`../concepts/tensor` for the full API reference.
+
+
+Tensor Allocation
+~~~~~~~~~~~~~~~~~
+
+**Always pre-allocate output tensors on the Python side** and pass them into
the
+kernel as :cpp:class:`~tvm::ffi::TensorView` parameters. Allocating tensors
+inside a kernel function is almost never the right choice:
+
+- it causes **memory fragmentation** from repeated small allocations,
+- it **breaks CUDA graph capture**, which requires deterministic memory
addresses, and
+- it **bypasses the framework's allocator** (caching pools, device placement,
memory planning).
+
+The pre-allocation pattern is straightforward:
.. code-block:: python
- mod = tvm_ffi.load_module("lib.so")
- x = ...
- y = ...
- mod.func_name(x, y)
+ # Python: pre-allocate output
+ y = torch.empty_like(x)
+ mod.scale(y, x, 2.0)
+
+.. code-block:: cpp
+
+ // C++: kernel writes into pre-allocated output
+ void Scale(TensorView output, TensorView input, double factor);
+
+If C++-side allocation is truly unavoidable — for example, when the output
shape
+is data-dependent and cannot be determined before the kernel runs — use
+:cpp:func:`tvm::ffi::Tensor::FromEnvAlloc` to at least reuse the host
+framework's allocator (e.g., ``torch.empty`` under PyTorch):
+
+.. literalinclude:: ../../examples/kernel_library/tvm_ffi_utils.h
+ :language: cpp
+ :start-after: [alloc_tensor.begin]
+ :end-before: [alloc_tensor.end]
+
+For custom allocators (e.g., ``cudaMalloc``/``cudaFree``), use
+:cpp:func:`tvm::ffi::Tensor::FromNDAlloc`. Note that the kernel library must
+outlive any tensors allocated this way, since the custom deleter lives in the
+library. See :doc:`../concepts/tensor` for details.
-``x`` and ``y`` here can be any ML framework tensors, such as
``torch.Tensor``, ``numpy.NDArray``, ``cupy.ndarray``, or other tensors as long
as TVM FFI supports. TVM FFI detects the tensor types in arguments and converts
them into :cpp:class:`~tvm::ffi::TensorView` or :cpp:class:`~tvm::ffi::Tensor`
automatically. So that we do not have to write the specific conversion codes
per framework.
-In constrast, if the kernel function returns :cpp:class:`~tvm::ffi::Tensor`
instead of ``void`` in the example above. TVM FFI automatically converts the
output :cpp:class:`~tvm::ffi::Tensor` to framework tensors also. The output
framework is inferred from the input framework tensors. For example, if the
input framework tensors are of ``torch.Tensor``, TVM FFI will convert the
output tensor to ``torch.Tensor``. And if none of the input tensors are from ML
framework, the output tensor will [...]
+Further Reading
+---------------
-Actually, it is **recommended** to pre-allocated input and output tensors from
framework at Python side alreadly. So that the return type of kernel functions
at C++ side should be ``void`` always.
+- :doc:`../get_started/quickstart`: End-to-end walkthrough shipping
``add_one`` across frameworks and languages
+- :doc:`../packaging/cpp_tooling`: Build toolchain, CMake integration,
GCC/NVCC flags, and library distribution
+- :doc:`../packaging/python_packaging`: Packaging kernel libraries as Python
wheels
+- :doc:`../concepts/tensor`: Tensor classes, DLPack interop, stream handling,
and allocation APIs
+- :doc:`../concepts/func_module`: Function calling convention, modules, and
the global registry
+- :doc:`../concepts/exception_handling`: Error handling across language
boundaries
+- :doc:`../concepts/abi_overview`: Low-level C ABI details
diff --git a/examples/kernel_library/load_scale.py
b/examples/kernel_library/load_scale.py
new file mode 100644
index 0000000..dbbc2a7
--- /dev/null
+++ b/examples/kernel_library/load_scale.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Load and call a scale kernel."""
+
+# [load_and_call.begin]
+import torch
+import tvm_ffi
+
+# Load the compiled shared library
+mod = tvm_ffi.load_module("build/scale_kernel.so")
+
+# Pre-allocate input and output tensors in PyTorch
+x = torch.randn(1024, device="cuda", dtype=torch.float32)
+y = torch.empty_like(x)
+
+# Call the kernel — PyTorch tensors are auto-converted to TensorView
+mod.scale(y, x, 2.0)
+
+assert torch.allclose(y, x * 2.0)
+# [load_and_call.end]
diff --git a/examples/kernel_library/scale_kernel.cu
b/examples/kernel_library/scale_kernel.cu
new file mode 100644
index 0000000..063f085
--- /dev/null
+++ b/examples/kernel_library/scale_kernel.cu
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "tvm_ffi_utils.h"
+
+// [cuda_kernel.begin]
+template <typename T>
+__global__ void ScaleKernel(T* out, const T* in, T factor, int64_t n) {
+ int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
+ if (i < n) {
+ out[i] = in[i] * factor;
+ }
+}
+// [cuda_kernel.end]
+
+// [function.begin]
+void Scale(TensorView output, TensorView input, double factor) {
+ // --- 1. Validate inputs ---
+ CHECK_INPUT(input);
+ CHECK_INPUT(output);
+ CHECK_DIM(1, input);
+ CHECK_DEVICE(input, output);
+ TVM_FFI_CHECK(input.dtype() == output.dtype(), ValueError) << "input/output
dtype mismatch";
+ TVM_FFI_CHECK(input.numel() == output.numel(), ValueError) << "input/output
size mismatch";
+
+ // --- 2. Device guard and stream ---
+ ffi::CUDADeviceGuard guard(input.device().device_id);
+ cudaStream_t stream = get_cuda_stream(input.device());
+
+ // --- 3. Dispatch on dtype and launch ---
+ int64_t n = input.numel();
+ int threads = 256;
+ int blocks = (n + threads - 1) / threads;
+
+ if (input.dtype() == dl_float32) {
+ ScaleKernel<<<blocks, threads, 0,
stream>>>(static_cast<float*>(output.data_ptr()),
+
static_cast<float*>(input.data_ptr()),
+ static_cast<float>(factor), n);
+ } else if (input.dtype() == dl_float16) {
+ ScaleKernel<<<blocks, threads, 0,
stream>>>(static_cast<half*>(output.data_ptr()),
+
static_cast<half*>(input.data_ptr()),
+ static_cast<half>(factor), n);
+ } else {
+ TVM_FFI_THROW(TypeError) << "Unsupported dtype: " << input.dtype();
+ }
+}
+// [function.end]
+
+// [export.begin]
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(scale, Scale);
+// [export.end]
diff --git a/examples/kernel_library/tvm_ffi_utils.h
b/examples/kernel_library/tvm_ffi_utils.h
new file mode 100644
index 0000000..e99d3be
--- /dev/null
+++ b/examples/kernel_library/tvm_ffi_utils.h
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef KERNEL_LIBRARY_TVM_FFI_UTILS_H_
+#define KERNEL_LIBRARY_TVM_FFI_UTILS_H_
+
+#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/device_guard.h>
+#include <tvm/ffi/tvm_ffi.h>
+
+namespace ffi = tvm::ffi;
+using ffi::Optional;
+using ffi::Tensor;
+using ffi::TensorView;
+
+// [check_macros.begin]
+// --- Reusable validation macros ---
+#define CHECK_CUDA(x) \
+ TVM_FFI_CHECK((x).device().device_type == kDLCUDA, ValueError) << #x " must
be a CUDA tensor"
+#define CHECK_CONTIGUOUS(x) \
+ TVM_FFI_CHECK((x).IsContiguous(), ValueError) << #x " must be contiguous"
+#define CHECK_INPUT(x) \
+ do { \
+ CHECK_CUDA(x); \
+ CHECK_CONTIGUOUS(x); \
+ } while (0)
+#define CHECK_DIM(d, x) \
+ TVM_FFI_CHECK((x).ndim() == (d), ValueError) << #x " must be a " #d "D
tensor"
+#define CHECK_DEVICE(a, b)
\
+ do {
\
+ TVM_FFI_CHECK((a).device().device_type == (b).device().device_type,
ValueError) \
+ << #a " and " #b " must be on the same device type";
\
+ TVM_FFI_CHECK((a).device().device_id == (b).device().device_id,
ValueError) \
+ << #a " and " #b " must be on the same device";
\
+ } while (0)
+// [check_macros.end]
+
+// [get_stream.begin]
+// --- Stream helper ---
+inline cudaStream_t get_cuda_stream(DLDevice device) {
+ return static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
+}
+// [get_stream.end]
+
+// [alloc_tensor.begin]
+// --- Tensor allocation helper ---
+inline ffi::Tensor alloc_tensor(const ffi::Shape& shape, DLDataType dtype,
DLDevice device) {
+ return ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, shape, dtype, device);
+}
+// [alloc_tensor.end]
+
+// [dtype_constants.begin]
+// --- DLPack dtype constants ---
+constexpr DLDataType dl_float16 = DLDataType{kDLFloat, 16, 1};
+constexpr DLDataType dl_float32 = DLDataType{kDLFloat, 32, 1};
+constexpr DLDataType dl_float64 = DLDataType{kDLFloat, 64, 1};
+constexpr DLDataType dl_bfloat16 = DLDataType{kDLBfloat, 16, 1};
+// [dtype_constants.end]
+
+#endif // KERNEL_LIBRARY_TVM_FFI_UTILS_H_
diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h
index 7224ac1..acbd652 100644
--- a/include/tvm/ffi/base_details.h
+++ b/include/tvm/ffi/base_details.h
@@ -87,11 +87,9 @@
#define TVM_FFI_FUNC_SIG __func__
#endif
-#if defined(__GNUC__)
-// gcc and clang and attribute constructor
-/// \cond Doxygen_Suppress
-#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName) __attribute__((constructor))
static void FnName()
/// \endcond
+
+#if defined(TVM_FFI_DOXYGEN_MODE)
/*!
* \brief Macro that defines a block that will be called during static
initialization.
*
@@ -101,12 +99,14 @@
* }
* \endcode
*/
+#define TVM_FFI_STATIC_INIT_BLOCK()
+#elif defined(__GNUC__)
+// gcc and clang: attribute constructor
+#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName) __attribute__((constructor))
static void FnName()
#define TVM_FFI_STATIC_INIT_BLOCK() \
TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc,
__COUNTER__))
-
#else
-/// \cond Doxygen_Suppress
-// for other compilers, use the variable trick
+// other compilers: use the variable trick
#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName, RegVar) \
static void FnName(); \
[[maybe_unused]] static inline int RegVar = []() { \
@@ -118,9 +118,9 @@
#define TVM_FFI_STATIC_INIT_BLOCK()
\
TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc,
__COUNTER__), \
TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg,
__COUNTER__))
-/// \endcond
#endif
+/// \cond Doxygen_Suppress
/*
* \brief Define the default copy/move constructor and assign operator
* \param TypeName The class typename.