tqchen commented on code in PR #283:
URL: https://github.com/apache/tvm-ffi/pull/283#discussion_r2553121177


##########
include/tvm/ffi/extra/cubin_launcher.h:
##########
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cubin_launcher.h
+ * \brief CUDA CUBIN launcher utility for loading and executing CUDA kernels.
+ *
+ * This header provides a lightweight C++ wrapper around CUDA Driver API
+ * for loading CUBIN modules and launching kernels. It supports:
+ * - Loading CUBIN from memory (embedded data) or files
+ * - Multi-GPU execution using CUDA primary contexts
+ * - Kernel parameter management and launch configuration
+ */
+#ifndef TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#define TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+
+#include <cuda.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+
+#include <cstdint>
+#include <cstring>
+#include <fstream>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Macro for checking CUDA driver API errors.
+ *
+ * This macro checks the return value of CUDA driver API calls and throws
+ * a RuntimeError with detailed error information if the call fails.
+ *
+ * \param stmt The CUDA driver API call to check.
+ */
+#define TVM_FFI_CHECK_CUDA_DRIVER_ERROR(stmt)                                  
                    \
+  do {                                                                         
                    \
+    CUresult __err = (stmt);                                                   
                    \
+    if (__err != CUDA_SUCCESS) {                                               
                    \
+      const char* __err_name = nullptr;                                        
                    \
+      const char* __err_str = nullptr;                                         
                    \
+      cuGetErrorName(__err, &__err_name);                                      
                    \
+      cuGetErrorString(__err, &__err_str);                                     
                    \
+      TVM_FFI_THROW(RuntimeError) << "CUDA Driver Error: "                     
                    \
+                                  << (__err_name ? __err_name : "UNKNOWN") << 
" ("                 \
+                                  << static_cast<int>(__err)                   
                    \
+                                  << "): " << (__err_str ? __err_str : "No 
description") << " at " \
+                                  << __FILE__ << ":" << __LINE__;              
                    \
+    }                                                                          
                    \
+  } while (0)
+
+/*!
+ * \brief A simple 3D dimension type for CUDA kernel launch configuration.
+ *
+ * This struct mimics the behavior of dim3 from CUDA Runtime API, but works
+ * with the CUDA Driver API. It can be constructed from 1, 2, or 3 dimensions.
+ */
+struct dim3 {
+  /*! \brief X dimension (number of blocks in x-direction or threads in 
x-direction) */
+  unsigned int x;
+  /*! \brief Y dimension (number of blocks in y-direction or threads in 
y-direction) */
+  unsigned int y;
+  /*! \brief Z dimension (number of blocks in z-direction or threads in 
z-direction) */
+  unsigned int z;
+
+  /*! \brief Default constructor initializes to (1, 1, 1) */
+  dim3() : x(1), y(1), z(1) {}
+
+  /*! \brief Construct with x dimension, y and z default to 1 */
+  explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {}
+
+  /*! \brief Construct with x and y dimensions, z defaults to 1 */
+  dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {}
+
+  /*! \brief Construct with all three dimensions */
+  dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), 
z(z_) {}
+};
+
+// Forward declaration
+class CubinKernel;
+
+/*!
+ * \brief CUDA CUBIN module loader and manager.
+ *
+ * This class provides a RAII wrapper around CUDA driver API's library 
management.
+ * It loads a CUBIN module from memory or file and manages the library handle.
+ * Supports multi-GPU execution using CUDA primary contexts.
+ */
+class CubinModule {
+ public:
+  /*!
+   * \brief Load CUBIN module from memory.
+   *
+   * \param data Pointer to CUBIN binary data in memory.
+   * \param size Size of the CUBIN binary data in bytes.
+   * \note Calls cuInit(0) to ensure CUDA is initialized.
+   */
+  CubinModule(const void* data, uint64_t size) {
+    TVM_FFI_CHECK_CUDA_DRIVER_ERROR(cuInit(0));
+    TVM_FFI_CHECK_CUDA_DRIVER_ERROR(
+        cuLibraryLoadData(&library_, data, nullptr, nullptr, 0, nullptr, 
nullptr, 0));
+  }
+
+  /*!
+   * \brief Load CUBIN module from file.
+   *
+   * \param filename Path to the CUBIN file.
+   * \note This reads the entire file into memory and then loads it.
+   */
+  explicit CubinModule(const char* filename) {

Review Comment:
   to keep dep minimal let us consider always use the load from inline binary, 
add a constructor from tvm::ffi::Bytes



##########
include/tvm/ffi/extra/cubin_launcher.h:
##########
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cubin_launcher.h
+ * \brief CUDA CUBIN launcher utility for loading and executing CUDA kernels.
+ *
+ * This header provides a lightweight C++ wrapper around CUDA Driver API
+ * for loading CUBIN modules and launching kernels. It supports:
+ * - Loading CUBIN from memory (embedded data) or files
+ * - Multi-GPU execution using CUDA primary contexts
+ * - Kernel parameter management and launch configuration
+ */
+#ifndef TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#define TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+
+#include <cuda.h>

Review Comment:
   let us move to extra/cuda/cubin_launcher.h



##########
examples/cubin_launcher/example_triton_cubin.py:
##########
@@ -0,0 +1,221 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Single-file Triton example: define kernel, compile to CUBIN, load via 
inline C++.
+
+This script:
+1. Embeds a minimal Triton kernel definition (elementwise square)
+2. Compiles it to a CUBIN using the Triton runtime API
+3. Defines C++ code inline using tvm_ffi.cpp.load_inline to load the CUBIN
+4. Launches the kernel through the TVM-FFI exported function pointer
+
+Notes:
+- Requires `triton` to be installed in the Python environment.
+
+"""
+
+from __future__ import annotations
+
+import sys
+import traceback
+from pathlib import Path
+
+import torch
+import triton  # type: ignore[import-not-found]
+import triton.language as tl  # type: ignore[import-not-found]
+from tvm_ffi import cpp
+
+
+def _compile_triton_to_cubin() -> tuple[bytes, str]:
+    """Define a Triton kernel in-process and compile it to a CUBIN file.
+
+    The kernel is named `square_kernel` and computes y[i] = x[i] * x[i].
+    Returns (cubin_bytes, ptx_source)
+    """
+
+    # Define the kernel dynamically
+    @triton.jit
+    def square_kernel(X_ptr, Y_ptr, n, BLOCK: tl.constexpr = 1024):  # noqa
+        pid = tl.program_id(0)
+        start = pid * BLOCK
+        offsets = start + tl.arange(0, BLOCK)
+        mask = offsets < n
+        x = tl.load(X_ptr + offsets, mask=mask, other=0.0)
+        y = x * x
+        tl.store(Y_ptr + offsets, y, mask=mask)
+
+    # Trigger kernel compilation by doing a dummy call
+    x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda")
+    y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda")
+    square_kernel[1, 1](x_dummy, y_dummy, 1024)
+
+    # Extract compiled CUBIN from the device cache
+    device_caches = square_kernel.device_caches
+    device_id = next(iter(device_caches.keys()))
+    cache_tuple = device_caches[device_id]
+    compiled_kernel = next(iter(cache_tuple[0].values()))
+
+    # Get CUBIN bytes and PTX source
+    cubin_bytes = compiled_kernel.kernel
+    ptx_source = (
+        compiled_kernel.asm.get("ptx", "")
+        if hasattr(compiled_kernel.asm, "get")
+        else str(compiled_kernel.asm)
+    )
+
+    return cubin_bytes, ptx_source
+
+
+def main() -> int:  # noqa: PLR0911,PLR0915
+    """Load and launch Triton kernel through TVM-FFI."""
+    print("Example: Triton (inline) -> CUBIN -> C++ (inline) -> TVM-FFI")
+    print("=" * 60)
+
+    if not torch.cuda.is_available():
+        print("[ERROR] CUDA is not available")
+        return 1
+
+    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
+    print(f"PyTorch version: {torch.__version__}\n")
+
+    base = Path(__file__).resolve().parent
+    build_dir = base / "build"
+    build_dir.mkdir(parents=True, exist_ok=True)
+
+    # Compile Triton kernel to CUBIN
+    try:
+        print("Compiling Triton kernel to CUBIN...")
+        cubin_bytes, ptx_source = _compile_triton_to_cubin()
+        print(f"Compiled CUBIN: {len(cubin_bytes)} bytes")
+        print("\n" + "=" * 60)
+        print("PTX Source:")
+        print("=" * 60)
+        print(ptx_source)
+        print("=" * 60 + "\n")
+    except Exception as e:
+        print(f"[ERROR] Failed to compile Triton kernel: {e}")
+        traceback.print_exc()
+        return 2
+
+    # Write CUBIN to file

Review Comment:
   would be nice to directly do tvm_ffi.cpp.load_inline(cpp_source, 
embed_cubin={"env": triton_cubin});
   



##########
include/tvm/ffi/extra/cubin_launcher.h:
##########
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cubin_launcher.h
+ * \brief CUDA CUBIN launcher utility for loading and executing CUDA kernels.
+ *
+ * This header provides a lightweight C++ wrapper around CUDA Driver API
+ * for loading CUBIN modules and launching kernels. It supports:
+ * - Loading CUBIN from memory (embedded data) or files
+ * - Multi-GPU execution using CUDA primary contexts
+ * - Kernel parameter management and launch configuration
+ */
+#ifndef TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#define TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+
+#include <cuda.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+
+#include <cstdint>
+#include <cstring>
+#include <fstream>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Macro for checking CUDA driver API errors.
+ *
+ * This macro checks the return value of CUDA driver API calls and throws
+ * a RuntimeError with detailed error information if the call fails.
+ *
+ * \param stmt The CUDA driver API call to check.
+ */
+#define TVM_FFI_CHECK_CUDA_DRIVER_ERROR(stmt)                                  
                    \
+  do {                                                                         
                    \
+    CUresult __err = (stmt);                                                   
                    \
+    if (__err != CUDA_SUCCESS) {                                               
                    \
+      const char* __err_name = nullptr;                                        
                    \
+      const char* __err_str = nullptr;                                         
                    \
+      cuGetErrorName(__err, &__err_name);                                      
                    \
+      cuGetErrorString(__err, &__err_str);                                     
                    \
+      TVM_FFI_THROW(RuntimeError) << "CUDA Driver Error: "                     
                    \
+                                  << (__err_name ? __err_name : "UNKNOWN") << 
" ("                 \
+                                  << static_cast<int>(__err)                   
                    \
+                                  << "): " << (__err_str ? __err_str : "No 
description") << " at " \
+                                  << __FILE__ << ":" << __LINE__;              
                    \

Review Comment:
   no need to have file and line since they are part of TVM_FFI_THROW



##########
include/tvm/ffi/extra/cubin_launcher.h:
##########
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/extra/cubin_launcher.h
+ * \brief CUDA CUBIN launcher utility for loading and executing CUDA kernels.
+ *
+ * This header provides a lightweight C++ wrapper around CUDA Driver API
+ * for loading CUBIN modules and launching kernels. It supports:
+ * - Loading CUBIN from memory (embedded data) or files
+ * - Multi-GPU execution using CUDA primary contexts
+ * - Kernel parameter management and launch configuration
+ */
+#ifndef TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+#define TVM_FFI_EXTRA_CUBIN_LAUNCHER_H_
+
+#include <cuda.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+
+#include <cstdint>
+#include <cstring>
+#include <fstream>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Macro for checking CUDA driver API errors.
+ *
+ * This macro checks the return value of CUDA driver API calls and throws
+ * a RuntimeError with detailed error information if the call fails.
+ *
+ * \param stmt The CUDA driver API call to check.
+ */
+#define TVM_FFI_CHECK_CUDA_DRIVER_ERROR(stmt)                                  
                    \
+  do {                                                                         
                    \
+    CUresult __err = (stmt);                                                   
                    \
+    if (__err != CUDA_SUCCESS) {                                               
                    \
+      const char* __err_name = nullptr;                                        
                    \
+      const char* __err_str = nullptr;                                         
                    \
+      cuGetErrorName(__err, &__err_name);                                      
                    \
+      cuGetErrorString(__err, &__err_str);                                     
                    \
+      TVM_FFI_THROW(RuntimeError) << "CUDA Driver Error: "                     
                    \
+                                  << (__err_name ? __err_name : "UNKNOWN") << 
" ("                 \
+                                  << static_cast<int>(__err)                   
                    \
+                                  << "): " << (__err_str ? __err_str : "No 
description") << " at " \
+                                  << __FILE__ << ":" << __LINE__;              
                    \
+    }                                                                          
                    \
+  } while (0)
+
+/*!
+ * \brief A simple 3D dimension type for CUDA kernel launch configuration.
+ *
+ * This struct mimics the behavior of dim3 from CUDA Runtime API, but works
+ * with the CUDA Driver API. It can be constructed from 1, 2, or 3 dimensions.
+ */
+struct dim3 {
+  /*! \brief X dimension (number of blocks in x-direction or threads in 
x-direction) */
+  unsigned int x;
+  /*! \brief Y dimension (number of blocks in y-direction or threads in 
y-direction) */
+  unsigned int y;
+  /*! \brief Z dimension (number of blocks in z-direction or threads in 
z-direction) */
+  unsigned int z;
+
+  /*! \brief Default constructor initializes to (1, 1, 1) */
+  dim3() : x(1), y(1), z(1) {}
+
+  /*! \brief Construct with x dimension, y and z default to 1 */
+  explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {}
+
+  /*! \brief Construct with x and y dimensions, z defaults to 1 */
+  dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {}
+
+  /*! \brief Construct with all three dimensions */
+  dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), 
z(z_) {}
+};
+
+// Forward declaration
+class CubinKernel;
+
+/*!
+ * \brief CUDA CUBIN module loader and manager.
+ *
+ * This class provides a RAII wrapper around CUDA driver API's library 
management.
+ * It loads a CUBIN module from memory or file and manages the library handle.
+ * Supports multi-GPU execution using CUDA primary contexts.
+ */
+class CubinModule {
+ public:
+  /*!
+   * \brief Load CUBIN module from memory.
+   *
+   * \param data Pointer to CUBIN binary data in memory.
+   * \param size Size of the CUBIN binary data in bytes.
+   * \note Calls cuInit(0) to ensure CUDA is initialized.
+   */
+  CubinModule(const void* data, uint64_t size) {
+    TVM_FFI_CHECK_CUDA_DRIVER_ERROR(cuInit(0));
+    TVM_FFI_CHECK_CUDA_DRIVER_ERROR(
+        cuLibraryLoadData(&library_, data, nullptr, nullptr, 0, nullptr, 
nullptr, 0));

Review Comment:
   would be cross check if the data needs to be retained in memory, and 
document accordingly



##########
examples/cubin_launcher/src/lib_embedded.cc:
##########
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file examples/cubin_launcher/src/lib_embedded.cc
+ * \brief TVM-FFI library with embedded CUBIN kernels.
+ *
+ * This library exports TVM-FFI functions to launch CUDA kernels from
+ * embedded CUBIN data.
+ */
+
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cubin_launcher.h>
+#include <tvm/ffi/function.h>
+
+#include <cstdint>
+#include <memory>
+
+// External symbols for embedded CUBIN data (linked via objcopy)
+extern "C" const char __cubin_data[];
+extern "C" const char __cubin_data_end[];
+
+// Calculate size from the symbols
+static const uint64_t cubin_data_size =
+    reinterpret_cast<const char*>(&__cubin_data_end) - reinterpret_cast<const 
char*>(&__cubin_data);
+
+// Global CUBIN module and kernels (initialized on first use)
+static std::unique_ptr<tvm::ffi::CubinModule> g_cubin_module;
+static std::unique_ptr<tvm::ffi::CubinKernel> g_add_one_kernel;
+static std::unique_ptr<tvm::ffi::CubinKernel> g_mul_two_kernel;
+
+// Initialize the CUBIN module and kernels
+void InitializeCubinModule() {
+  if (g_cubin_module == nullptr) {
+    g_cubin_module = std::make_unique<tvm::ffi::CubinModule>(__cubin_data, 
cubin_data_size);
+    g_add_one_kernel = 
std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["add_one_cuda"]);
+    g_mul_two_kernel = 
std::make_unique<tvm::ffi::CubinKernel>((*g_cubin_module)["mul_two_cuda"]);
+  }
+}
+
+namespace cubin_embedded {
+
+/*!
+ * \brief Launch add_one_cuda kernel on input tensor.
+ * \param x Input tensor (float32, 1D)
+ * \param y Output tensor (float32, 1D, same shape as x)
+ */
+void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
+  InitializeCubinModule();

Review Comment:
   consider use a different pattern via static single singleton
   
   ```c++
   // maybe as macro TVM_FFI_CUBIN_EMBED(env);
   // use env as the indicator key in case we want to enbed other cubins
   extern "C" static const char __tvm_ffi__cubin_env[];
   extern "C" static const char __tvm_ffi__cubin_env_end[];
   
   namespace {
   struct CubinModule_env {
       tvm::ffi::CubinModule mod {__tvm_ffi__cubin_env, 
__tvm_ffi__cubin_env_end};  
   
       static CubinModule_env* Global() {
            static CubinModule_env* inst;
             return &inst;
       }
   };
   }  // anonymous namespace to avoid symbol conflict
   
   void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
      // maybe as macro 
      //  static auto kernel =TVM_FFI_CUBIN_GET_KERNEL(env, "mul_two_cuda");
      static tvm::ffi::CubinKernel kernel = 
CubinModule_env::Global()->mod["mul_two_cuda"];
      kernel->add_two_cuda.launch(...);
   }
   ```



##########
examples/cubin_launcher/src/lib_embedded.cc:
##########
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file examples/cubin_launcher/src/lib_embedded.cc
+ * \brief TVM-FFI library with embedded CUBIN kernels.
+ *
+ * This library exports TVM-FFI functions to launch CUDA kernels from
+ * embedded CUBIN data.
+ */
+
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cubin_launcher.h>
+#include <tvm/ffi/function.h>
+
+#include <cstdint>
+#include <memory>
+
+// External symbols for embedded CUBIN data (linked via objcopy)
+extern "C" const char __cubin_data[];
+extern "C" const char __cubin_data_end[];

Review Comment:
   one thing to consider is how to avoid symbol conflict when we link multiple 
such embedded files,
   
   here is what gemini suggest, do it in the following steps, maybe we can have 
a tool via python `tvm.ffi.cpp.embed_cubin(output_object, input_object, cubin, 
key="env")` to do that. we can also provide a cmake macro for those who like to 
use tvm-ffi cmake
   
   - Step 1: Compile the C++ Source (source.o) Compile your C++ code normally. 
Ensure you declare the symbols as extern so the compiler creates an "undefined 
reference" (a hole to be filled later). g++ -c source.cc -o source.o
   - Step 2: Convert Binary to Object (blob_raw.o) Use objcopy to wrap the raw 
binary file into a linkable object file. This creates Global symbols by 
default. objcopy -I binary -O elf64-x86-64 kernel.cubin blob_raw.o
   - Step 3: Rename Symbols (blob_renamed.o) Change the auto-generated names 
(e.g., _binary_kernel_start) to the specific names your C++ code expects 
(__tvm_ffi__...). objcopy --redefine-sym old_name=new_name blob_raw.o 
blob_renamed.o
   - Step 4: Partial Link / Merge (merged.o) Use ld -r to fuse the code 
(source.o) and the data (blob_renamed.o) together. This resolves the "undefined 
reference." ld -r source.o blob_renamed.o -o merged.o
   - Step 5: Localize Symbols (final.o) Crucial Last Step: Now that the code 
and data are in the same file, use objcopy to change the symbols from Global to 
Local. This hides them from the outside world (Internal Linkage). objcopy 
--localize-symbol=__tvm_ffi__cubin_data merged.o final.o
   
   ```make
   # ==========================================
   # Configuration
   # ==========================================
   
   # Files
   BINARY_FILE := kernel.cubin
   SOURCE_FILE := source.cc
   OUTPUT_OBJ  := final_module.o
   
   # The symbol names your C++ code uses (extern "C")
   SYM_NAME      := __tvm_ffi__cubin_data
   SYM_NAME_END  := __tvm_ffi__cubin_data_end
   
   # Compiler settings
   CXX      := g++
   CXXFLAGS := -O2 -Wall -fPIC
   LD       := ld
   OBJCOPY  := objcopy
   
   # ------------------------------------------
   # Internal Calculation for objcopy default names
   # objcopy converts "kernel.cubin" -> "_binary_kernel_cubin_start"
   # We replace dots and slashes with underscores to match objcopy's behavior.
   # ------------------------------------------
   BINARY_FLAT   := $(subst /,_,$(subst .,_,$(BINARY_FILE)))
   DEFAULT_START := _binary_$(BINARY_FLAT)_start
   DEFAULT_END   := _binary_$(BINARY_FLAT)_end
   DEFAULT_SIZE  := _binary_$(BINARY_FLAT)_size
   
   # ==========================================
   # Rules
   # ==========================================
   
   .PHONY: all clean check
   
   all: $(OUTPUT_OBJ)
   
   # 1. Compile the C++ source into an object file.
   #    (Contains undefined references to the symbols)
   source.o: $(SOURCE_FILE)
        @echo "[1/5] Compiling Source..."
        $(CXX) $(CXXFLAGS) -c $< -o $@
   
   # 2. Convert the raw binary into an ELF object file.
   #    (Symbols are Global and named _binary_kernel_cubin_start)
   blob_raw.o: $(BINARY_FILE)
        @echo "[2/5] Converting Binary to Object..."
        $(OBJCOPY) -I binary -O elf64-x86-64 $< $@
   
   # 3. Rename the symbols to match your C++ declaration.
   #    (Still Global, but names match __tvm_ffi__...)
   blob_renamed.o: blob_raw.o
        @echo "[3/5] Renaming Symbols..."
        $(OBJCOPY) \
                --redefine-sym $(DEFAULT_START)=$(SYM_NAME) \
                --redefine-sym $(DEFAULT_END)=$(SYM_NAME_END) \
                --strip-symbol=$(DEFAULT_SIZE) \
                $< $@
   
   # 4. Partial Link (Merge).
   #    (Fuses source.o and blob.o. Code can now see Data.)
   merged.o: source.o blob_renamed.o
        @echo "[4/5] Linking (Partial Merge)..."
        $(LD) -r source.o blob_renamed.o -o $@
   
   # 5. Localize Symbols.
   #    (Hides the symbols from the outside world. Global D -> Local d)
   $(OUTPUT_OBJ): merged.o
        @echo "[5/5] Finalizing: Hiding Symbols..."
        $(OBJCOPY) \
                --localize-symbol=$(SYM_NAME) \
                --localize-symbol=$(SYM_NAME_END) \
                $< $@
        @echo "Success! Created $(OUTPUT_OBJ)"
   
   # ==========================================
   # Utilities
   # ==========================================
   
   # Helper to prove the symbols are local
   check: $(OUTPUT_OBJ)
        @echo "Checking symbol visibility in $(OUTPUT_OBJ)..."
        @echo "Look for lowercase 'd' (local data) or 'r' (local read-only):"
        @nm $(OUTPUT_OBJ) | grep __tvm_ffi__
   
   clean:
        rm -f *.o
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to