This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5618628586 [TVMScript][Relax] Preserve tir.SizeVar through TVMScript 
round-trip (#17083)
5618628586 is described below

commit 561862858661aca27ecd6d0d14fb30b03ad9acab
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Thu Jun 13 06:50:20 2024 -0500

    [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip 
(#17083)
    
    * [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip
    
    Prior to this commit, all symbolic variables were printed identically,
    regardless of whether the underlying variable was a `tir.Var` or
    `tir.SizeVar`.  As a result, numeric simplifications that rely on a
    `tir.SizeVar` being non-negative may be skipped after a round-trip
    through TVMScript.
    
    This commit updates the TVMScript printing and parsing of Relax
    functions to use `var = T.int64(is_size_var=True)` for `tir.SizeVar`,
    matching how `tir.SizeVar` is parsed for TIR functions.  As an added
    benefit, this also allows Relax functions `R.Prim` arguments other
    than `int64` to be benefit.  This may be useful in the future, such as
    to specify the fill value for `R.full`.
    
    * Remove strict=True argument, not available until python 3.10
    
    * lint fix
    
    * Fix breakage in unit tests
---
 python/tvm/script/parser/relax/parser.py           | 46 +++++++++++++++++++---
 src/script/printer/relax/tir.cc                    |  3 +-
 tests/python/tvmscript/test_tvmscript_roundtrip.py | 28 +++++++++++++
 3 files changed, 71 insertions(+), 6 deletions(-)

diff --git a/python/tvm/script/parser/relax/parser.py 
b/python/tvm/script/parser/relax/parser.py
index 400c023aa7..08269ddeeb 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -68,7 +68,14 @@ def bind_assign_value(
                     "Expected the same dtype for TIR vars "
                     f"but got {value.dtype} vs {prev_value.dtype}",
                 )
-            return prev_value
+            if not isinstance(value, type(prev_value)):
+                self.report_error(
+                    node,
+                    f"Expected the same IR type for TIR vars "
+                    f"but existing value {type(value)} is mismatched "
+                    f"to previous {type(prev_value)}",
+                )
+            value = prev_value
         IRBuilder.name(var_name, value)
         return value
 
@@ -144,18 +151,47 @@ def is_recursive(node: doc.FunctionDef) -> bool:
     return False
 
 
+def collect_symbolic_var_from_prelude(
+    self: Parser, node: doc.FunctionDef, symbolic_vars: Dict[str, tir.Var]
+) -> Dict[str, tir.Var]:
+    prelude_vars = {}
+    for stmt in node.body:
+        if isinstance(stmt, doc.Assign) and all(
+            isinstance(target, doc.Name) and target.id in symbolic_vars for 
target in stmt.targets
+        ):
+            values = self.eval_expr(stmt.value)
+
+            try:
+                iter(values)
+            except TypeError:
+                values = [values]
+
+            assert len(stmt.targets) == len(values)
+            for target, value in zip(stmt.targets, values):
+                name = target.id
+                prelude_vars[name] = value
+
+    return {**symbolic_vars, **prelude_vars}
+
+
 def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> 
None:
     # Collect symbolic vars from parameters
-    symbolic_vars = set()
+    symbolic_vars = {}
     for arg in node.args.args:
         if arg.annotation is None:
             self.report_error(arg, "Type annotation is required for function 
parameters.")
         param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
-        symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars())
+
+        for var_name in param_sinfo_proxy.get_symbolic_vars():
+            if var_name not in symbolic_vars:
+                symbolic_vars[var_name] = tir.Var(var_name, "int64")
+
+    # Update symbolic vars based on
+    symbolic_vars = collect_symbolic_var_from_prelude(self, node, 
symbolic_vars)
 
     # Define symbolic vars to the current var_table frame
-    for var_name in symbolic_vars:
-        self.var_table.add(var_name, tir.Var(var_name, "int64"), 
allow_shadowing=False)
+    for var_name, var in symbolic_vars.items():
+        self.var_table.add(var_name, var, allow_shadowing=False)
 
 
 @dispatch.register(token="relax", type_name="FunctionDef")
diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc
index 1a9c5d0546..6f9a8cbf89 100644
--- a/src/script/printer/relax/tir.cc
+++ b/src/script/printer/relax/tir.cc
@@ -18,6 +18,7 @@
  */
 #include <tvm/ir/expr.h>
 
+#include "../tir/utils.h"
 #include "./utils.h"
 
 namespace tvm {
@@ -59,7 +60,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
     }
     IdDoc var = d->Define(n, GetRef<Frame>(f), n->name_hint.empty() ? "v" : 
n->name_hint);
     var->source_paths.push_back(n_p);
-    f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}), 
NullOpt));
+    f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), NullOpt));
   }
   if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
     return doc.value();
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index ee404f08ef..f81a80de6d 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -4088,6 +4088,32 @@ def relax_match_cast_struct_info_proxy():
         yield make_ir_generator(subclass)
 
 
+def relax_symbolic_size_var():
+    """Relax symbolic variables may be SizeVar"""
+    N = tvm.tir.SizeVar("N", "int64")
+
+    @R.function
+    def func(A: R.Tensor([N], "float16")):
+        B: R.Tensor([N], "float16") = A
+        return B
+
+    return func
+
+
+def relax_float_symbolic_var():
+    """Relax symbolic variables may hold any dtype"""
+
+    @R.function
+    def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")):
+        N = T.int64()
+        threshold = T.float16()
+
+        B = A >= R.prim_value(threshold / T.cast(N, "float16"))
+        return B
+
+    return func
+
+
 ir_generator = tvm.testing.parameter(
     launch_env_thread,
     opt_gemm_normalize,
@@ -4174,6 +4200,8 @@ ir_generator = tvm.testing.parameter(
     return_zero_private_with_attr,
     *op_of_literal(),
     *relax_match_cast_struct_info_proxy(),
+    relax_symbolic_size_var,
+    relax_float_symbolic_var,
 )
 
 relax_ir_generator = tvm.testing.parameter(

Reply via email to