This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 39f2482580 [Fix] Fix SSA conversion for SizeVar retention (#16924)
39f2482580 is described below

commit 39f2482580b57fa5b1f6c1a1dc0e6f5e823ee4c0
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Thu Apr 25 08:11:46 2024 -0400

    [Fix] Fix SSA conversion for SizeVar retention (#16924)
    
    This PR fixes the var construction in IRConvertSSA, which always casts
    SizeVar to Var. This behavior leads to expr not being able to get
    simplified in the LowerIntrin pass later on. Specifically, if not using
    SizeVar, the LowerIntrin pass loses the information of the non-negative
    var information, and cannot simply a bunch of FloorDiv/FloorMod
    expressions.
    
    One regression test for SplitHostDevice is added to ensure the retention
    of SizeVar. Adding the test in SplitHostDevice because this is where
    the SSA conversion is used.
---
 src/tir/transforms/ir_utils.cc                     | 13 +++++++++--
 .../test_tir_transform_split_host_device.py        | 25 ++++++++++++++++++++--
 2 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 584b3cbf58..c52027acba 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -435,10 +435,19 @@ class IRConvertSSA final : public StmtExprMutator {
  private:
   struct ScopedRedefine {
     ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), 
old_var(old_var) {
+      bool is_size_var = old_var->IsInstance<SizeVarNode>();
       if (old_var->type_annotation.defined()) {
-        new_var = Var(old_var->name_hint, old_var->type_annotation);
+        if (is_size_var) {
+          new_var = SizeVar(old_var->name_hint, old_var->type_annotation);
+        } else {
+          new_var = Var(old_var->name_hint, old_var->type_annotation);
+        }
       } else {
-        new_var = Var(old_var->name_hint, old_var->dtype);
+        if (is_size_var) {
+          new_var = SizeVar(old_var->name_hint, old_var->dtype);
+        } else {
+          new_var = Var(old_var->name_hint, old_var->dtype);
+        }
       }
       parent->scope_[old_var.get()].push_back(new_var);
     }
diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py 
b/tests/python/tir-transform/test_tir_transform_split_host_device.py
index 6adfbeb81d..2d0d8a68d8 100644
--- a/tests/python/tir-transform/test_tir_transform_split_host_device.py
+++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py
@@ -15,9 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 import tvm
-from tvm import te
 import tvm.testing
-from tvm.script import tir as T, ir as I
+from tvm import te
+from tvm.script import ir as I
+from tvm.script import tir as T
 
 
 @tvm.testing.requires_cuda
@@ -345,5 +346,25 @@ def test_dynamic_launch_thread():
     tvm.ir.assert_structural_equal(expected, after)
 
 
+def test_size_var():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(var_A: T.handle, var_B: T.handle):
+            T.func_attr({"target": T.target("cuda")})
+            m = T.int64(is_size_var=True)
+            A = T.match_buffer(var_A, (m,))
+            B = T.match_buffer(var_B, (m,))
+            T.attr(T.target("cuda"), "target", 0)
+            blockIdx_x = T.launch_thread("blockIdx.x", m)
+            B_1 = T.Buffer((m,), data=B.data)
+            A_1 = T.Buffer((m,), data=A.data)
+            B_1[blockIdx_x] = A_1[blockIdx_x]
+
+    after = tvm.tir.transform.SplitHostDevice()(Module)
+    assert len(after["main_kernel"].params) == 3
+    assert isinstance(after["main_kernel"].params[2], tvm.tir.SizeVar)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to