yaoyaoding commented on code in PR #283: URL: https://github.com/apache/tvm-ffi/pull/283#discussion_r2557668685
########## examples/cubin_launcher/example_triton_cubin.py: ########## @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Single-file Triton example: define kernel, compile to CUBIN, load via inline C++. + +This script: +1. Embeds a minimal Triton kernel definition (elementwise square) +2. Compiles it to a CUBIN using the Triton runtime API +3. Defines C++ code inline using tvm_ffi.cpp.load_inline to load the CUBIN +4. Launches the kernel through the TVM-FFI exported function pointer + +Notes: +- Requires `triton` to be installed in the Python environment. + +""" + +from __future__ import annotations + +import sys +import traceback +from pathlib import Path + +import torch +import triton # type: ignore[import-not-found] +import triton.language as tl # type: ignore[import-not-found] +from tvm_ffi import cpp + + +def _compile_triton_to_cubin() -> tuple[bytes, str]: + """Define a Triton kernel in-process and compile it to a CUBIN file. + + The kernel is named `square_kernel` and computes y[i] = x[i] * x[i]. + Returns (cubin_bytes, ptx_source) + """ + + # Define the kernel dynamically + @triton.jit + def square_kernel(X_ptr, Y_ptr, n, BLOCK: tl.constexpr = 1024): # noqa + pid = tl.program_id(0) + start = pid * BLOCK + offsets = start + tl.arange(0, BLOCK) + mask = offsets < n + x = tl.load(X_ptr + offsets, mask=mask, other=0.0) + y = x * x + tl.store(Y_ptr + offsets, y, mask=mask) + + # Trigger kernel compilation by doing a dummy call + x_dummy = torch.ones(1024, dtype=torch.float32, device="cuda") + y_dummy = torch.empty(1024, dtype=torch.float32, device="cuda") + square_kernel[1, 1](x_dummy, y_dummy, 1024) + + # Extract compiled CUBIN from the device cache + device_caches = square_kernel.device_caches + device_id = next(iter(device_caches.keys())) + cache_tuple = device_caches[device_id] + compiled_kernel = next(iter(cache_tuple[0].values())) + + # Get CUBIN bytes and PTX source + cubin_bytes = compiled_kernel.kernel + ptx_source = ( + compiled_kernel.asm.get("ptx", "") + if hasattr(compiled_kernel.asm, "get") + else str(compiled_kernel.asm) + ) + + return cubin_bytes, ptx_source + + +def main() -> int: # noqa: PLR0911,PLR0915 + """Load and launch Triton kernel through TVM-FFI.""" + print("Example: Triton (inline) -> CUBIN -> C++ (inline) -> TVM-FFI") + print("=" * 60) + + if not torch.cuda.is_available(): + print("[ERROR] CUDA is not available") + return 1 + + print(f"CUDA device: {torch.cuda.get_device_name(0)}") + print(f"PyTorch version: {torch.__version__}\n") + + base = Path(__file__).resolve().parent + build_dir = base / "build" + build_dir.mkdir(parents=True, exist_ok=True) + + # Compile Triton kernel to CUBIN + try: + print("Compiling Triton kernel to CUBIN...") + cubin_bytes, ptx_source = _compile_triton_to_cubin() + print(f"Compiled CUBIN: {len(cubin_bytes)} bytes") + print("\n" + "=" * 60) + print("PTX Source:") + print("=" * 60) + print(ptx_source) + print("=" * 60 + "\n") + except Exception as e: + print(f"[ERROR] Failed to compile Triton kernel: {e}") + traceback.print_exc() + return 2 + + # Write CUBIN to file Review Comment: sounds good! Added. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
