This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 446bd2dbf0 [BugFix][S-TIR] Wrap bare scalar bodies in
DefaultGPUSchedule to avoid root-block crash (#19514)
446bd2dbf0 is described below
commit 446bd2dbf0f480718ad7d8ad64f59ebb9aa9c4cf
Author: Soowon Jeong <[email protected]>
AuthorDate: Thu May 7 01:22:43 2026 +0900
[BugFix][S-TIR] Wrap bare scalar bodies in DefaultGPUSchedule to avoid
root-block crash (#19514)
## Problem
Closes #17873.
`DefaultGPUSchedule` crashes when a PrimFunc body is a bare
`SBlockRealize` (a fully-scalar op with no enclosing loops and no iter
vars):
```
ValueError: Check failed: (sref->parent != nullptr) is false:
Cannot add loops on top of the root block
```
Minimal repro (TVMScript decorators are omitted in this snippet to
satisfy the PR-body lint; the regression test uses the regular
`T.prim_func` form):
```
ir_module:
prim_func main(a: Buffer((), "float32"),
b: Buffer((), "float32"),
c: Buffer((), "float32")):
func_attr({"target": target("nvidia/geforce-rtx-3080")})
with sblock("scalar_add"):
c[()] = a[()] + b[()]
s_tir.transform.DefaultGPUSchedule()(M) # crashes
```
## Root Cause
The realized `scalar_add` block is itself the prim_func body's root
sref — it has no parent stmt to mutate. `ThreadBind`
(`src/s_tir/transform/default_gpu_schedule.cc`) reaches the
`loops.empty()` branch and calls `sch->AddUnitLoop(block)`, which fails
the `sref->parent != nullptr` check in `s_tir::AddUnitLoop`
(`src/s_tir/schedule/primitive/loop_transformation.cc:1166`).
The schedule infrastructure additionally requires the prim_func body
to be an `SBlockRealize` whose block is the function's root
(`GetRootPrimFunc` in `src/s_tir/schedule/analysis/analysis.cc:53`),
so the body cannot simply be wrapped in a top-level `For`.
## Fix
Before constructing the schedule, rewrite GPU-bound PrimFuncs whose
body is a bare-leaf `SBlockRealize` so the realized block is no longer
the root. The wrap conditions are intentionally narrow:
1. `func->body` is `SBlockRealize`,
2. the realized block has empty `iter_vars`, and
3. the block's body is not `For` or `SBlockRealize` (i.e. it is a leaf
computation, not the well-formed implicit root that wraps a loop
nest produced by the rest of the pipeline).
When all three hold, the body becomes:
```
SBlockRealize(
block=SBlock("root", body=
For(u, 0, 1, kSerial,
SBlockRealize(iter_values=[u],
block=<original block, iter_vars=[IterVar(0..1, vu, kDataPar)]>))))
```
The synthesised 1-extent data-parallel iter keeps
`iter_values.size() == iter_vars.size()` for downstream checks, and the
new For loop gives `ThreadBind` a real loop to bind to `blockIdx.x` /
`threadIdx.x`. Already-scheduled functions and host-only PrimFuncs are
skipped via the existing `IsScheduledOnGPU` / `kIsScheduled` gating.
## Testing
```
pytest
tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
```
10 passed (9 existing + 1 new `test_scalar_block_no_loops`). End-to-end
compile + execute on RTX 3080 (sm_86): the scalar repro returns the
expected `2.0 + 3.0 = 5.0`.
---
src/s_tir/transform/default_gpu_schedule.cc | 70 ++++++++++++++++++++++
.../test_s_tir_transform_default_gpu_schedule.py | 34 +++++++++++
2 files changed, 104 insertions(+)
diff --git a/src/s_tir/transform/default_gpu_schedule.cc
b/src/s_tir/transform/default_gpu_schedule.cc
index d41e2f5843..b130cbfe45 100644
--- a/src/s_tir/transform/default_gpu_schedule.cc
+++ b/src/s_tir/transform/default_gpu_schedule.cc
@@ -103,6 +103,55 @@ IRModule MarkScheduled(const IRModule& mod) {
mod->global_infos); // global_infos
}
+/*!
+ * \brief Wrap a PrimFunc body that is a bare \c SBlockRealize (no enclosing
+ * loops, no iter vars) so the realized block is no longer the function's root
+ * sref.
+ *
+ * Without this, \c ThreadBind below calls \c Schedule::AddUnitLoop(block) on
+ * a block that is itself the prim_func's root sref, hitting the
+ * "Cannot add loops on top of the root block" check in
+ * \c s_tir::AddUnitLoop. The schedule infrastructure additionally requires
+ * the prim_func body to be an \c SBlockRealize, so we keep that shape and
+ * push the original block one level deeper, inside a wrapping root block
+ * that holds a unit serial loop. The synthesised data-parallel iter keeps
+ * iter_values/iter_vars counts consistent for downstream checks.
+ */
+tirx::PrimFunc WrapBareSBlockBody(const tirx::PrimFunc& func) {
+ const auto* realize = func->body.as<tirx::SBlockRealizeNode>();
+ if (realize == nullptr || !realize->block->iter_vars.empty()) {
+ return func;
+ }
+ // Only wrap when the block is a leaf computation. A well-formed PrimFunc
+ // produced by the rest of the pipeline has an implicit root SBlockRealize
+ // whose block body is a For loop (or a nested SBlockRealize) — that case
+ // already has somewhere to put thread bindings, so leave it alone.
+ const tirx::Stmt& inner = realize->block->body;
+ if (inner->IsInstance<tirx::ForNode>() ||
inner->IsInstance<tirx::SBlockRealizeNode>()) {
+ return func;
+ }
+ tvm::IntImm zero(tvm::DataType::Int(32), 0);
+ tvm::IntImm one(tvm::DataType::Int(32), 1);
+ tirx::Var loop_var("u", tvm::DataType::Int(32));
+ tirx::Var iter_var_var("vu", tvm::DataType::Int(32));
+ tirx::IterVar new_iter(tvm::Range::FromMinExtent(zero, one), iter_var_var,
+ tirx::IterVarType::kDataPar);
+ tirx::SBlock inner_block = realize->block;
+ inner_block.CopyOnWrite()->iter_vars = ffi::Array<tirx::IterVar>{new_iter};
+ tirx::SBlockRealize
inner_realize(/*iter_values=*/ffi::Array<tvm::PrimExpr>{loop_var},
+ /*predicate=*/realize->predicate,
inner_block);
+ tirx::Stmt for_stmt = tirx::For(loop_var, zero, one, tirx::ForKind::kSerial,
inner_realize);
+ tirx::SBlock root_block(/*iter_vars=*/ffi::Array<tirx::IterVar>{},
+ /*reads=*/ffi::Array<tirx::BufferRegion>{},
+ /*writes=*/ffi::Array<tirx::BufferRegion>{},
+ /*name_hint=*/"root", /*body=*/for_stmt);
+ tirx::SBlockRealize root_realize(/*iter_values=*/ffi::Array<tvm::PrimExpr>{},
+ /*predicate=*/tvm::Bool(true), root_block);
+ tirx::PrimFunc result = func;
+ result.CopyOnWrite()->body = std::move(root_realize);
+ return result;
+}
+
bool IsScheduledOnGPU(const BaseFunc& func) {
// the target from context.
tvm::Target target = tvm::Target::Current();
@@ -125,6 +174,27 @@ bool IsScheduledOnGPU(const BaseFunc& func) {
Pass DefaultGPUSchedule() {
auto pass_func = //
[=](IRModule m, PassContext pc) {
+ // Wrap any GPU-bound PrimFunc whose body is a bare SBlockRealize
+ // (e.g. a scalar op) so ThreadBind below has a loop to operate on.
+ ffi::Map<GlobalVar, BaseFunc> wrapped;
+ bool any_wrapped = false;
+ for (const auto& [gv, base_func] : m->functions) {
+ if (const auto* prim_func_node = base_func.as<tirx::PrimFuncNode>();
+ prim_func_node != nullptr && IsScheduledOnGPU(base_func) &&
+ !base_func->HasNonzeroAttr(tirx::attr::kIsScheduled)) {
+ tirx::PrimFunc func = ffi::GetRef<tirx::PrimFunc>(prim_func_node);
+ tirx::PrimFunc new_func = WrapBareSBlockBody(func);
+ if (!new_func.same_as(func)) {
+ wrapped.Set(gv, new_func);
+ any_wrapped = true;
+ continue;
+ }
+ }
+ wrapped.Set(gv, base_func);
+ }
+ if (any_wrapped) {
+ m = IRModule(wrapped, m->source_map, m->attrs, m->global_infos);
+ }
s_tir::Schedule sch = s_tir::Schedule::Traced(m, /*seed=*/-1,
/*debug_mask=*/0,
s_tir::ScheduleErrorRenderLevel::kDetail);
for (const auto& [gv, func] : m->functions) {
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
index c562a29e87..f08dba00d6 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
@@ -567,5 +567,39 @@ def test_sum():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_scalar_block_no_loops():
+ # A PrimFunc whose body is a bare SBlockRealize (e.g. a fully-scalar op)
+ # used to crash DefaultGPUSchedule with "Cannot add loops on top of the
+ # root block" because the realized block was the function's root sref.
+ # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+ # fmt: off
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"),
c: T.Buffer((), "float32")):
+ with T.sblock("scalar_add"):
+ c[()] = a[()] + b[()]
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"),
c: T.Buffer((), "float32")):
+ T.func_attr({"tirx.is_scheduled": True})
+ # with T.sblock("root"):
+ for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+ for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+ with T.sblock("scalar_add"):
+ vu = T.axis.spatial(1, 0)
+ T.reads()
+ T.writes()
+ c[()] = a[()] + b[()]
+ # fmt: on
+ # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+ target = tvm.target.Target("nvidia/geforce-rtx-3070")
+ with target, tvm.transform.PassContext(opt_level=0):
+ mod = DefaultGPUSchedule()(Before)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()