This is an automated email from the ASF dual-hosted git repository. lunderberg pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 5618628586 [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip (#17083) 5618628586 is described below commit 561862858661aca27ecd6d0d14fb30b03ad9acab Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Thu Jun 13 06:50:20 2024 -0500 [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip (#17083) * [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip Prior to this commit, all symbolic variables were printed identically, regardless of whether the underlying variable was a `tir.Var` or `tir.SizeVar`. As a result, numeric simplifications that rely on a `tir.SizeVar` being non-negative may be skipped after a round-trip through TVMScript. This commit updates the TVMScript printing and parsing of Relax functions to use `var = T.int64(is_size_var=True)` for `tir.SizeVar`, matching how `tir.SizeVar` is parsed for TIR functions. As an added benefit, this also allows Relax functions `R.Prim` arguments other than `int64` to be benefit. This may be useful in the future, such as to specify the fill value for `R.full`. * Remove strict=True argument, not available until python 3.10 * lint fix * Fix breakage in unit tests --- python/tvm/script/parser/relax/parser.py | 46 +++++++++++++++++++--- src/script/printer/relax/tir.cc | 3 +- tests/python/tvmscript/test_tvmscript_roundtrip.py | 28 +++++++++++++ 3 files changed, 71 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 400c023aa7..08269ddeeb 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -68,7 +68,14 @@ def bind_assign_value( "Expected the same dtype for TIR vars " f"but got {value.dtype} vs {prev_value.dtype}", ) - return prev_value + if not isinstance(value, type(prev_value)): + self.report_error( + node, + f"Expected the same IR type for TIR vars " + f"but existing value {type(value)} is mismatched " + f"to previous {type(prev_value)}", + ) + value = prev_value IRBuilder.name(var_name, value) return value @@ -144,18 +151,47 @@ def is_recursive(node: doc.FunctionDef) -> bool: return False +def collect_symbolic_var_from_prelude( + self: Parser, node: doc.FunctionDef, symbolic_vars: Dict[str, tir.Var] +) -> Dict[str, tir.Var]: + prelude_vars = {} + for stmt in node.body: + if isinstance(stmt, doc.Assign) and all( + isinstance(target, doc.Name) and target.id in symbolic_vars for target in stmt.targets + ): + values = self.eval_expr(stmt.value) + + try: + iter(values) + except TypeError: + values = [values] + + assert len(stmt.targets) == len(values) + for target, value in zip(stmt.targets, values): + name = target.id + prelude_vars[name] = value + + return {**symbolic_vars, **prelude_vars} + + def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> None: # Collect symbolic vars from parameters - symbolic_vars = set() + symbolic_vars = {} for arg in node.args.args: if arg.annotation is None: self.report_error(arg, "Type annotation is required for function parameters.") param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation) - symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars()) + + for var_name in param_sinfo_proxy.get_symbolic_vars(): + if var_name not in symbolic_vars: + symbolic_vars[var_name] = tir.Var(var_name, "int64") + + # Update symbolic vars based on + symbolic_vars = collect_symbolic_var_from_prelude(self, node, symbolic_vars) # Define symbolic vars to the current var_table frame - for var_name in symbolic_vars: - self.var_table.add(var_name, tir.Var(var_name, "int64"), allow_shadowing=False) + for var_name, var in symbolic_vars.items(): + self.var_table.add(var_name, var, allow_shadowing=False) @dispatch.register(token="relax", type_name="FunctionDef") diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 1a9c5d0546..6f9a8cbf89 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -18,6 +18,7 @@ */ #include <tvm/ir/expr.h> +#include "../tir/utils.h" #include "./utils.h" namespace tvm { @@ -59,7 +60,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { } IdDoc var = d->Define(n, GetRef<Frame>(f), n->name_hint.empty() ? "v" : n->name_hint); var->source_paths.push_back(n_p); - f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}), NullOpt)); + f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), NullOpt)); } if (Optional<ExprDoc> doc = d->GetVarDoc(n)) { return doc.value(); diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index ee404f08ef..f81a80de6d 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -4088,6 +4088,32 @@ def relax_match_cast_struct_info_proxy(): yield make_ir_generator(subclass) +def relax_symbolic_size_var(): + """Relax symbolic variables may be SizeVar""" + N = tvm.tir.SizeVar("N", "int64") + + @R.function + def func(A: R.Tensor([N], "float16")): + B: R.Tensor([N], "float16") = A + return B + + return func + + +def relax_float_symbolic_var(): + """Relax symbolic variables may hold any dtype""" + + @R.function + def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): + N = T.int64() + threshold = T.float16() + + B = A >= R.prim_value(threshold / T.cast(N, "float16")) + return B + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -4174,6 +4200,8 @@ ir_generator = tvm.testing.parameter( return_zero_private_with_attr, *op_of_literal(), *relax_match_cast_struct_info_proxy(), + relax_symbolic_size_var, + relax_float_symbolic_var, ) relax_ir_generator = tvm.testing.parameter(