This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 66e18fbe1f [Bugfix][TVMScript] Handle LetStmt for `var1 = var2` 
expressions (#14320)
66e18fbe1f is described below

commit 66e18fbe1ff03c08cb67dc32fc9b2448a247a60f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Apr 2 18:09:51 2023 -0500

    [Bugfix][TVMScript] Handle LetStmt for `var1 = var2` expressions (#14320)
    
    * [Bugfix][TVMScript] Handle LetStmt for `var1 = var2` expressions
    
    Usually, when using TVMScript to represent a `PrimFunc` variable
    definition `var_name = expr` defines `LetStmt` with a variable named
    `var_name` bound to the expression `expr`.  However, prior to this
    commit, if `expr` is a `tir::Var`, the TVMScript parser would instead
    silently omit the `LetStmt`, and rename all instances of that variable
    to `var_name`.
    
    The root cause was in the `VarTable.exist` check, which erroneously
    returned False in all cases.  This was due to a `value is v` check,
    which checked if the value was the same as the stack of
    maybe-shadowing values that share the same name.  Replacing the
    'value is v` check with a `value in v` check resolves this issue.
    
    This bug dates to the initial implementation of the new TVMScript
    parser in https://github.com/apache/tvm/pull/12496.
    
    * Avoid implicit `PrimExpr.__bool__` from `if value in value_stack`
    
    * Use T.meta_var where variable renaming is required.
---
 python/tvm/script/parser/core/parser.py            |  9 ++++----
 python/tvm/tir/tensor_intrin/cuda.py               |  4 ++--
 .../python/unittest/test_tvmscript_syntax_sugar.py | 25 ++++++++++++++++++++++
 3 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/python/tvm/script/parser/core/parser.py 
b/python/tvm/script/parser/core/parser.py
index 7c699c42ae..fdccabcd23 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -186,10 +186,11 @@ class VarTable:
         res : bool
             The existence of the value.
         """
-        for v in self.name2value.values():
-            if v is value:
-                return True
-        return False
+        return any(
+            value.same_as(known_value)
+            for known_value_stack in self.name2value.values()
+            for known_value in known_value_stack
+        )
 
 
 def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index da194f885d..3bc16f234f 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -245,7 +245,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
                 with T.block("C"):
                     i, j, k = T.axis.remap("SSR", [i, j, k])
-                    b_row_ind, b_col_ind = maybe_swap(k, j)
+                    b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))
 
                     thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
                     thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
@@ -719,7 +719,7 @@ def get_wmma_sync_intrin(
             for i, j, k in T.grid(m_dim, n_dim, k_dim):
                 with T.block(""):
                     vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
-                    B_index_0, B_index_1 = maybe_swap(vkk, vjj)
+                    B_index_0, B_index_1 = T.meta_var(maybe_swap(vkk, vjj))
                     C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) * 
maybe_cast(
                         B[B_index_0, B_index_1]
                     )
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py 
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 184722cd36..ac1262b9b5 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -399,6 +399,31 @@ def test_implicit_evaluate_call_extern():
     assert_structural_equal(implicit, explicit)
 
 
+def test_preserve_trivial_let_binding():
+    @T.prim_func
+    def explicit(i: T.int32):
+        j = T.int32()
+        T.LetStmt(i, var=j)
+        T.evaluate(j)
+
+    @T.prim_func
+    def implicit(i: T.int32):
+        j = i
+        T.evaluate(j)
+
+    assert_structural_equal(implicit, explicit)
+
+
+def test_preserve_parameter_name():
+    @T.prim_func
+    def func(i: T.int32):
+        j = i
+        T.evaluate(j)
+
+    param_name = func.params[0].name
+    assert param_name == "i"
+
+
 def test_preserve_variable_name():
     """Use variable name when generating tir::LetStmt"""
 

Reply via email to