This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 2702f2b docs: add PaddlePaddle quickstart and load example (#415)
2702f2b is described below
commit 2702f2b08721b1ea45f2f7858e3be88b02378731
Author: Nyakku Shigure <[email protected]>
AuthorDate: Sun Jan 18 12:08:11 2026 +0800
docs: add PaddlePaddle quickstart and load example (#415)
PaddlePaddle's recently released version 3.3.0 includes full TVM FFI
support:
- DLPack 1.2 support (or 1.3) and many DLPack implementation fixes
- C DLPack exchange API (`__c_dlpack_exchange_api__` for tvm_ffi
0.1.0-0.1.4 / `__dlpack_c_exchange_api__` for tvm_ffi 0.1.5+)
- DataType exchange protocol (`__dlpack_data_type__`) and Device
exchange protocol (`__dlpack_device__`)
This PR documents and demonstrates PaddlePaddle's TVM FFI
interoperability introduced in PaddlePaddle 3.3.0.
The new example runs as expected:
<img width="1043" height="273" alt="image"
src="https://github.com/user-attachments/assets/80fe3239-3a7f-4de0-a725-cae0e27567b4"
/>
---
README.md | 2 +-
docs/concepts/tensor.rst | 4 ++--
docs/get_started/quickstart.rst | 20 ++++++++++++++----
docs/get_started/stable_c_abi.rst | 2 +-
examples/quickstart/README.md | 1 +
.../{run_all_cuda.sh => load/load_paddle.py} | 24 ++++++++++++----------
examples/quickstart/run_all_cuda.sh | 3 +++
7 files changed, 37 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index 3fbd2cc..0f2f22b 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,7 @@ yet flexible open convention with the following systems in
mind:
- **Kernel libraries** - ship one wheel to support multiple frameworks, Python
versions, and different languages. [[FlashInfer](https://docs.flashinfer.ai/)]
- **Kernel DSLs** - reusable open ABI for JIT and AOT kernel exposure
frameworks and runtimes.
[[TileLang](https://tilelang.com/)][[cuteDSL](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/cute_dsl_general/compile_with_tvm_ffi.html)]
-- **Frameworks and runtimes** - a uniform extension point for ABI-compliant
libraries and DSLs.
[[PyTorch](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-pytorch)][[JAX](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-jax)][[NumPy/CuPy](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-numpy)]
+- **Frameworks and runtimes** - a uniform extension point for ABI-compliant
libraries and DSLs.
[[PyTorch](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-pytorch)][[JAX](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-jax)][[PaddlePaddle](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-paddle)][[NumPy/CuPy](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-numpy)]
- **ML infrastructure** - out-of-box bindings and interop across languages.
[[Python](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-python)][[C++](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-cpp)][[Rust](https://tvm.apache.org/ffi/get_started/quickstart.html#ship-to-rust)]
- **Coding agents** - a unified mechanism for shipping generated code in
production.
diff --git a/docs/concepts/tensor.rst b/docs/concepts/tensor.rst
index d7f9343..2669f71 100644
--- a/docs/concepts/tensor.rst
+++ b/docs/concepts/tensor.rst
@@ -20,7 +20,7 @@ Tensor and DLPack
At runtime, TVM-FFI often needs to accept tensors from many sources:
-* Frameworks (e.g. PyTorch, JAX) via :py:meth:`array_api.array.__dlpack__`;
+* Frameworks (e.g. PyTorch, JAX, PaddlePaddle) via
:py:meth:`array_api.array.__dlpack__`;
* C/C++ callers passing :c:struct:`DLTensor* <DLTensor>`;
* Tensors allocated by a library but managed by TVM-FFI itself.
@@ -115,7 +115,7 @@ PyTorch Interop
On the Python side, :py:class:`tvm_ffi.Tensor` is a managed n-dimensional
array that:
-* can be created via :py:func:`tvm_ffi.from_dlpack(ext_tensor, ...)
<tvm_ffi.from_dlpack>` to import tensors from external frameworks, e.g.,
:ref:`PyTorch <ship-to-pytorch>`, :ref:`JAX <ship-to-jax>`, :ref:`NumPy/CuPy
<ship-to-numpy>`;
+* can be created via :py:func:`tvm_ffi.from_dlpack(ext_tensor, ...)
<tvm_ffi.from_dlpack>` to import tensors from external frameworks, e.g.,
:ref:`PyTorch <ship-to-pytorch>`, :ref:`JAX <ship-to-jax>`, :ref:`PaddlePaddle
<ship-to-paddle>`, :ref:`NumPy/CuPy <ship-to-numpy>`;
* implements the DLPack protocol so it can be passed back to frameworks
without copying, e.g., :py:func:`torch.from_dlpack`.
The following example demonstrates a typical round-trip pattern:
diff --git a/docs/get_started/quickstart.rst b/docs/get_started/quickstart.rst
index 6d608e7..f4ded6a 100644
--- a/docs/get_started/quickstart.rst
+++ b/docs/get_started/quickstart.rst
@@ -27,7 +27,7 @@ This guide walks through shipping a minimal ``add_one``
function that computes
TVM-FFI's Open ABI and FFI make it possible to **ship one library** for
multiple frameworks and languages.
We can build a single shared library that works across:
-- **ML frameworks**, e.g. PyTorch, JAX, NumPy, CuPy, and others;
+- **ML frameworks**, e.g. PyTorch, JAX, PaddlePaddle, NumPy, CuPy, and others;
- **Languages**, e.g. C++, Python, Rust, and others;
- **Python ABI versions**, e.g. one wheel that supports all Python versions,
including free-threaded ones.
@@ -37,7 +37,7 @@ We can build a single shared library that works across:
- Python: 3.9 or newer
- Compiler: C++17-capable toolchain (GCC/Clang/MSVC)
- - Optional ML frameworks for testing: NumPy, PyTorch, JAX, CuPy
+ - Optional ML frameworks for testing: NumPy, PyTorch, JAX, CuPy,
PaddlePaddle
- CUDA: Any modern version (if you want to try the CUDA part)
- TVM-FFI installed via:
@@ -90,7 +90,7 @@ it also exports the function's metadata as a symbol
``__tvm_ffi__metadata_add_on
The class :cpp:class:`tvm::ffi::TensorView` enables zero-copy interop with
tensors from different ML frameworks:
- NumPy, CuPy,
-- PyTorch, JAX, or
+- PyTorch, JAX, PaddlePaddle, or
- any array type that supports the standard :external+data-api:doc:`DLPack
protocol <design_topics/data_interchange>`.
Finally, :cpp:func:`TVMFFIEnvGetStream` can be used in the CUDA code to launch
kernels on the caller's stream.
@@ -162,7 +162,7 @@ TVM-FFI integrates with CMake via ``find_package`` as
demonstrated below:
- Python version/ABI. They are not compiled or linked with Python and depend
only on TVM-FFI's stable C ABI;
- Languages, including C++, Python, Rust, or any other language that can
interop with the C ABI;
-- ML frameworks, such as PyTorch, JAX, NumPy, CuPy, or any array library that
implements the standard :external+data-api:doc:`DLPack protocol
<design_topics/data_interchange>`.
+- ML frameworks, such as PyTorch, JAX, PaddlePaddle, NumPy, CuPy, or any array
library that implements the standard :external+data-api:doc:`DLPack protocol
<design_topics/data_interchange>`.
.. _sec-use-across-framework:
@@ -228,6 +228,18 @@ After installation, ``add_one_cuda`` can be registered as
a target for JAX's ``f
)(x)
print(y)
+.. _ship-to-paddle:
+
+PaddlePaddle
+~~~~~~~~~~~~
+
+Since PaddlePaddle 3.3.0, full TVM FFI support is provided.
+
+.. literalinclude:: ../../examples/quickstart/load/load_paddle.py
+ :language: python
+ :start-after: [example.begin]
+ :end-before: [example.end]
+
.. _ship-to-numpy:
NumPy/CuPy
diff --git a/docs/get_started/stable_c_abi.rst
b/docs/get_started/stable_c_abi.rst
index b8d8195..0f6dbd5 100644
--- a/docs/get_started/stable_c_abi.rst
+++ b/docs/get_started/stable_c_abi.rst
@@ -125,7 +125,7 @@ Stability and Interoperability
**Cross-language.** TVM-FFI implements this calling convention in multiple
languages (C, C++, Python, Rust, ...), enabling code written in one language -
or generated by a DSL targeting the ABI - to be called from another language.
-**Cross-framework.** TVM-FFI uses standard data structures such as
:external+data-api:doc:`DLPack tensors <design_topics/data_interchange>` to
represent arrays, so compiled functions can be used from any array framework
that implements the DLPack protocol (NumPy, PyTorch, TensorFlow, CuPy, JAX, and
others).
+**Cross-framework.** TVM-FFI uses standard data structures such as
:external+data-api:doc:`DLPack tensors <design_topics/data_interchange>` to
represent arrays, so compiled functions can be used from any array framework
that implements the DLPack protocol (NumPy, PyTorch, TensorFlow, CuPy, JAX,
PaddlePaddle, and others).
Stable ABI in C Code
diff --git a/examples/quickstart/README.md b/examples/quickstart/README.md
index 1c23eca..093ebb6 100644
--- a/examples/quickstart/README.md
+++ b/examples/quickstart/README.md
@@ -57,6 +57,7 @@ To run library loading examples across ML frameworks
(requires CUDA for the CUDA
```bash
python load/load_pytorch.py
+python load/load_paddle.py
python load/load_numpy.py
python load/load_cupy.py
```
diff --git a/examples/quickstart/run_all_cuda.sh
b/examples/quickstart/load/load_paddle.py
old mode 100755
new mode 100644
similarity index 67%
copy from examples/quickstart/run_all_cuda.sh
copy to examples/quickstart/load/load_paddle.py
index d27bf0b..1162e15
--- a/examples/quickstart/run_all_cuda.sh
+++ b/examples/quickstart/load/load_paddle.py
@@ -1,4 +1,3 @@
-#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -15,14 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-set -ex
+# fmt: off
+# ruff: noqa
+# mypy: ignore-errors
+# [example.begin]
+# File: load/load_paddle.py
+import tvm_ffi
+mod = tvm_ffi.load_module("build/add_one_cuda.so")
-# To compile `compile/add_one_cuda.cu` to shared library
`build/add_one_cuda.so`
-cmake . -B build -DEXAMPLE_NAME="compile_cuda"
-DCMAKE_BUILD_TYPE=RelWithDebInfo
-cmake --build build --config RelWithDebInfo
-
-# To load and run `add_one_cuda.so` in PyTorch
-python load/load_pytorch.py
-
-# To load and run `add_one_cuda.so` in CuPy
-python load/load_cupy.py
+import paddle
+x = paddle.tensor([1, 2, 3, 4, 5], dtype=paddle.float32, device="cuda")
+y = paddle.empty_like(x)
+mod.add_one_cuda(x, y)
+print(y)
+# [example.end]
diff --git a/examples/quickstart/run_all_cuda.sh
b/examples/quickstart/run_all_cuda.sh
index d27bf0b..d8807a5 100755
--- a/examples/quickstart/run_all_cuda.sh
+++ b/examples/quickstart/run_all_cuda.sh
@@ -26,3 +26,6 @@ python load/load_pytorch.py
# To load and run `add_one_cuda.so` in CuPy
python load/load_cupy.py
+
+# To load and run `add_one_cuda.so` in PaddlePaddle
+python load/load_paddle.py