This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3e201a6aa5 [TIRx] Preserve Triton call_kernel compile options (#19728)
3e201a6aa5 is described below
commit 3e201a6aa50b50942bf44297a669795f9a7c126d
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 11 07:21:14 2026 -0400
[TIRx] Preserve Triton call_kernel compile options (#19728)
Previously `_generate_triton_kernel` overwrote the user-provided kwargs
with the constexpr dict before calling triton.compiler.compile, so
options such as num_warps passed to T.call_kernel were silently dropped.
Pass constexprs to ASTSource and forward the user kwargs as compile
options.
The pre-3.3 compatibility branches are removed in favor of an explicit
minimum-version check: they were never exercised in CI (which does not
install Triton), and Triton >= 3.3 has shipped with PyTorch since 2.7.
The integration test now matches the actual lowering, where constexpr
parameters (BLOCK_SIZE) appear as runtime kernel arguments in
call_packed, and passes num_warps=8 expecting a thread extent of 256 to
cover the option forwarding.
---
python/tvm/tirx/script/builder/triton.py | 25 ++++++++++------------
.../python/contrib/test_tir_triton_integration.py | 10 ++++++++-
2 files changed, 20 insertions(+), 15 deletions(-)
diff --git a/python/tvm/tirx/script/builder/triton.py
b/python/tvm/tirx/script/builder/triton.py
index 5c5d2b5567..14f2d92bab 100644
--- a/python/tvm/tirx/script/builder/triton.py
+++ b/python/tvm/tirx/script/builder/triton.py
@@ -29,6 +29,11 @@ from tvm.topi.utils import get_const_int
from .external_kernel import BaseKernel
+if version.parse(triton.__version__) < version.parse("3.3.0"):
+ raise ImportError(
+ f"TIR Triton integration requires Triton >= 3.3.0, but found Triton
{triton.__version__}"
+ )
+
class TritonKernel(BaseKernel):
"""A kernel from Triton JIT function.
@@ -74,12 +79,9 @@ class TritonKernel(BaseKernel):
: len(grid)
]
launch_args = [num_warps * 32] + list(grid)
- if version.parse(triton.__version__) >= version.parse("3.3.0"):
- kernel_arg_types = [
- arg.dtype if not isinstance(arg, int) else "int64" for arg in
kernel_args
- ]
- else:
- kernel_arg_types = [arg.dtype for arg in kernel_args]
+ kernel_arg_types = [
+ arg.dtype if not isinstance(arg, int) else "int64" for arg in
kernel_args
+ ]
if triton_kernel.metadata.shared > 0:
# Add shared memory size to the launch arguments
launch_param_tags.append("tirx.use_dyn_shared_memory")
@@ -107,9 +109,8 @@ class TritonKernel(BaseKernel):
for i, arg in enumerate(args):
if kernel_params[i].is_constexpr:
constants[kernel_params[i].name] = get_const_int(arg)
- if version.parse(triton.__version__) >= version.parse("3.3.0"):
- signature[kernel_params[i].name] = "constexpr"
- kernel_args.append(arg)
+ signature[kernel_params[i].name] = "constexpr"
+ kernel_args.append(arg)
continue
if arg.dtype == "handle":
assert isinstance(arg, tirx.Var)
@@ -122,10 +123,6 @@ class TritonKernel(BaseKernel):
# TODO: Support default argument in the kernel
# TODO: Add specialization for aligned buffer pointers
- if version.parse(triton.__version__) >= version.parse("3.3.0"):
- kwargs = {"constexprs": constants}
- else:
- kwargs = {"constants": constants}
- source = triton.compiler.ASTSource(fn=func, signature=signature,
**kwargs)
+ source = triton.compiler.ASTSource(fn=func, signature=signature,
constexprs=constants)
compiled = triton.compiler.compile(source, options=kwargs)
return compiled, kernel_args
diff --git a/tests/python/contrib/test_tir_triton_integration.py
b/tests/python/contrib/test_tir_triton_integration.py
index 29fb44adda..33d7962e8f 100644
--- a/tests/python/contrib/test_tir_triton_integration.py
+++ b/tests/python/contrib/test_tir_triton_integration.py
@@ -31,8 +31,12 @@ from tvm.script import tirx as T
try:
import triton
import triton.language as tl
+ from packaging import version
except ImportError:
pytestmark = pytest.skip("Triton is not available",
allow_module_level=True)
+else:
+ if version.parse(triton.__version__) < version.parse("3.3.0"):
+ pytestmark = pytest.skip("Triton >= 3.3.0 is required",
allow_module_level=True)
@tvm.testing.requires_cuda
@@ -76,6 +80,7 @@ def test_tir_triton_integration():
output.data,
m,
BLOCK_SIZE,
+ num_warps=8,
)
@R.function
@@ -86,6 +91,8 @@ def test_tir_triton_integration():
R.output(output)
return output
+ # Constexpr parameters (BLOCK_SIZE) stay in the kernel arguments, and the
+ # thread extent is 256 because the kernel is compiled with num_warps=8.
@I.ir_module(s_tir=True)
class Parsed:
@T.prim_func(s_tir=True)
@@ -103,7 +110,8 @@ def test_tir_triton_integration():
y.data,
output.data,
m,
- 128,
+ 64,
+ 256,
(m + T.int64(64) - T.int64(1)) // T.int64(64),
)