This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 30bf568dd0 [Tests] Check WebGPU volatile allreduce annotation 
structurally (#19740)
30bf568dd0 is described below

commit 30bf568dd0d1a61e622ac84dda49486292577c92
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 12 00:38:19 2026 -0400

    [Tests] Check WebGPU volatile allreduce annotation structurally (#19740)
    
    This pr updates the WebGPU multi-warp allreduce test to check the
    generated `tirx.volatile` allocation annotation structurally instead of
    matching the exact TVMScript printer output.
    
    The test is intended to verify that `LowerThreadAllreduce` marks the
    generated shared allocation as volatile. It previously checked for the
    exact string:
    
    ```python
    "tirx.volatile": T.bool(True)
    ```
    
    However, the current printer emits the same annotation as:
    
    ```python
    annotations={"tirx.volatile": True}
    ```
    
    The transform behavior is unchanged; only the printer spelling differs.
    This patch walks the generated TIRX body and checks for an `AllocBuffer`
    with `tirx.volatile=True`, which matches the actual semantic requirement
    of the test without depending on bool literal formatting.
---
 .../test_s_tir_transform_lower_thread_all_reduce.py        | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
index f39ccb6fde..b719416e62 100644
--- 
a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py
@@ -23,6 +23,18 @@ from tvm.script import ir as I
 from tvm.script import tirx as T
 
 
+def _has_volatile_alloc_buffer(mod):
+    has_volatile_alloc = False
+
+    def visit(node):
+        nonlocal has_volatile_alloc
+        if isinstance(node, tvm.tirx.AllocBuffer) and "tirx.volatile" in 
node.annotations:
+            has_volatile_alloc = has_volatile_alloc or 
node.annotations["tirx.volatile"] is True
+
+    tvm.tirx.stmt_functor.post_order_visit(mod["main"].body, visit)
+    return has_volatile_alloc
+
+
 def test_basic():
     transform = tvm.s_tir.transform.LowerThreadAllreduce()
 
@@ -503,7 +515,7 @@ def test_webgpu_multi_warp_reduce():
     After_script = After.script()
     assert "tvm_warp_shuffle_down" in After_script
     assert "tvm_storage_sync" in After_script
-    assert '"tirx.volatile": T.bool(True)' in After_script
+    assert _has_volatile_alloc_buffer(After)
     assert "T.uint32(" not in After_script
 
 

Reply via email to