This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 66e18fbe1f [Bugfix][TVMScript] Handle LetStmt for `var1 = var2`
expressions (#14320)
66e18fbe1f is described below
commit 66e18fbe1ff03c08cb67dc32fc9b2448a247a60f
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Apr 2 18:09:51 2023 -0500
[Bugfix][TVMScript] Handle LetStmt for `var1 = var2` expressions (#14320)
* [Bugfix][TVMScript] Handle LetStmt for `var1 = var2` expressions
Usually, when using TVMScript to represent a `PrimFunc` variable
definition `var_name = expr` defines `LetStmt` with a variable named
`var_name` bound to the expression `expr`. However, prior to this
commit, if `expr` is a `tir::Var`, the TVMScript parser would instead
silently omit the `LetStmt`, and rename all instances of that variable
to `var_name`.
The root cause was in the `VarTable.exist` check, which erroneously
returned False in all cases. This was due to a `value is v` check,
which checked if the value was the same as the stack of
maybe-shadowing values that share the same name. Replacing the
'value is v` check with a `value in v` check resolves this issue.
This bug dates to the initial implementation of the new TVMScript
parser in https://github.com/apache/tvm/pull/12496.
* Avoid implicit `PrimExpr.__bool__` from `if value in value_stack`
* Use T.meta_var where variable renaming is required.
---
python/tvm/script/parser/core/parser.py | 9 ++++----
python/tvm/tir/tensor_intrin/cuda.py | 4 ++--
.../python/unittest/test_tvmscript_syntax_sugar.py | 25 ++++++++++++++++++++++
3 files changed, 32 insertions(+), 6 deletions(-)
diff --git a/python/tvm/script/parser/core/parser.py
b/python/tvm/script/parser/core/parser.py
index 7c699c42ae..fdccabcd23 100644
--- a/python/tvm/script/parser/core/parser.py
+++ b/python/tvm/script/parser/core/parser.py
@@ -186,10 +186,11 @@ class VarTable:
res : bool
The existence of the value.
"""
- for v in self.name2value.values():
- if v is value:
- return True
- return False
+ return any(
+ value.same_as(known_value)
+ for known_value_stack in self.name2value.values()
+ for known_value in known_value_stack
+ )
def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
index da194f885d..3bc16f234f 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -245,7 +245,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
- b_row_ind, b_col_ind = maybe_swap(k, j)
+ b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))
thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
@@ -719,7 +719,7 @@ def get_wmma_sync_intrin(
for i, j, k in T.grid(m_dim, n_dim, k_dim):
with T.block(""):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
- B_index_0, B_index_1 = maybe_swap(vkk, vjj)
+ B_index_0, B_index_1 = T.meta_var(maybe_swap(vkk, vjj))
C[vii, vjj] = C[vii, vjj] + maybe_cast(A[vii, vkk]) *
maybe_cast(
B[B_index_0, B_index_1]
)
diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py
b/tests/python/unittest/test_tvmscript_syntax_sugar.py
index 184722cd36..ac1262b9b5 100644
--- a/tests/python/unittest/test_tvmscript_syntax_sugar.py
+++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py
@@ -399,6 +399,31 @@ def test_implicit_evaluate_call_extern():
assert_structural_equal(implicit, explicit)
+def test_preserve_trivial_let_binding():
+ @T.prim_func
+ def explicit(i: T.int32):
+ j = T.int32()
+ T.LetStmt(i, var=j)
+ T.evaluate(j)
+
+ @T.prim_func
+ def implicit(i: T.int32):
+ j = i
+ T.evaluate(j)
+
+ assert_structural_equal(implicit, explicit)
+
+
+def test_preserve_parameter_name():
+ @T.prim_func
+ def func(i: T.int32):
+ j = i
+ T.evaluate(j)
+
+ param_name = func.params[0].name
+ assert param_name == "i"
+
+
def test_preserve_variable_name():
"""Use variable name when generating tir::LetStmt"""