This is an automated email from the ASF dual-hosted git repository. yongwww pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new ed9aa56b37 [Relax][Analysis] Handle recursive functions in CollectVarUsage (#17224) ed9aa56b37 is described below commit ed9aa56b373c60acef151d4defac44e3c2360a0a Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Thu Aug 22 11:26:27 2024 -0500 [Relax][Analysis] Handle recursive functions in CollectVarUsage (#17224) * [Relax][Analysis] Handle recursive functions in CollectVarUsage Prior to this commit, the `relax::analysis::CollectVarUsage` utility treated a local function definition as in-scope after visiting the body of the local function. As a result, recursive calls from a local function were incorrectly identified as calls to an undefined variable. This commit updates the `CollectVarUsage` to treat a local function definition as in-scope when inspecting the function body. This change is similar to the change made for structural equality in https://github.com/apache/tvm/pull/16756. * lint fixes --- src/relax/analysis/udchain.cc | 21 +++++- .../relax/test_transform_dead_code_elimination.py | 81 ++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index d7ab4f1031..65e15a4161 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -55,6 +55,7 @@ class UDChain : relax::ExprVisitor { private: Map<Var, Expr> bound_values; + std::unordered_set<Var> forward_declarations; std::unordered_map<Var, support::OrderedSet<Var>> usage_map; support::OrderedSet<Var> outputs; @@ -71,9 +72,20 @@ class UDChain : relax::ExprVisitor { cur_user_ = cache; } + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) override { + // A local Relax function may be recursively defined. References to + // `binding->var` that appear within `func` are valid. + DefineVar(binding->var); + forward_declarations.insert(binding->var); + ExprVisitor::VisitBinding_(binding, func); + } + void VisitVarDef(const Var& var) override { - CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition"; - usage_map[var] = {}; + if (forward_declarations.count(var)) { + forward_declarations.erase(var); + } else { + DefineVar(var); + } } void VisitExpr_(const VarNode* op) override { auto var = GetRef<Var>(op); @@ -89,6 +101,11 @@ class UDChain : relax::ExprVisitor { cur_user_ = nullptr; ExprVisitor::VisitExpr_(op); } + + void DefineVar(const Var& var) { + CHECK(!usage_map.count(var)) << "Variable " << var << " was used before its definition"; + usage_map[var] = {}; + } }; std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> FunctionUseDef( diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 142faf5160..6546d09777 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -658,5 +658,86 @@ def test_well_formed_output_with_restricted_scope(): tvm.ir.assert_structural_equal(Expected, After) +def test_recursively_defined_lambda(): + """DCE may be applied to recursively-defined functions + + While most expressions may only contain references to + previously-defined variables, local Relax function definitions may + contain references to themselves. + + This is a regression test. In previous implementations, the + recursive use of `while_loop` resulted in an error, as + `while_loop` was not considered in-scope by the `CollectVarUsage` + utility until after the body of `while_loop` had been visited. + + """ + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond = R.call_pure_packed( + "test.vm.less", i, R.const(10), sinfo_args=R.Tensor((), dtype="bool") + ) + c = R.const(1, dtype="int32") + if cond: + new_i = R.add(i, c) + new_s = R.add(s, x) + r = while_loop(new_i, new_s) + else: + r = s + return r + + gv = while_loop(R.const(0), x) + return gv + + Expected = Before + + verify(Before, Expected) + + +def test_recursively_defined_closure(): + """DCE may be applied to recursively-defined closures + + This test is identical to `test_recursively_defined_lambda`, + except that the threshold for recursion is defined in an enclosed + variable outside of the recursive function. + + """ + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor: + threshold = R.const(10) + + @R.function + def while_loop( + i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + cond = R.call_pure_packed( + "test.vm.less", i, threshold, sinfo_args=R.Tensor((), dtype="bool") + ) + c = R.const(1, dtype="int32") + if cond: + new_i = R.add(i, c) + new_s = R.add(s, x) + r = while_loop(new_i, new_s) + else: + r = s + return r + + gv = while_loop(R.const(0), x) + return gv + + Expected = Before + + verify(Before, Expected) + + if __name__ == "__main__": tvm.testing.main()