gemini-code-assist[bot] commented on code in PR #18347:
URL: https://github.com/apache/tvm/pull/18347#discussion_r2383895791
##########
python/tvm/tir/pipeline.py:
##########
@@ -120,6 +120,21 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx:
tvm.transform.PassContext) -> tvm.ir.I
tir.transform.LowerDeviceKernelLaunch(),
]
)
+ # Optional GPU verification based on PassContext configuration.
+ # Usage example:
+ #
+ # with tvm.transform.PassContext(config={
+ # "tir.verify_gpu_code": True,
+ # # Optional per-device cap override; for Ampere (e.g., RTX
A6000, SM 86),
+ # # you may choose 96 KB (98304 bytes). Default is conservative
48 KB.
+ # # "tir.cuda.max_shared_memory_per_block": 96 * 1024,
+ # }):
+ # lib = tvm.tir.build(mod, target="cuda")
+ #
+ # This check is opt-in and does not change defaults.
+ if bool(config.get("tir.verify_gpu_code", False)):
+ cap = int(config.get("tir.cuda.max_shared_memory_per_block", 48 *
1024))
Review Comment:

The `bool()` and `int()` type casts are redundant. `config.get()` with a
default value already returns the correct type, and the value from the first
`get` can be used directly in the `if` condition.
```suggestion
if config.get("tir.verify_gpu_code", False):
cap = config.get("tir.cuda.max_shared_memory_per_block", 48 *
1024)
```
##########
tests/python/relay/test_cuda_shared_mem_overflow.py:
##########
@@ -0,0 +1,45 @@
+"""
+Tests for verifying shared memory limit using a deterministic TIR kernel.
+
+To enable optional pipeline verification during normal builds (opt-in, off by
default):
+
+ with tvm.transform.PassContext(config={
+ "tir.verify_gpu_code": True,
+ # For Ampere (e.g., RTX A6000, SM 86), you may set a higher cap, e.g.,
96 KB:
+ # "tir.cuda.max_shared_memory_per_block": 96 * 1024,
+ # By default, leave unset to use a conservative 48 KB.
+ }):
+ lib = tvm.tir.build(mod, target="cuda")
+
+This test avoids schedule/lowering variability by using a direct kernel that
allocates a
+64 KB shared buffer and asserts the verifier fails when the cap is 48 KB.
+"""
+
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+
[email protected]_func
+def _pf_direct_kernel_shared_large(A: T.handle) -> None:
+ T.func_attr({"global_symbol": "_pf_direct_kernel_shared_large",
"tir.noalias": True})
+ A_buf = T.match_buffer(A, (1,), dtype="float32")
+ blockIdx_x = T.launch_thread("blockIdx.x", 1)
+ threadIdx_x = T.launch_thread("threadIdx.x", 1)
+ # 16384 float32 elements = 64 KB shared allocation
+ sh = T.allocate([16384], "float32", "shared")
+ s = T.float32(0)
+ for t in T.serial(0, 16384):
+ s = s + T.float32(1)
Review Comment:

The allocated shared memory buffer `sh` is not used in the loop. While this
works for the verifier, using the buffer would make the test's intent clearer
and more robust against potential dead-code elimination if this function were
used in a context with more optimizations. For example:
```python
for t in T.serial(0, 16384):
sh[t] = T.float32(1)
s = s + sh[t]
```
##########
tests/python/relay/test_cuda_shared_mem_overflow.py:
##########
@@ -0,0 +1,45 @@
+"""
Review Comment:

This test appears to be for TIR analysis rather than Relay. To improve
project structure and discoverability, consider moving this file to a more
appropriate location, such as under `tests/python/tir/`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]