This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab25b49225 [TIR] Fix InjectPTXLDG32 segfaults and skip non-CUDA
targets (#18671)
ab25b49225 is described below
commit ab25b49225cfc4b91171f111578bfb5906cabae1
Author: YinHanke <[email protected]>
AuthorDate: Wed Jan 28 00:51:17 2026 +0800
[TIR] Fix InjectPTXLDG32 segfaults and skip non-CUDA targets (#18671)
### Motivation
InjectPTXLDG32 rewrites BufferStore when encountering if_then_else, but
it only
initializes temporary buffers when an Allocate node exists. For
functions without
Allocate, this leads to uninitialized buffers and a hard segfault during
compilation.
In addition, the PTX-only pass can run on CPU/LLVM targets when
tir.ptx_ldg32=1,
injecting PTX intrinsics that are invalid for non-CUDA codegen.
This PR ensures temporary buffers are created even when no Allocate
exists, and
skips InjectPTXLDG32 on non-CUDA targets, preventing segfaults and
invalid PTX
intrinsics on CPU.
### Changes
- Ensure temp buffers are created when the rewrite path is taken without
Allocate
- Insert allocations at the function level when needed
- Guard InjectPTXLDG32 so it only runs on CUDA targets
- Add tests for CUDA (insertion) and CPU (skip) behavior
### Testing
test_tir_transform_inject_ptx_ldg32.py
### Fixes
- [#18612](https://github.com/apache/tvm/issues/18612)
- [#18617](https://github.com/apache/tvm/issues/18617)
- [#18599](https://github.com/apache/tvm/issues/18599)
---
src/tir/transforms/inject_ptx_ldg32.cc | 44 +++++++++---
.../test_tir_transform_inject_ptx_ldg32.py | 80 ++++++++++++++++++++++
2 files changed, 115 insertions(+), 9 deletions(-)
diff --git a/src/tir/transforms/inject_ptx_ldg32.cc
b/src/tir/transforms/inject_ptx_ldg32.cc
index 8cdef1be44..f52539fa77 100644
--- a/src/tir/transforms/inject_ptx_ldg32.cc
+++ b/src/tir/transforms/inject_ptx_ldg32.cc
@@ -35,16 +35,22 @@ namespace tir {
class PTXRewriter : public StmtMutator {
public:
- Stmt VisitStmt_(const AllocateNode* allocate) final {
- if (!has_buffer_1) {
- has_buffer_1 = true;
- // addr[0] -> global_addr / addr[1] -> local_addr
- addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)},
DataType::Int(32), "addr", "local");
- predicate_buffer =
- decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(),
"predicate", "local");
+ Stmt AddAllocationsIfNeeded(Stmt body) {
+ if (!needs_buffer || has_buffer_2) {
+ return body;
}
+ EnsureBuffers();
+ body = Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape,
Bool(true), body);
+ body = Allocate(predicate_buffer->data, predicate_buffer->dtype,
predicate_buffer->shape,
+ Bool(true), body);
+ has_buffer_2 = true;
+ return body;
+ }
+
+ Stmt VisitStmt_(const AllocateNode* allocate) final {
Stmt result = StmtMutator::VisitStmt_(allocate);
- if (!has_buffer_2) {
+ if (needs_buffer && !has_buffer_2) {
+ EnsureBuffers();
has_buffer_2 = true;
result =
Allocate(addr_buffer->data, addr_buffer->dtype, addr_buffer->shape,
Bool(true), result);
@@ -82,6 +88,8 @@ class PTXRewriter : public StmtMutator {
if (ramp != nullptr) {
return result;
}
+ EnsureBuffers();
+ needs_buffer = true;
local_addr = store->indices[0];
BufferStore addr_store(addr_buffer, global_addr,
{IntImm(DataType::Int(32), 0)});
BufferStore local_addr_store(addr_buffer, local_addr,
{IntImm(DataType::Int(32), 1)});
@@ -104,7 +112,19 @@ class PTXRewriter : public StmtMutator {
return result;
}
+ void EnsureBuffers() {
+ if (has_buffer_1) {
+ return;
+ }
+ has_buffer_1 = true;
+ // addr[0] -> global_addr / addr[1] -> local_addr
+ addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)},
DataType::Int(32), "addr", "local");
+ predicate_buffer =
+ decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(),
"predicate", "local");
+ }
+
bool has_buffer_1 = false, has_buffer_2 = false;
+ bool needs_buffer = false;
Buffer addr_buffer, predicate_buffer;
};
@@ -113,8 +133,14 @@ namespace transform {
Pass InjectPTXLDG32(bool enable_inject_ptx_intrin) {
auto pass_func = [enable_inject_ptx_intrin](PrimFunc f, IRModule m,
PassContext ctx) {
if (enable_inject_ptx_intrin) {
+ auto target = f->GetAttr<Target>("target");
+ if (!target.defined() || target.value()->kind->name != "cuda") {
+ return f;
+ }
auto* n = f.CopyOnWrite();
- n->body = PTXRewriter()(n->body);
+ PTXRewriter rewriter;
+ Stmt body = rewriter(n->body);
+ n->body = rewriter.AddAllocationsIfNeeded(body);
// inject ptx
}
return f;
diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
b/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
new file mode 100644
index 0000000000..55099f252c
--- /dev/null
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_ldg32.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+def _count_alloc(stmt):
+ num_alloc = [0]
+
+ def visit(n):
+ if isinstance(n, tvm.tir.Allocate):
+ num_alloc[0] += 1
+
+ tvm.tir.stmt_functor.post_order_visit(stmt, visit)
+ return num_alloc[0]
+
+
+def _count_ptx_ldg32(stmt):
+ num_call = [0]
+
+ def visit(n):
+ if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_ldg32":
+ num_call[0] += 1
+
+ tvm.tir.stmt_functor.post_order_visit(stmt, visit)
+ return num_call[0]
+
+
[email protected]_func
+def where_no_alloc(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32"))
-> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True, "target":
T.target("cuda")})
+ for i in range(4):
+ C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0))
+
+
[email protected]_func
+def where_no_alloc_cpu(A: T.Buffer((4,), "float32"), C: T.Buffer((4,),
"float32")) -> None:
+ T.func_attr({"global_symbol": "main", "tir.noalias": True, "target":
T.target("llvm")})
+ for i in range(4):
+ C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0))
+
+
+def test_inject_ptx_ldg32_inserts_alloc_for_no_alloc_func():
+ mod = tvm.IRModule.from_expr(where_no_alloc)
+ assert _count_alloc(mod["main"].body) == 0
+
+ mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+ assert _count_alloc(mod["main"].body) > 0
+ assert _count_ptx_ldg32(mod["main"].body) == 1
+
+
+def test_inject_ptx_ldg32_skip_non_cuda_target():
+ mod = tvm.IRModule.from_expr(where_no_alloc_cpu)
+ cpu_target = tvm.target.Target("llvm")
+ mod = tvm.IRModule({"main": mod["main"].with_attr("target", cpu_target)})
+ assert _count_alloc(mod["main"].body) == 0
+
+ mod = tvm.tir.transform.InjectPTXLDG32()(mod)
+ assert _count_alloc(mod["main"].body) == 0
+ assert _count_ptx_ldg32(mod["main"].body) == 0
+
+
+if __name__ == "__main__":
+ tvm.testing.main()