This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 446bd2dbf0 [BugFix][S-TIR] Wrap bare scalar bodies in 
DefaultGPUSchedule to avoid root-block crash (#19514)
446bd2dbf0 is described below

commit 446bd2dbf0f480718ad7d8ad64f59ebb9aa9c4cf
Author: Soowon Jeong <[email protected]>
AuthorDate: Thu May 7 01:22:43 2026 +0900

    [BugFix][S-TIR] Wrap bare scalar bodies in DefaultGPUSchedule to avoid 
root-block crash (#19514)
    
    ## Problem
    
    Closes #17873.
    
    `DefaultGPUSchedule` crashes when a PrimFunc body is a bare
    `SBlockRealize` (a fully-scalar op with no enclosing loops and no iter
    vars):
    
    ```
    ValueError: Check failed: (sref->parent != nullptr) is false:
      Cannot add loops on top of the root block
    ```
    
    Minimal repro (TVMScript decorators are omitted in this snippet to
    satisfy the PR-body lint; the regression test uses the regular
    `T.prim_func` form):
    
    ```
    ir_module:
      prim_func main(a: Buffer((), "float32"),
                     b: Buffer((), "float32"),
                     c: Buffer((), "float32")):
          func_attr({"target": target("nvidia/geforce-rtx-3080")})
          with sblock("scalar_add"):
              c[()] = a[()] + b[()]
    
    s_tir.transform.DefaultGPUSchedule()(M)  # crashes
    ```
    
    ## Root Cause
    
    The realized `scalar_add` block is itself the prim_func body's root
    sref — it has no parent stmt to mutate. `ThreadBind`
    (`src/s_tir/transform/default_gpu_schedule.cc`) reaches the
    `loops.empty()` branch and calls `sch->AddUnitLoop(block)`, which fails
    the `sref->parent != nullptr` check in `s_tir::AddUnitLoop`
    (`src/s_tir/schedule/primitive/loop_transformation.cc:1166`).
    
    The schedule infrastructure additionally requires the prim_func body
    to be an `SBlockRealize` whose block is the function's root
    (`GetRootPrimFunc` in `src/s_tir/schedule/analysis/analysis.cc:53`),
    so the body cannot simply be wrapped in a top-level `For`.
    
    ## Fix
    
    Before constructing the schedule, rewrite GPU-bound PrimFuncs whose
    body is a bare-leaf `SBlockRealize` so the realized block is no longer
    the root. The wrap conditions are intentionally narrow:
    
    1. `func->body` is `SBlockRealize`,
    2. the realized block has empty `iter_vars`, and
    3. the block's body is not `For` or `SBlockRealize` (i.e. it is a leaf
       computation, not the well-formed implicit root that wraps a loop
       nest produced by the rest of the pipeline).
    
    When all three hold, the body becomes:
    
    ```
    SBlockRealize(
      block=SBlock("root", body=
        For(u, 0, 1, kSerial,
          SBlockRealize(iter_values=[u],
            block=<original block, iter_vars=[IterVar(0..1, vu, kDataPar)]>))))
    ```
    
    The synthesised 1-extent data-parallel iter keeps
    `iter_values.size() == iter_vars.size()` for downstream checks, and the
    new For loop gives `ThreadBind` a real loop to bind to `blockIdx.x` /
    `threadIdx.x`. Already-scheduled functions and host-only PrimFuncs are
    skipped via the existing `IsScheduledOnGPU` / `kIsScheduled` gating.
    
    ## Testing
    
    ```
    pytest 
tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
    ```
    
    10 passed (9 existing + 1 new `test_scalar_block_no_loops`). End-to-end
    compile + execute on RTX 3080 (sm_86): the scalar repro returns the
    expected `2.0 + 3.0 = 5.0`.
---
 src/s_tir/transform/default_gpu_schedule.cc        | 70 ++++++++++++++++++++++
 .../test_s_tir_transform_default_gpu_schedule.py   | 34 +++++++++++
 2 files changed, 104 insertions(+)

diff --git a/src/s_tir/transform/default_gpu_schedule.cc 
b/src/s_tir/transform/default_gpu_schedule.cc
index d41e2f5843..b130cbfe45 100644
--- a/src/s_tir/transform/default_gpu_schedule.cc
+++ b/src/s_tir/transform/default_gpu_schedule.cc
@@ -103,6 +103,55 @@ IRModule MarkScheduled(const IRModule& mod) {
                   mod->global_infos);  // global_infos
 }
 
+/*!
+ * \brief Wrap a PrimFunc body that is a bare \c SBlockRealize (no enclosing
+ * loops, no iter vars) so the realized block is no longer the function's root
+ * sref.
+ *
+ * Without this, \c ThreadBind below calls \c Schedule::AddUnitLoop(block) on
+ * a block that is itself the prim_func's root sref, hitting the
+ * "Cannot add loops on top of the root block" check in
+ * \c s_tir::AddUnitLoop. The schedule infrastructure additionally requires
+ * the prim_func body to be an \c SBlockRealize, so we keep that shape and
+ * push the original block one level deeper, inside a wrapping root block
+ * that holds a unit serial loop. The synthesised data-parallel iter keeps
+ * iter_values/iter_vars counts consistent for downstream checks.
+ */
+tirx::PrimFunc WrapBareSBlockBody(const tirx::PrimFunc& func) {
+  const auto* realize = func->body.as<tirx::SBlockRealizeNode>();
+  if (realize == nullptr || !realize->block->iter_vars.empty()) {
+    return func;
+  }
+  // Only wrap when the block is a leaf computation. A well-formed PrimFunc
+  // produced by the rest of the pipeline has an implicit root SBlockRealize
+  // whose block body is a For loop (or a nested SBlockRealize) — that case
+  // already has somewhere to put thread bindings, so leave it alone.
+  const tirx::Stmt& inner = realize->block->body;
+  if (inner->IsInstance<tirx::ForNode>() || 
inner->IsInstance<tirx::SBlockRealizeNode>()) {
+    return func;
+  }
+  tvm::IntImm zero(tvm::DataType::Int(32), 0);
+  tvm::IntImm one(tvm::DataType::Int(32), 1);
+  tirx::Var loop_var("u", tvm::DataType::Int(32));
+  tirx::Var iter_var_var("vu", tvm::DataType::Int(32));
+  tirx::IterVar new_iter(tvm::Range::FromMinExtent(zero, one), iter_var_var,
+                         tirx::IterVarType::kDataPar);
+  tirx::SBlock inner_block = realize->block;
+  inner_block.CopyOnWrite()->iter_vars = ffi::Array<tirx::IterVar>{new_iter};
+  tirx::SBlockRealize 
inner_realize(/*iter_values=*/ffi::Array<tvm::PrimExpr>{loop_var},
+                                    /*predicate=*/realize->predicate, 
inner_block);
+  tirx::Stmt for_stmt = tirx::For(loop_var, zero, one, tirx::ForKind::kSerial, 
inner_realize);
+  tirx::SBlock root_block(/*iter_vars=*/ffi::Array<tirx::IterVar>{},
+                          /*reads=*/ffi::Array<tirx::BufferRegion>{},
+                          /*writes=*/ffi::Array<tirx::BufferRegion>{},
+                          /*name_hint=*/"root", /*body=*/for_stmt);
+  tirx::SBlockRealize root_realize(/*iter_values=*/ffi::Array<tvm::PrimExpr>{},
+                                   /*predicate=*/tvm::Bool(true), root_block);
+  tirx::PrimFunc result = func;
+  result.CopyOnWrite()->body = std::move(root_realize);
+  return result;
+}
+
 bool IsScheduledOnGPU(const BaseFunc& func) {
   // the target from context.
   tvm::Target target = tvm::Target::Current();
@@ -125,6 +174,27 @@ bool IsScheduledOnGPU(const BaseFunc& func) {
 Pass DefaultGPUSchedule() {
   auto pass_func =  //
       [=](IRModule m, PassContext pc) {
+        // Wrap any GPU-bound PrimFunc whose body is a bare SBlockRealize
+        // (e.g. a scalar op) so ThreadBind below has a loop to operate on.
+        ffi::Map<GlobalVar, BaseFunc> wrapped;
+        bool any_wrapped = false;
+        for (const auto& [gv, base_func] : m->functions) {
+          if (const auto* prim_func_node = base_func.as<tirx::PrimFuncNode>();
+              prim_func_node != nullptr && IsScheduledOnGPU(base_func) &&
+              !base_func->HasNonzeroAttr(tirx::attr::kIsScheduled)) {
+            tirx::PrimFunc func = ffi::GetRef<tirx::PrimFunc>(prim_func_node);
+            tirx::PrimFunc new_func = WrapBareSBlockBody(func);
+            if (!new_func.same_as(func)) {
+              wrapped.Set(gv, new_func);
+              any_wrapped = true;
+              continue;
+            }
+          }
+          wrapped.Set(gv, base_func);
+        }
+        if (any_wrapped) {
+          m = IRModule(wrapped, m->source_map, m->attrs, m->global_infos);
+        }
         s_tir::Schedule sch = s_tir::Schedule::Traced(m, /*seed=*/-1, 
/*debug_mask=*/0,
                                                       
s_tir::ScheduleErrorRenderLevel::kDetail);
         for (const auto& [gv, func] : m->functions) {
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py 
b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
index c562a29e87..f08dba00d6 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
@@ -567,5 +567,39 @@ def test_sum():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_scalar_block_no_loops():
+    # A PrimFunc whose body is a bare SBlockRealize (e.g. a fully-scalar op)
+    # used to crash DefaultGPUSchedule with "Cannot add loops on top of the
+    # root block" because the realized block was the function's root sref.
+    # pylint: disable=no-self-argument,missing-class-docstring,line-too-long
+    # fmt: off
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), 
c: T.Buffer((), "float32")):
+            with T.sblock("scalar_add"):
+                c[()] = a[()] + b[()]
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def scalar_add(a: T.Buffer((), "float32"), b: T.Buffer((), "float32"), 
c: T.Buffer((), "float32")):
+            T.func_attr({"tirx.is_scheduled": True})
+            # with T.sblock("root"):
+            for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
+                for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"):
+                    with T.sblock("scalar_add"):
+                        vu = T.axis.spatial(1, 0)
+                        T.reads()
+                        T.writes()
+                        c[()] = a[()] + b[()]
+    # fmt: on
+    # pylint: enable=no-self-argument,missing-class-docstring,line-too-long
+    target = tvm.target.Target("nvidia/geforce-rtx-3070")
+    with target, tvm.transform.PassContext(opt_level=0):
+        mod = DefaultGPUSchedule()(Before)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to