This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8626100  doc: Polish quickstart doc and add Examples (#170)
8626100 is described below

commit 86261005000c8a3a912d953582ea654dfbcc3a38
Author: Junru Shao <[email protected]>
AuthorDate: Sat Oct 18 17:51:10 2025 -0700

    doc: Polish quickstart doc and add Examples (#170)
---
 .pre-commit-config.yaml                            |   5 +-
 docs/get_started/quickstart.rst                    | 308 +++++++--------------
 examples/quick_start/CMakeLists.txt                |  67 -----
 examples/quick_start/README.md                     |  91 ------
 examples/quick_start/run_example.py                |  93 -------
 examples/quick_start/src/add_one_c.c               |  72 -----
 examples/quick_start/src/run_example.cc            |  53 ----
 examples/quick_start/src/run_example_cuda.cc       |  94 -------
 examples/quickstart/CMakeLists.txt                 |  75 +++++
 examples/quickstart/README.md                      |  55 ++++
 .../src => quickstart/compile}/add_one_cpu.cc      |  28 +-
 .../src => quickstart/compile}/add_one_cuda.cu     |  36 +--
 examples/quickstart/load/load_cpp.cc               |  82 ++++++
 .../load/load_cupy.py}                             |  35 +--
 .../load/load_numpy.py}                            |  35 +--
 .../load/load_pytorch.py}                          |  37 +--
 examples/quickstart/raw_compile.sh                 |  61 ++++
 .../run_example.sh => quickstart/run_all_cpu.sh}   |  27 +-
 .../run_example.sh => quickstart/run_all_cuda.sh}  |  25 +-
 pyproject.toml                                     |   2 +
 20 files changed, 461 insertions(+), 820 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8737cff..be42c58 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,10 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# TODO(@junrushao): adding a few extra hooks:
-# - Python type checking via mypy or ty
-# - CMake linters
-# - Conventional commits
 default_install_hook_types:
   - pre-commit
 repos:
@@ -96,6 +92,7 @@ repos:
     rev: v3.12.0-2
     hooks:
       - id: shfmt
+        args: [--indent=2]
   - repo: https://github.com/shellcheck-py/shellcheck-py
     rev: v0.10.0.1
     hooks:
diff --git a/docs/get_started/quickstart.rst b/docs/get_started/quickstart.rst
index 3eb02d1..6f1ce00 100644
--- a/docs/get_started/quickstart.rst
+++ b/docs/get_started/quickstart.rst
@@ -18,14 +18,18 @@
 Quick Start
 ===========
 
+.. note::
+
+  All the code in this tutorial can be found under `examples/quickstart 
<https://github.com/apache/tvm-ffi/tree/main/examples/quickstart>`_ in the 
repository.
+
 This guide walks through shipping a minimal ``add_one`` function that computes
 ``y = x + 1`` in C++ and CUDA.
 TVM-FFI's Open ABI and FFI make it possible to **ship one library** for 
multiple frameworks and languages.
 We can build a single shared library that works across:
 
 - **ML frameworks**, e.g. PyTorch, JAX, NumPy, CuPy, etc., and
-- **languages**, e.g. C++, Python, Rust, etc.
-- **Python ABI versions**, e.g. ship one wheel to support multiple Python 
versions, including free-threaded Python.
+- **Languages**, e.g. C++, Python, Rust, etc.,
+- **Python ABI versions**, e.g. ship one wheel to support all Python versions, 
including free-threaded ones.
 
 .. admonition:: Prerequisite
    :class: hint
@@ -56,66 +60,23 @@ Suppose we implement a C++ function ``AddOne`` that 
performs elementwise ``y = x
 
   .. group-tab:: C++
 
-    .. code-block:: cpp
+    .. literalinclude:: ../../examples/quickstart/compile/add_one_cpu.cc
+      :language: cpp
       :emphasize-lines: 8, 17
-
-      // File: add_one_cpu.cc
-      #include <tvm/ffi/container/tensor.h>
-      #include <tvm/ffi/function.h>
-
-      namespace tvm_ffi_example_cpp {
-
-      /*! \brief Perform vector add one: y = x + 1 (1-D float32) */
-      void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
-        int64_t n = x.shape()[0];
-        float* x_data = static_cast<float *>(x.data_ptr());
-        float* y_data = static_cast<float *>(y.data_ptr());
-        for (int64_t i = 0; i < n; ++i) {
-          y_data[i] = x_data[i] + 1;
-        }
-      }
-
-      TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example_cpp::AddOne);
-      }
-
+      :start-after: [example.begin]
+      :end-before: [example.end]
 
   .. group-tab:: CUDA
 
-    .. code-block:: cpp
+    .. literalinclude:: ../../examples/quickstart/compile/add_one_cuda.cu
+      :language: cpp
       :emphasize-lines: 15, 22, 26
-
-      // File: main.cu
-      #include <tvm/ffi/container/tensor.h>
-      #include <tvm/ffi/extra/c_env_api.h>
-      #include <tvm/ffi/function.h>
-
-      namespace tvm_ffi_example_cuda {
-
-      __global__ void AddOneKernel(float* x, float* y, int n) {
-        int idx = blockIdx.x * blockDim.x + threadIdx.x;
-        if (idx < n) {
-          y[idx] = x[idx] + 1;
-        }
-      }
-
-      void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
-        int64_t n = x.shape()[0];
-        float* x_data = static_cast<float *>(x.data_ptr());
-        float* y_data = static_cast<float *>(y.data_ptr());
-        int64_t threads = 256;
-        int64_t blocks = (n + threads - 1) / threads;
-        cudaStream_t stream = static_cast<cudaStream_t>(
-          TVMFFIEnvGetStream(x.device().device_type, x.device().device_id));
-        AddOneKernel<<<blocks, threads, 0, stream>>>(x_data, y_data, n);
-      }
-
-      TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, 
tvm_ffi_example_cuda::AddOne);
-      }
-
+      :start-after: [example.begin]
+      :end-before: [example.end]
 
 
 The macro :c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC` exports the C++ function 
``AddOne``
-as a TVM FFI compatible symbol with the name ``add_one_cpu`` or 
``add_one_cuda`` in the resulting library.
+as a TVM FFI compatible symbol with the name ``__tvm_ffi_add_one_cpu/cuda`` in 
the resulting library.
 
 The class :cpp:class:`tvm::ffi::TensorView` allows zero-copy interop with 
tensors from different ML frameworks:
 
@@ -136,31 +97,28 @@ Compile with TVM-FFI
 
   .. group-tab:: C++
 
-    .. code-block:: bash
-
-      g++ -shared -O3 add_one_cpu.cc                   \
-          -fPIC -fvisibility=hidden             \
-          `tvm-ffi-config --cxxflags`           \
-          `tvm-ffi-config --ldflags`            \
-          `tvm-ffi-config --libs`               \
-          -o add_one_cpu.so
+    .. literalinclude:: ../../examples/quickstart/raw_compile.sh
+      :language: bash
+      :start-after: [cpp_compile.begin]
+      :end-before: [cpp_compile.end]
 
   .. group-tab:: CUDA
 
-    .. code-block:: bash
-
-      nvcc -shared -O3 add_one_cuda.cu                  \
-        --compiler-options -fPIC                \
-        --compiler-options -fvisibility=hidden  \
-        `tvm-ffi-config --cxxflags`             \
-        `tvm-ffi-config --ldflags`              \
-        `tvm-ffi-config --libs`                 \
-        -o add_one_cuda.so
+    .. literalinclude:: ../../examples/quickstart/raw_compile.sh
+      :language: bash
+      :start-after: [cuda_compile.begin]
+      :end-before: [cuda_compile.end]
 
 This step produces a shared library ``add_one_cpu.so`` and ``add_one_cuda.so`` 
that can be used across languages and frameworks.
 
-**CMake.** As the preferred approach for building across platforms,
-CMake relies on the CMake package ``tvm_ffi``, which can be found via 
``tvm-ffi-config --cmakedir``.
+.. hint::
+
+   For a single-file C++/CUDA project, a convenient method 
:py:func:`tvm_ffi.cpp.load_inline`
+   is provided to minimize boilerplate code in compilation, linking, and 
loading.
+
+
+**CMake.** CMake is the preferred approach for building across platforms.
+TVM-FFI natively integrates with CMake via ``find_package`` as demonstrated 
below:
 
 .. tabs::
 
@@ -168,51 +126,36 @@ CMake relies on the CMake package ``tvm_ffi``, which can 
be found via ``tvm-ffi-
 
     .. code-block:: cmake
 
+      # Run `tvm-ffi-config --cmakedir` to set `tvm_ffi_DIR`
       find_package(Python COMPONENTS Interpreter REQUIRED)
-      # Run `tvm_ffi.config --cmakedir` to find tvm-ffi targets
-      execute_process(
-        COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
-        OUTPUT_STRIP_TRAILING_WHITESPACE
-        OUTPUT_VARIABLE tvm_ffi_ROOT
-      )
+      execute_process(COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config 
--cmakedir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
       find_package(tvm_ffi CONFIG REQUIRED)
-      # Create C++ target `add_one_cpu`
-      add_library(add_one_cpu SHARED add_one_cpu.cc)
+
+      # Link C++ target to `tvm_ffi_header` and `tvm_ffi_shared`
+      add_library(add_one_cpu SHARED compile/add_one_cpu.cc)
       target_link_libraries(add_one_cpu PRIVATE tvm_ffi_header)
       target_link_libraries(add_one_cpu PRIVATE tvm_ffi_shared)
-      # show as add_one_cpu.so
-      set_target_properties(add_one_cpu PROPERTIES PREFIX "" SUFFIX ".so")
 
   .. group-tab:: CUDA
 
     .. code-block:: cmake
 
+      enable_language(CUDA)
+      # Run `tvm-ffi-config --cmakedir` to set `tvm_ffi_DIR`
       find_package(Python COMPONENTS Interpreter REQUIRED)
-      # Run `tvm_ffi.config --cmakedir` to find tvm-ffi targets
-      execute_process(
-        COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
-        OUTPUT_STRIP_TRAILING_WHITESPACE
-        OUTPUT_VARIABLE tvm_ffi_ROOT
-      )
+      execute_process(COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config 
--cmakedir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE tvm_ffi_ROOT)
       find_package(tvm_ffi CONFIG REQUIRED)
-      # Create C++ target `add_one_cuda`
-      enable_language(CUDA)
-      add_library(add_one_cuda SHARED add_one_cuda.cu)
+
+      # Link CUDA target to `tvm_ffi_header` and `tvm_ffi_shared`
+      add_library(add_one_cuda SHARED compile/add_one_cuda.cu)
       target_link_libraries(add_one_cuda PRIVATE tvm_ffi_header)
       target_link_libraries(add_one_cuda PRIVATE tvm_ffi_shared)
-      # show as add_one_cuda.so
-      set_target_properties(add_one_cuda PROPERTIES PREFIX "" SUFFIX ".so")
 
-.. hint::
+**Artifact.** The resulting ``add_one_cpu.so`` and ``add_one_cuda.so`` are 
minimal libraries that are agnostic to:
 
-   For a single-file C++/CUDA project, a convenient method 
:py:func:`tvm_ffi.cpp.load_inline`
-   is provided to minimize boilerplate code in compilation, linking, and 
loading.
-
-The resulting ``add_one_cpu.so`` and ``add_one_cuda.so`` are minimal libraries 
that are agnostic to:
-
-- Python version/ABI, because it is pure C++ and not compiled or linked 
against Python
-- C++ ABI, because TVM-FFI interacts with the artifact only via stable C APIs
-- Languages, which can be C++, Rust or Python.
+- Python version/ABI. It is not compiled/linked with Python and depends only 
on TVM-FFI's stable C ABI;
+- Languages, including C++, Python, Rust or any other language that can 
interop with C ABI;
+- ML frameworks, such as PyTorch, JAX, NumPy, CuPy, or anything with standard 
`DLPack protocol 
<https://data-apis.org/array-api/2024.12/design_topics/data_interchange.html>`_.
 
 .. _sec-use-across-framework:
 
@@ -229,82 +172,72 @@ the ``add_one_cpu.so`` or ``add_one_cuda.so`` into 
:py:class:`tvm_ffi.Module`.
    func : tvm_ffi.Function = mod.add_one_cpu
 
 ``mod.add_one_cpu`` retrieves a callable :py:class:`tvm_ffi.Function` that 
accepts tensors from host frameworks
-directly, which can be zero-copy incorporated into all popular ML frameworks. 
This process is done seamlessly
-without any boilerplate code and with extremely low latency.
-We can then use these functions in the following ways:
+directly. This process is done zero-copy, without any boilerplate code, under 
extremely low latency.
 
+We can then use these functions in the following ways:
 
 .. tab-set::
 
     .. tab-item:: PyTorch
 
-        .. code-block:: python
+        .. literalinclude:: ../../examples/quickstart/load/load_pytorch.py
+          :language: python
+          :start-after: [example.begin]
+          :end-before: [example.end]
 
-          import torch
-          # cpu also works by changing the module to add_one_cpu.so and device 
to "cpu"
-          mod = tvm_ffi.load_module("add_one_cuda.so")
-          device = "cuda"
-          x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device=device)
-          y = torch.empty_like(x)
-          mod.add_one_cuda(x, y)
-          print(y)
+    .. tab-item:: JAX
 
+        Support via `nvidia/jax-tvm-ffi 
<https://github.com/nvidia/jax-tvm-ffi>`_. This can be installed via
 
-    .. tab-item:: JAX
+        .. code-block:: bash
+
+          pip install jax-tvm-ffi
 
-        Support via `jax-tvm-ffi <https://github.com/nvidia/jax-tvm-ffi>`_
+        After installation, ``add_one_cuda`` can be registered as a target to 
JAX's ``ffi_call``.
 
         .. code-block:: python
 
-          import jax
-          import jax.numpy as jnp
-          import jax_tvm_ffi
+          # Step 1. Load `build/add_one_cuda.so`
           import tvm_ffi
+          mod = tvm_ffi.load_module("build/add_one_cuda.so")
 
-          mod = tvm_ffi.load_module("add_one_cuda.so")
+          # Step 2. Register `mod.add_one_cuda` into JAX
+          import jax_tvm_ffi
+          jax_tvm_ffi.register_ffi_target("add_one", mod.add_one_cuda, 
platform="gpu")
 
-          # Register the function with JAX
-          jax_tvm_ffi.register_ffi_target("add_one_cuda", mod.add_one_cuda, 
platform="cuda")
-          x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
+          # Step 3. Run `mod.add_one_cuda` with JAX
+          import jax
+          import jax.numpy as jnp
+          jax_device, *_ = jax.devices("gpu")
+          x = jnp.array([1, 2, 3, 4, 5], dtype=jnp.float32, device=jax_device)
           y = jax.ffi.ffi_call(
-              "add_one_cuda",
-              jax.ShapeDtypeStruct(x.shape, x.dtype),
+              "add_one",  # name of the registered function
+              jax.ShapeDtypeStruct(x.shape, x.dtype),  # shape and dtype of 
the output
               vmap_method="broadcast_all",
           )(x)
           print(y)
 
-    .. tab-item:: NumPy (CPU)
-
-        .. code-block:: python
-
-          import numpy as np
-
-          mod = tvm_ffi.load_module("add_one_cpu.so")
-          x = np.array([1, 2, 3, 4, 5], dtype=np.float32)
-          y = np.empty_like(x)
-          mod.add_one_cpu(x, y)
-          print(y)
-
-    .. tab-item:: CuPy (CUDA)
+    .. tab-item:: NumPy
 
-        .. code-block:: python
+        .. literalinclude:: ../../examples/quickstart/load/load_numpy.py
+          :language: python
+          :start-after: [example.begin]
+          :end-before: [example.end]
 
-          import cupy as cp
+    .. tab-item:: CuPy
 
-          mod = tvm_ffi.load_module("add_one_cuda.so")
-          x = cp.array([1, 2, 3, 4, 5], dtype=cp.float32)
-          y = cp.empty_like(x)
-          mod.add_one_cuda(x, y)
-          print(y)
+        .. literalinclude:: ../../examples/quickstart/load/load_cupy.py
+          :language: python
+          :start-after: [example.begin]
+          :end-before: [example.end]
 
 
 Ship Across Languages
 ---------------------
 
 TVM-FFI's core loading mechanism is ABI stable and works across language 
boundaries.
-A single artifact can be loaded in every language TVM-FFI supports,
-without having to recompile different artifacts targeting different ABIs or 
languages.
-
+A single library can be loaded in every language TVM-FFI supports,
+without having to recompile different libraries targeting different ABIs or 
languages.
 
 Python
 ~~~~~~
@@ -319,79 +252,42 @@ C++
 TVM-FFI's C++ API :cpp:func:`tvm::ffi::Module::LoadFromFile` loads 
``add_one_cpu.so`` or ``add_one_cuda.so`` and
 can be used directly in C/C++ with no Python dependency.
 
-.. code-block:: cpp
-
-  // File: run_example.cc
-  #include <tvm/ffi/container/tensor.h>
-  #include <tvm/ffi/extra/module.h>
-
-  namespace ffi = tvm::ffi;
-  struct CPUNDAlloc {
-    void AllocData(DLTensor* tensor) { tensor->data = 
malloc(ffi::GetDataSize(*tensor)); }
-    void FreeData(DLTensor* tensor) { free(tensor->data); }
-  };
-
-  inline ffi::Tensor Empty(ffi::Shape shape, DLDataType dtype, DLDevice 
device) {
-    return ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device);
-  }
+.. literalinclude:: ../../examples/quickstart/load/load_cpp.cc
+   :language: cpp
+   :start-after: [example.begin]
+   :end-before: [example.end]
 
-  int main() {
-    // load the module
-    ffi::Module mod = ffi::Module::LoadFromFile("add_one_cpu.so");
+Compile and run it with:
 
-    // create an Tensor, alternatively, one can directly pass in a DLTensor*
-    ffi::Tensor x = Empty({5}, DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
-    for (int i = 0; i < 5; ++i) {
-      reinterpret_cast<float*>(x.data_ptr())[i] = static_cast<float>(i);
-    }
-
-    ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
-    add_one_cpu(x, x);
-
-    std::cout << "x after add_one_cpu(x, x)" << std::endl;
-    for (int i = 0; i < 5; ++i) {
-      std::cout << reinterpret_cast<float*>(x.data_ptr())[i] << " ";
-    }
-    std::cout << std::endl;
-    return 0;
-  }
-
-Compile it with:
+.. literalinclude:: ../../examples/quickstart/raw_compile.sh
+   :language: bash
+   :start-after: [load_cpp.begin]
+   :end-before: [load_cpp.end]
 
-.. code-block:: bash
+.. note::
 
-    g++ -fvisibility=hidden -O3               \
-        run_example.cc                        \
-        `tvm-ffi-config --cxxflags`           \
-        `tvm-ffi-config --ldflags`            \
-        `tvm-ffi-config --libs`               \
-        -Wl,-rpath,`tvm-ffi-config --libdir`  \
-        -o run_example
+  Don't like loading shared libraries? Static linking is also supported.
 
-    ./run_example
-
-.. hint::
-
-  Sometimes it may be desirable to directly bundle the exported module into 
the same binary as the main program.
   In such cases, we can use :cpp:func:`tvm::ffi::Function::FromExternC` to 
create a
   :cpp:class:`tvm::ffi::Function` from the exported symbol, or directly use
-  :cpp:func:`tvm::ffi::Function::InvokeExternC` to invoke the function. This 
feature can be useful
-  when the exported module is generated by another DSL compiler matching the 
ABI.
+  :cpp:func:`tvm::ffi::Function::InvokeExternC` to invoke the function.
+
+  This feature can be useful on iOS, or when the exported module is generated 
by another DSL compiler matching the ABI.
 
   .. code-block:: cpp
 
-      // File: test_bundle.cc, link with libmain.o
+      // Linked with `add_one_cpu.o` or `add_one_cuda.o`
       #include <tvm/ffi/function.h>
       #include <tvm/ffi/container/tensor.h>
 
       // declare reference to the exported symbol
-      extern "C" int __tvm_ffi_add_one(void*, const TVMFFIAny*, int32_t, 
TVMFFIAny*);
+      extern "C" int __tvm_ffi_add_one_cpu(void*, const TVMFFIAny*, int32_t, 
TVMFFIAny*);
 
       namespace ffi = tvm::ffi;
 
       int bundle_add_one(ffi::TensorView x, ffi::TensorView y) {
         void* closure_handle = nullptr;
-        ffi::Function::InvokeExternC(closure_handle, __tvm_ffi_add_one, x, y);
+        ffi::Function::InvokeExternC(closure_handle, __tvm_ffi_add_one_cpu, x, 
y);
         return 0;
       }
 
@@ -406,8 +302,8 @@ This procedure is identical to those in C++ and Python:
 
     fn run_add_one(x: &Tensor, y: &Tensor) -> Result<()> {
         let module = tvm_ffi::Module::load_from_file("add_one_cpu.so")?;
-        let fn = module.get_function("add_one_cpu")?;
-        let typed_fn = into_typed_fn!(fn, Fn(&Tensor, &Tensor) -> Result<()>);
+        let func = module.get_function("add_one_cpu")?;
+        let typed_fn = into_typed_fn!(func, Fn(&Tensor, &Tensor) -> 
Result<()>);
         typed_fn(x, y)?;
         Ok(())
     }
@@ -423,5 +319,5 @@ Troubleshooting
 ---------------
 
 - ``OSError: cannot open shared object file``: Add an rpath (Linux/macOS) or 
ensure the DLL is on ``PATH`` (Windows). Example run-path: 
``-Wl,-rpath,`tvm-ffi-config --libdir```.
-- ``undefined symbol: __tvm_ffi_add_one``: Ensure you used 
``TVM_FFI_DLL_EXPORT_TYPED_FUNC`` and compiled with default symbol visibility 
(``-fvisibility=hidden`` is fine; the macro ensures export).
+- ``undefined symbol: __tvm_ffi_add_one_cpu``: Ensure you used 
:c:macro:`TVM_FFI_DLL_EXPORT_TYPED_FUNC` and compiled with default symbol 
visibility (``-fvisibility=hidden`` is fine; the macro ensures export).
 - ``CUDA error: invalid device function``: Rebuild with the correct 
``-arch=sm_XX`` for your GPU, or include multiple ``-gencode`` entries.
diff --git a/examples/quick_start/CMakeLists.txt 
b/examples/quick_start/CMakeLists.txt
deleted file mode 100644
index 0ba983d..0000000
--- a/examples/quick_start/CMakeLists.txt
+++ /dev/null
@@ -1,67 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-cmake_minimum_required(VERSION 3.18)
-project(tvm_ffi_example)
-
-# Discover the Python interpreter so we can query tvm-ffi for its CMake 
package path.
-find_package(
-  Python
-  COMPONENTS Interpreter
-  REQUIRED
-)
-
-# Run `tvm_ffi.config --cmakedir` to find tvm-ffi targets
-execute_process(
-  COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
-  OUTPUT_STRIP_TRAILING_WHITESPACE
-  OUTPUT_VARIABLE tvm_ffi_ROOT
-)
-
-find_package(tvm_ffi CONFIG REQUIRED)
-
-# Build the CPU and C versions of the simple "add one" function that the 
examples call.
-add_library(add_one_cpu SHARED src/add_one_cpu.cc)
-add_library(add_one_c SHARED src/add_one_c.c)
-target_link_libraries(add_one_cpu tvm_ffi_header)
-target_link_libraries(add_one_cpu tvm_ffi_shared)
-target_link_libraries(add_one_c tvm_ffi_shared)
-# show as add_one_cpu.so
-set_target_properties(add_one_cpu PROPERTIES PREFIX "" SUFFIX ".so")
-set_target_properties(add_one_c PROPERTIES PREFIX "" SUFFIX ".so")
-
-# Optionally build the CUDA variant if the CUDA toolkit is present.
-if (NOT WIN32)
-  find_package(CUDAToolkit QUIET)
-  if (CUDAToolkit_FOUND)
-    enable_language(CUDA)
-
-    add_library(add_one_cuda SHARED src/add_one_cuda.cu)
-    target_link_libraries(add_one_cuda PRIVATE tvm_ffi_shared)
-
-    set_target_properties(add_one_cuda PROPERTIES PREFIX "" SUFFIX ".so")
-
-    add_executable(run_example_cuda src/run_example_cuda.cc)
-    set_target_properties(run_example_cuda PROPERTIES CXX_STANDARD 17)
-    target_link_libraries(run_example_cuda PRIVATE tvm_ffi_shared CUDA::cudart)
-  endif ()
-endif ()
-
-# CPU-only C++ driver used in the quick start guide.
-add_executable(run_example src/run_example.cc)
-set_target_properties(run_example PROPERTIES CXX_STANDARD 17)
-target_link_libraries(run_example tvm_ffi_shared)
diff --git a/examples/quick_start/README.md b/examples/quick_start/README.md
deleted file mode 100644
index 14e62e8..0000000
--- a/examples/quick_start/README.md
+++ /dev/null
@@ -1,91 +0,0 @@
-<!--- Licensed to the Apache Software Foundation (ASF) under one -->
-<!--- or more contributor license agreements.  See the NOTICE file -->
-<!--- distributed with this work for additional information -->
-<!--- regarding copyright ownership.  The ASF licenses this file -->
-<!--- to you under the Apache License, Version 2.0 (the -->
-<!--- "License"); you may not use this file except in compliance -->
-<!--- with the License.  You may obtain a copy of the License at -->
-
-<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
-
-<!--- Unless required by applicable law or agreed to in writing, -->
-<!--- software distributed under the License is distributed on an -->
-<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
-<!--- KIND, either express or implied.  See the License for the -->
-<!--- specific language governing permissions and limitations -->
-<!--- under the License. -->
-
-# Getting Started with TVM FFI
-
-This example demonstrates how to use tvm-ffi to expose a universal function
-that can be loaded in different environments.
-
-The example implements a simple "add one" operation that adds 1 to each element
-of an input tensor, showing how to create C++ functions callable from Python.
-
-## Prerequisites
-
-Before running the quick start, ensure you have:
-
-- tvm-ffi installed locally (editable installs are convenient while iterating):
-- Installation guide: [Installation 
guide](https://tvm.apache.org/ffi/get_started/install.html)
-
-```bash
-# From the quick_start directory
-# install and include test dependency(this will install torch and numpy)
-pip install -ve "../..[test]"
-```
-
-## Run the Quick Start
-
-From `examples/quick_start` you can build and run everything with the helper 
script:
-
-```bash
-cd examples/quick_start
-./run_example.sh
-```
-
-The script picks an available CMake generator (preferring Ninja), configures a 
build in `build/`, compiles the C++ libraries and examples,
-and finally runs the Python and C++ demos. If the CUDA toolkit is detected it 
will also build and execute `run_example_cuda`.
-
-If you prefer to drive the build manually, run the following instead:
-
-```bash
-# configure (omit -G Ninja if Ninja is not installed)
-cmake -G Ninja -B build -S .
-
-# compile the example targets
-cmake --build build --parallel
-
-# run the demos
-python run_example.py
-./build/run_example
-./build/run_example_cuda  # optional, requires CUDA toolkit
-```
-
-At a high level, the `TVM_FFI_DLL_EXPORT_TYPED_FUNC` macro helps to expose
-a C++ function into the TVM FFI C ABI convention for functions.
-Then the function can be accessed by different environments and languages
-that interface with the TVM FFI. The current example shows how to do so
-in Python and C++.
-
-## Key Files
-
-- `src/add_one_cpu.cc` - CPU implementation of the add_one function
-- `src/add_one_c.c` - C implementation showing the C ABI workflow
-- `src/add_one_cuda.cu` - CUDA implementation for GPU operations
-- `src/run_example.cc` - C++ example showing how to call the functions
-- `src/run_example_cuda.cc` - C++ example showing how to call the CUDA 
functions
-- `run_example.py` - Python example showing how to call the functions
-- `run_example.sh` - Convenience script that builds and runs all examples
-
-## Compile without CMake
-
-You can also compile the modules directly using
-flags provided by the `tvm-ffi-config` tool.
-
-```bash
-gcc -shared -fPIC `tvm-ffi-config --cflags`  \
-    src/add_one_c.c -o build/add_one_c.so \
-    `tvm-ffi-config --ldflags` `tvm-ffi-config --libs`
-```
diff --git a/examples/quick_start/run_example.py 
b/examples/quick_start/run_example.py
deleted file mode 100644
index 65c188d..0000000
--- a/examples/quick_start/run_example.py
+++ /dev/null
@@ -1,93 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Quick start script to run tvm-ffi examples from prebuilt libraries."""
-
-import numpy
-import torch
-import tvm_ffi
-
-
-def run_add_one_cpu() -> None:
-    """Load the add_one_cpu module and call the add_one_cpu function."""
-    mod = tvm_ffi.load_module("build/add_one_cpu.so")
-
-    x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
-    y = numpy.empty_like(x)
-    # tvm-ffi automatically handles DLPack compatible tensors
-    # torch tensors can be viewed as ffi::TensorView
-    # in the background
-    mod.add_one_cpu(x, y)
-    print("numpy.result after add_one(x, y)")
-    print(x)
-
-    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
-    y = torch.empty_like(x)
-    # tvm-ffi automatically handles DLPack compatible tensors
-    # torch tensors can be viewed as ffi::TensorView
-    # in the background
-    mod.add_one_cpu(x, y)
-    print("torch.result after add_one(x, y)")
-    print(y)
-
-
-def run_add_one_c() -> None:
-    """Load the add_one_c module and call the add_one_c function."""
-    mod = tvm_ffi.load_module("build/add_one_c.so")
-
-    x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
-    y = numpy.empty_like(x)
-    mod.add_one_c(x, y)
-    print("numpy.result after add_one_c(x, y)")
-    print(x)
-
-    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
-    y = torch.empty_like(x)
-    mod.add_one_c(x, y)
-    print("torch.result after add_one_c(x, y)")
-    print(y)
-
-
-def run_add_one_cuda() -> None:
-    """Load the add_one_cuda module and call the add_one_cuda function."""
-    if not torch.cuda.is_available():
-        return
-
-    mod = tvm_ffi.load_module("build/add_one_cuda.so")
-    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
-    y = torch.empty_like(x)
-
-    stream = torch.cuda.Stream()
-    with torch.cuda.stream(stream):
-        # tvm-ffi automatically handles DLPack compatible tensors
-        # it also handles interactions with torch runtime
-        # torch.cuda.current_stream() will be set and available via 
TVMFFIEnvGetStream
-        # when calling the function
-        mod.add_one_cuda(x, y)
-    stream.synchronize()
-    print("torch.result after mod.add_one_cuda(x, y)")
-    print(y)
-
-
-def main() -> None:
-    """Run the quick start example."""
-    run_add_one_cpu()
-    run_add_one_c()
-    run_add_one_cuda()
-
-
-if __name__ == "__main__":
-    main()
diff --git a/examples/quick_start/src/add_one_c.c 
b/examples/quick_start/src/add_one_c.c
deleted file mode 100644
index 9997027..0000000
--- a/examples/quick_start/src/add_one_c.c
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <tvm/ffi/c_api.h>
-#include <tvm/ffi/extra/c_env_api.h>
-
-// This is a raw C variant of the add_one_cpu function
-// it is used to demonstrate how low-level mechanism works
-// to construct a tvm ffi compatible function
-//
-// This function can also serve as a reference for how to implement
-// a compiler codegen to target tvm ffi
-//
-// if you are looking for a more high-level way to construct a tvm ffi 
compatible function,
-// please refer to the add_one_cpu.cc instead
-/*!
- * \brief Helper code to read DLTensor from TVMFFIAny, can be inlined into 
generated code
- * \param value The TVMFFIAny to read from
- * \param out The DLTensor to read into
- * \return 0 on success, -1 on error
- */
-int ReadDLTensorPtr(const TVMFFIAny* value, DLTensor** out) {
-  if (value->type_index == kTVMFFIDLTensorPtr) {
-    *out = (DLTensor*)(value->v_ptr);
-    return 0;
-  }
-  if (value->type_index != kTVMFFITensor) {
-    // Use TVMFFIErrorSetRaisedFromCStr or TVMFFIErrorSetRaisedFromCStrParts 
to set an
-    // error which will be propagated to the caller
-    TVMFFIErrorSetRaisedFromCStr("ValueError", "Expects a Tensor input");
-    return -1;
-  }
-  *out = (DLTensor*)((char*)(value->v_obj) + sizeof(TVMFFIObject));
-  return 0;
-}
-
-// FFI function implementing add_one operation
-int __tvm_ffi_add_one_c(                                                      
//
-    void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny* result  
//
-) {
-  DLTensor *x, *y;
-  // Extract tensor arguments
-  // return -1 for error, error is set through TVMFFIErrorSetRaisedFromCStr
-  if (ReadDLTensorPtr(&args[0], &x) == -1) return -1;
-  if (ReadDLTensorPtr(&args[1], &y) == -1) return -1;
-
-  // Get current stream for device synchronization (e.g., CUDA)
-  // not needed for CPU, just keep here for demonstration purpose
-  void* stream = TVMFFIEnvGetStream(x->device.device_type, 
x->device.device_id);
-
-  // perform the actual operation
-  for (int i = 0; i < x->shape[0]; ++i) {
-    ((float*)(y->data))[i] = ((float*)(x->data))[i] + 1;
-  }
-  // return 0 for success run
-  return 0;
-}
diff --git a/examples/quick_start/src/run_example.cc 
b/examples/quick_start/src/run_example.cc
deleted file mode 100644
index 4b38343..0000000
--- a/examples/quick_start/src/run_example.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <tvm/ffi/container/tensor.h>
-#include <tvm/ffi/extra/module.h>
-
-// This file shows how to load the same compiled module and interact with it 
in C++
-namespace ffi = tvm::ffi;
-
-struct CPUNDAlloc {
-  void AllocData(DLTensor* tensor) { tensor->data = 
malloc(ffi::GetDataSize(*tensor)); }
-  void FreeData(DLTensor* tensor) { free(tensor->data); }
-};
-
-inline ffi::Tensor Empty(ffi::Shape shape, DLDataType dtype, DLDevice device) {
-  return ffi::Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device);
-}
-
-int main() {
-  // load the module
-  ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so");
-
-  // create an Tensor, alternatively, one can directly pass in a DLTensor*
-  ffi::Tensor x = Empty({5}, DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 
0}));
-  for (int i = 0; i < 5; ++i) {
-    reinterpret_cast<float*>(x.data_ptr())[i] = static_cast<float>(i);
-  }
-
-  ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
-  add_one_cpu(x, x);
-
-  std::cout << "x after add_one_cpu(x, x)" << std::endl;
-  for (int i = 0; i < 5; ++i) {
-    std::cout << reinterpret_cast<float*>(x.data_ptr())[i] << " ";
-  }
-  std::cout << std::endl;
-  return 0;
-}
diff --git a/examples/quick_start/src/run_example_cuda.cc 
b/examples/quick_start/src/run_example_cuda.cc
deleted file mode 100644
index 21e7f49..0000000
--- a/examples/quick_start/src/run_example_cuda.cc
+++ /dev/null
@@ -1,94 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <cuda_runtime.h>
-#include <tvm/ffi/container/tensor.h>
-#include <tvm/ffi/error.h>
-#include <tvm/ffi/extra/module.h>
-
-#include <iostream>
-#include <vector>
-
-namespace ffi = tvm::ffi;
-
-// This example mirrors run_example.cc but keeps all data on the GPU by 
allocating
-// CUDA tensors, invoking the add_one_cuda FFI function, and copying the 
result back
-// to host memory so users can inspect the output.
-struct CUDANDAlloc {
-  void AllocData(DLTensor* tensor) {
-    size_t data_size = ffi::GetDataSize(*tensor);
-    void* ptr = nullptr;
-    cudaError_t err = cudaMalloc(&ptr, data_size);
-    TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaMalloc failed: " << 
cudaGetErrorString(err);
-    tensor->data = ptr;
-  }
-
-  void FreeData(DLTensor* tensor) {
-    if (tensor->data != nullptr) {
-      cudaError_t err = cudaFree(tensor->data);
-      TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaFree failed: " << 
cudaGetErrorString(err);
-      tensor->data = nullptr;
-    }
-  }
-};
-
-inline ffi::Tensor Empty(ffi::Shape shape, DLDataType dtype, DLDevice device) {
-  return ffi::Tensor::FromNDAlloc(CUDANDAlloc(), shape, dtype, device);
-}
-
-int main() {
-  // Load the CUDA implementation that run_example.cu exports during the CMake 
build.
-  ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cuda.so");
-
-  DLDataType f32_dtype{kDLFloat, 32, 1};
-  DLDevice cuda_device{kDLCUDA, 0};
-
-  constexpr int ARRAY_SIZE = 5;
-
-  ffi::Tensor x = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
-  ffi::Tensor y = Empty({ARRAY_SIZE}, f32_dtype, cuda_device);
-
-  std::vector<float> host_x(ARRAY_SIZE);
-  for (int i = 0; i < ARRAY_SIZE; ++i) {
-    host_x[i] = static_cast<float>(i);
-  }
-
-  size_t nbytes = host_x.size() * sizeof(float);
-  cudaError_t err = cudaMemcpy(x.data_ptr(), host_x.data(), nbytes, 
cudaMemcpyHostToDevice);
-  TVM_FFI_ICHECK_EQ(err, cudaSuccess)
-      << "cudaMemcpy host to device failed: " << cudaGetErrorString(err);
-
-  // Call into the FFI function; tensors remain on device because they carry a
-  // kDLCUDA device tag.
-  ffi::Function add_one_cuda = mod->GetFunction("add_one_cuda").value();
-  add_one_cuda(x, y);
-
-  std::vector<float> host_y(host_x.size());
-  err = cudaMemcpy(host_y.data(), y.data_ptr(), nbytes, 
cudaMemcpyDeviceToHost);
-  TVM_FFI_ICHECK_EQ(err, cudaSuccess)
-      << "cudaMemcpy device to host failed: " << cudaGetErrorString(err);
-
-  std::cout << "y after add_one_cuda(x, y)" << std::endl;
-  for (float value : host_y) {
-    std::cout << value << " ";
-  }
-  std::cout << std::endl;
-
-  return 0;
-}
diff --git a/examples/quickstart/CMakeLists.txt 
b/examples/quickstart/CMakeLists.txt
new file mode 100644
index 0000000..71b1b27
--- /dev/null
+++ b/examples/quickstart/CMakeLists.txt
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+cmake_minimum_required(VERSION 3.20)
+project(tvm_ffi_example)
+
+option(EXAMPLE_NAME "Which example to build 
('compile_cpu'/'compile_cuda'/'load_cpp')"
+       "compile_cpu"
+)
+message(STATUS "Building example: ${EXAMPLE_NAME}")
+
+# Run `tvm_ffi.config --cmakedir` to find tvm-ffi package
+find_package(
+  Python
+  COMPONENTS Interpreter
+  REQUIRED
+)
+execute_process(
+  COMMAND "${Python_EXECUTABLE}" -m tvm_ffi.config --cmakedir
+  OUTPUT_STRIP_TRAILING_WHITESPACE
+  OUTPUT_VARIABLE tvm_ffi_ROOT
+)
+find_package(tvm_ffi CONFIG REQUIRED)
+
+if (EXAMPLE_NAME STREQUAL "compile_cpu")
+  # Example 1. C++ `add_one`
+  add_library(add_one_cpu SHARED compile/add_one_cpu.cc)
+  target_link_libraries(add_one_cpu PRIVATE tvm_ffi_header)
+  target_link_libraries(add_one_cpu PRIVATE tvm_ffi_shared)
+  set_target_properties(
+    add_one_cpu
+    PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/"
+               PREFIX ""
+               SUFFIX ".so"
+  )
+elseif (EXAMPLE_NAME STREQUAL "compile_cuda")
+  # Example 2. CUDA `add_one`
+  enable_language(CUDA)
+  add_library(add_one_cuda SHARED compile/add_one_cuda.cu)
+  target_link_libraries(add_one_cuda PRIVATE tvm_ffi_shared)
+  set_target_properties(
+    add_one_cuda
+    PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/"
+               PREFIX ""
+               SUFFIX ".so"
+  )
+elseif (EXAMPLE_NAME STREQUAL "load_cpp")
+  # Example 3. Load C++ shared library
+  add_executable(load_cpp load/load_cpp.cc)
+  target_link_libraries(load_cpp PRIVATE tvm_ffi_header)
+  target_link_libraries(load_cpp PRIVATE tvm_ffi_shared)
+  set_target_properties(
+    load_cpp
+    PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/"
+               PREFIX ""
+               SUFFIX ""
+  )
+else ()
+  message(FATAL_ERROR "Unknown EXAMPLE_NAME option: ${EXAMPLE_NAME}. "
+                      "Expected: 'compile_cpu', 'compile_cuda', 'load_cpp'."
+  )
+endif ()
diff --git a/examples/quickstart/README.md b/examples/quickstart/README.md
new file mode 100644
index 0000000..4e515cc
--- /dev/null
+++ b/examples/quickstart/README.md
@@ -0,0 +1,55 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+# Quick Start Code Example
+
+This directory contains all the source code for 
[tutorial](https://tvm.apache.org/ffi/get_started/quickstart.html).
+
+## Compile and Distribute `add_one_*`
+
+To compile the C++ Example:
+
+```bash
+cmake . -B build -DEXAMPLE_NAME="compile_cpu" -DCMAKE_BUILD_TYPE=RelWithDebInfo
+cmake --build build --config RelWithDebInfo
+```
+
+To compile CUDA Example:
+
+```bash
+cmake . -B build -DEXAMPLE_NAME="compile_cuda" 
-DCMAKE_BUILD_TYPE=RelWithDebInfo
+cmake --build build --config RelWithDebInfo
+```
+
+## Load the Distributed `add_one_*`
+
+To run library loading examples across ML frameworks:
+
+```bash
+python load/load_pytorch.py
+python load/load_jax.py
+python load/load_numpy.py
+python load/load_cupy.py
+```
+
+To run library loading example in C++:
+
+```bash
+cmake . -B build -DEXAMPLE_NAME="load_cpp" -DCMAKE_BUILD_TYPE=RelWithDebInfo
+cmake --build build --config RelWithDebInfo
+build/load_cpp
+```
diff --git a/examples/quick_start/src/add_one_cpu.cc 
b/examples/quickstart/compile/add_one_cpu.cc
similarity index 54%
rename from examples/quick_start/src/add_one_cpu.cc
rename to examples/quickstart/compile/add_one_cpu.cc
index abc188e..22ce5e0 100644
--- a/examples/quick_start/src/add_one_cpu.cc
+++ b/examples/quickstart/compile/add_one_cpu.cc
@@ -16,26 +16,24 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+// [example.begin]
+// File: compile/add_one_cpu.cc
 #include <tvm/ffi/container/tensor.h>
-#include <tvm/ffi/dtype.h>
-#include <tvm/ffi/error.h>
 #include <tvm/ffi/function.h>
 
-namespace tvm_ffi_example {
+namespace tvm_ffi_example_cpu {
 
+/*! \brief Perform vector add one: y = x + 1 (1-D float32) */
 void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
-  // implementation of a library function
-  TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
-  DLDataType f32_dtype{kDLFloat, 32, 1};
-  TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
-  TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
-  TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
-  TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
-  for (int i = 0; i < x.size(0); ++i) {
-    static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
+  int64_t n = x.size(0);
+  float* x_data = static_cast<float*>(x.data_ptr());
+  float* y_data = static_cast<float*>(y.data_ptr());
+  for (int64_t i = 0; i < n; ++i) {
+    y_data[i] = x_data[i] + 1;
   }
 }
 
-// Expose global symbol `add_one_cpu` that follows tvm-ffi abi
-TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example::AddOne);
-}  // namespace tvm_ffi_example
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, tvm_ffi_example_cpu::AddOne);
+}  // namespace tvm_ffi_example_cpu
+// [example.end]
diff --git a/examples/quick_start/src/add_one_cuda.cu 
b/examples/quickstart/compile/add_one_cuda.cu
similarity index 50%
rename from examples/quick_start/src/add_one_cuda.cu
rename to examples/quickstart/compile/add_one_cuda.cu
index 07acfdb..2b743e8 100644
--- a/examples/quick_start/src/add_one_cuda.cu
+++ b/examples/quickstart/compile/add_one_cuda.cu
@@ -16,13 +16,14 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+
+// [example.begin]
+// File: compile/add_one_cuda.cu
 #include <tvm/ffi/container/tensor.h>
-#include <tvm/ffi/dtype.h>
-#include <tvm/ffi/error.h>
 #include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 
-namespace tvm_ffi_example {
+namespace tvm_ffi_example_cuda {
 
 __global__ void AddOneKernel(float* x, float* y, int n) {
   int idx = blockIdx.x * blockDim.x + threadIdx.x;
@@ -31,28 +32,17 @@ __global__ void AddOneKernel(float* x, float* y, int n) {
   }
 }
 
-void AddOneCUDA(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
-  // implementation of a library function
-  TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
-  DLDataType f32_dtype{kDLFloat, 32, 1};
-  TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
-  TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
-  TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
-  TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
-
+void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
   int64_t n = x.size(0);
-  int64_t nthread_per_block = 256;
-  int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
-  // Obtain the current stream from the environment
-  // it will be set to torch.cuda.current_stream() when calling the function
-  // with torch.Tensors
+  float* x_data = static_cast<float*>(x.data_ptr());
+  float* y_data = static_cast<float*>(y.data_ptr());
+  int64_t threads = 256;
+  int64_t blocks = (n + threads - 1) / threads;
   cudaStream_t stream =
       static_cast<cudaStream_t>(TVMFFIEnvGetStream(x.device().device_type, 
x.device().device_id));
-  // launch the kernel
-  AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x.data_ptr()),
-                                                         
static_cast<float*>(y.data_ptr()), n);
+  AddOneKernel<<<blocks, threads, 0, stream>>>(x_data, y_data, n);
 }
 
-// Expose global symbol `add_one_cpu` that follows tvm-ffi abi
-TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA);
-}  // namespace tvm_ffi_example
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example_cuda::AddOne);
+}  // namespace tvm_ffi_example_cuda
+// [example.end]
diff --git a/examples/quickstart/load/load_cpp.cc 
b/examples/quickstart/load/load_cpp.cc
new file mode 100644
index 0000000..afa4343
--- /dev/null
+++ b/examples/quickstart/load/load_cpp.cc
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+// [example.begin]
+// File: load/load_cpp.cc
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/extra/module.h>
+
+namespace {
+namespace ffi = tvm::ffi;
+
+/************* Main logics *************/
+
+/*!
+ * \brief Main logics of library loading and function calling.
+ * \param x The input tensor.
+ * \param y The output tensor.
+ */
+void Run(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
+  // Load shared library `build/add_one_cpu.so`
+  ffi::Module mod = ffi::Module::LoadFromFile("build/add_one_cpu.so");
+  // Look up `add_one_cpu` function
+  ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
+  // Call the function
+  add_one_cpu(x, y);
+}
+
+/************* Auxiliary Logics *************/
+
+/*!
+ * \brief Allocate a 1D float32 `tvm::ffi::Tensor` on CPU from an braced 
initializer list.
+ * \param data The input data.
+ * \return The allocated Tensor.
+ */
+ffi::Tensor Alloc1DTensor(std::initializer_list<float> data) {
+  struct CPUAllocator {
+    void AllocData(DLTensor* tensor) {
+      tensor->data = std::malloc(tensor->shape[0] * sizeof(float));
+    }
+    void FreeData(DLTensor* tensor) { std::free(tensor->data); }
+  };
+  DLDataType f32 = DLDataType({kDLFloat, 32, 1});
+  DLDevice cpu = DLDevice({kDLCPU, 0});
+  int64_t n = static_cast<int64_t>(data.size());
+  ffi::Tensor x = ffi::Tensor::FromNDAlloc(CPUAllocator(), {n}, f32, cpu);
+  float* x_data = static_cast<float*>(x.data_ptr());
+  for (float v : data) {
+    *x_data++ = v;
+  }
+  return x;
+}
+
+}  // namespace
+
+int main() {
+  ffi::Tensor x = Alloc1DTensor({1, 2, 3, 4, 5});
+  ffi::Tensor y = Alloc1DTensor({0, 0, 0, 0, 0});
+  Run(x, y);
+  std::cout << "[ ";
+  const float* y_data = static_cast<const float*>(y.data_ptr());
+  for (int i = 0; i < 5; ++i) {
+    std::cout << y_data[i] << " ";
+  }
+  std::cout << "]" << std::endl;
+  return 0;
+}
+// [example.end]
diff --git a/examples/quick_start/run_example.sh 
b/examples/quickstart/load/load_cupy.py
old mode 100755
new mode 100644
similarity index 64%
copy from examples/quick_start/run_example.sh
copy to examples/quickstart/load/load_cupy.py
index e6ada48..581dba1
--- a/examples/quick_start/run_example.sh
+++ b/examples/quickstart/load/load_cupy.py
@@ -1,4 +1,3 @@
-#!/bin/bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,25 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-set -ex
+# fmt: off
+# ruff: noqa
+# mypy: ignore-errors
+# [example.begin]
+# File: load/load_cupy.py
+import tvm_ffi
+mod = tvm_ffi.load_module("build/add_one_cuda.so")
 
-if command -v ninja >/dev/null 2>&1; then
-       generator="Ninja"
-else
-       echo "Ninja not found, falling back to Unix Makefiles" >&2
-       generator="Unix Makefiles"
-fi
-
-rm -rf build/CMakeCache.txt
-cmake -G "$generator" -B build -S .
-cmake --build build --parallel
-
-# running python example
-python run_example.py
-
-# running c++ example
-./build/run_example
-
-if [ -x ./build/run_example_cuda ]; then
-       ./build/run_example_cuda
-fi
+import cupy as cp
+x = cp.array([1, 2, 3, 4, 5], dtype=cp.float32)
+y = cp.empty_like(x)
+mod.add_one_cuda(x, y)
+print(y)
+# [example.end]
diff --git a/examples/quick_start/run_example.sh 
b/examples/quickstart/load/load_numpy.py
old mode 100755
new mode 100644
similarity index 64%
copy from examples/quick_start/run_example.sh
copy to examples/quickstart/load/load_numpy.py
index e6ada48..f176a31
--- a/examples/quick_start/run_example.sh
+++ b/examples/quickstart/load/load_numpy.py
@@ -1,4 +1,3 @@
-#!/bin/bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,25 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-set -ex
+# fmt: off
+# ruff: noqa
+# mypy: ignore-errors
+# [example.begin]
+# File: load/load_numpy.py
+import tvm_ffi
+mod = tvm_ffi.load_module("build/add_one_cpu.so")
 
-if command -v ninja >/dev/null 2>&1; then
-       generator="Ninja"
-else
-       echo "Ninja not found, falling back to Unix Makefiles" >&2
-       generator="Unix Makefiles"
-fi
-
-rm -rf build/CMakeCache.txt
-cmake -G "$generator" -B build -S .
-cmake --build build --parallel
-
-# running python example
-python run_example.py
-
-# running c++ example
-./build/run_example
-
-if [ -x ./build/run_example_cuda ]; then
-       ./build/run_example_cuda
-fi
+import numpy as np
+x = np.array([1, 2, 3, 4, 5], dtype=np.float32)
+y = np.empty_like(x)
+mod.add_one_cpu(x, y)
+print(y)
+# [example.end]
diff --git a/examples/quick_start/run_example.sh 
b/examples/quickstart/load/load_pytorch.py
old mode 100755
new mode 100644
similarity index 64%
copy from examples/quick_start/run_example.sh
copy to examples/quickstart/load/load_pytorch.py
index e6ada48..5d9c8a0
--- a/examples/quick_start/run_example.sh
+++ b/examples/quickstart/load/load_pytorch.py
@@ -1,4 +1,3 @@
-#!/bin/bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -15,25 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-set -ex
+# fmt: off
+# ruff: noqa
+# mypy: ignore-errors
+# [example.begin]
+# File: load/load_pytorch.py
+# Step 1. Load `build/add_one_cuda.so`
+import tvm_ffi
+mod = tvm_ffi.load_module("build/add_one_cuda.so")
 
-if command -v ninja >/dev/null 2>&1; then
-       generator="Ninja"
-else
-       echo "Ninja not found, falling back to Unix Makefiles" >&2
-       generator="Unix Makefiles"
-fi
-
-rm -rf build/CMakeCache.txt
-cmake -G "$generator" -B build -S .
-cmake --build build --parallel
-
-# running python example
-python run_example.py
-
-# running c++ example
-./build/run_example
-
-if [ -x ./build/run_example_cuda ]; then
-       ./build/run_example_cuda
-fi
+# Step 2. Run `mod.add_one_cuda` with PyTorch
+import torch
+x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
+y = torch.empty_like(x)
+mod.add_one_cuda(x, y)
+print(y)
+# [example.end]
diff --git a/examples/quickstart/raw_compile.sh 
b/examples/quickstart/raw_compile.sh
new file mode 100755
index 0000000..d0fcbcc
--- /dev/null
+++ b/examples/quickstart/raw_compile.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# shellcheck disable=SC2046
+set -ex
+
+BUILD_DIR=build
+mkdir -p $BUILD_DIR
+
+# Example 1. Compile C++ `add_one_cpu.cc` to shared library `add_one_cpu.so`
+# [cpp_compile.begin]
+g++ -shared -O3 compile/add_one_cpu.cc  \
+    -fPIC -fvisibility=hidden           \
+    $(tvm-ffi-config --cxxflags)        \
+    $(tvm-ffi-config --ldflags)         \
+    $(tvm-ffi-config --libs)            \
+    -o $BUILD_DIR/add_one_cpu.so
+# [cpp_compile.end]
+
+# Example 2. Compile CUDA `add_one_cuda.cu` to shared library `add_one_cuda.so`
+
+if command -v nvcc >/dev/null 2>&1; then
+# [cuda_compile.begin]
+nvcc -shared -O3 compile/add_one_cuda.cu    \
+    -Xcompiler -fPIC,-fvisibility=hidden    \
+    $(tvm-ffi-config --cxxflags)            \
+    $(tvm-ffi-config --ldflags)             \
+    $(tvm-ffi-config --libs)                \
+    -o $BUILD_DIR/add_one_cuda.so
+# [cuda_compile.end]
+fi
+
+# Example 3. Load and run `add_one_cpu.so` in C++
+
+if [ -f "$BUILD_DIR/add_one_cpu.so" ]; then
+# [load_cpp.begin]
+g++ -fvisibility=hidden -O3                 \
+    load/load_cpp.cc                        \
+    $(tvm-ffi-config --cxxflags)            \
+    $(tvm-ffi-config --ldflags)             \
+    $(tvm-ffi-config --libs)                \
+    -Wl,-rpath,$(tvm-ffi-config --libdir)   \
+    -o build/load_cpp
+
+build/load_cpp
+# [load_cpp.end]
+fi
diff --git a/examples/quick_start/run_example.sh 
b/examples/quickstart/run_all_cpu.sh
similarity index 64%
copy from examples/quick_start/run_example.sh
copy to examples/quickstart/run_all_cpu.sh
index e6ada48..577280b 100755
--- a/examples/quick_start/run_example.sh
+++ b/examples/quickstart/run_all_cpu.sh
@@ -17,23 +17,14 @@
 # under the License.
 set -ex
 
-if command -v ninja >/dev/null 2>&1; then
-       generator="Ninja"
-else
-       echo "Ninja not found, falling back to Unix Makefiles" >&2
-       generator="Unix Makefiles"
-fi
+# To compile `compile/add_one_cpu.cc` to shared library `build/add_one_cpu.so`
+cmake . -B build -DEXAMPLE_NAME="compile_cpu" -DCMAKE_BUILD_TYPE=RelWithDebInfo
+cmake --build build --config RelWithDebInfo
 
-rm -rf build/CMakeCache.txt
-cmake -G "$generator" -B build -S .
-cmake --build build --parallel
+# To load and run `add_one_cpu.so` in NumPy
+python load/load_numpy.py
 
-# running python example
-python run_example.py
-
-# running c++ example
-./build/run_example
-
-if [ -x ./build/run_example_cuda ]; then
-       ./build/run_example_cuda
-fi
+# To load and run `add_one_cpu.so` in C++
+cmake . -B build -DEXAMPLE_NAME="load_cpp" -DCMAKE_BUILD_TYPE=RelWithDebInfo
+cmake --build build --config RelWithDebInfo
+build/load_cpp
diff --git a/examples/quick_start/run_example.sh 
b/examples/quickstart/run_all_cuda.sh
similarity index 66%
rename from examples/quick_start/run_example.sh
rename to examples/quickstart/run_all_cuda.sh
index e6ada48..d27bf0b 100755
--- a/examples/quick_start/run_example.sh
+++ b/examples/quickstart/run_all_cuda.sh
@@ -17,23 +17,12 @@
 # under the License.
 set -ex
 
-if command -v ninja >/dev/null 2>&1; then
-       generator="Ninja"
-else
-       echo "Ninja not found, falling back to Unix Makefiles" >&2
-       generator="Unix Makefiles"
-fi
+# To compile `compile/add_one_cuda.cu` to shared library 
`build/add_one_cuda.so`
+cmake . -B build -DEXAMPLE_NAME="compile_cuda" 
-DCMAKE_BUILD_TYPE=RelWithDebInfo
+cmake --build build --config RelWithDebInfo
 
-rm -rf build/CMakeCache.txt
-cmake -G "$generator" -B build -S .
-cmake --build build --parallel
+# To load and run `add_one_cuda.so` in PyTorch
+python load/load_pytorch.py
 
-# running python example
-python run_example.py
-
-# running c++ example
-./build/run_example
-
-if [ -x ./build/run_example_cuda ]; then
-       ./build/run_example_cuda
-fi
+# To load and run `add_one_cuda.so` in CuPy
+python load/load_cupy.py
diff --git a/pyproject.toml b/pyproject.toml
index 013566b..df528ee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -265,3 +265,5 @@ ignore_missing_imports = true
 
 [tool.uv.dependency-groups]
 docs = { requires-python = ">=3.13" }
+
+[tool.setuptools_scm]

Reply via email to