This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ed9aa56b37 [Relax][Analysis] Handle recursive functions in 
CollectVarUsage (#17224)
ed9aa56b37 is described below

commit ed9aa56b373c60acef151d4defac44e3c2360a0a
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Thu Aug 22 11:26:27 2024 -0500

    [Relax][Analysis] Handle recursive functions in CollectVarUsage (#17224)
    
    * [Relax][Analysis] Handle recursive functions in CollectVarUsage
    
    Prior to this commit, the `relax::analysis::CollectVarUsage` utility
    treated a local function definition as in-scope after visiting the
    body of the local function.  As a result, recursive calls from a local
    function were incorrectly identified as calls to an undefined
    variable.
    
    This commit updates the `CollectVarUsage` to treat a local function
    definition as in-scope when inspecting the function body.  This change
    is similar to the change made for structural equality in
    https://github.com/apache/tvm/pull/16756.
    
    * lint fixes
---
 src/relax/analysis/udchain.cc                      | 21 +++++-
 .../relax/test_transform_dead_code_elimination.py  | 81 ++++++++++++++++++++++
 2 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index d7ab4f1031..65e15a4161 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -55,6 +55,7 @@ class UDChain : relax::ExprVisitor {
 
  private:
   Map<Var, Expr> bound_values;
+  std::unordered_set<Var> forward_declarations;
   std::unordered_map<Var, support::OrderedSet<Var>> usage_map;
   support::OrderedSet<Var> outputs;
 
@@ -71,9 +72,20 @@ class UDChain : relax::ExprVisitor {
     cur_user_ = cache;
   }
 
+  void VisitBinding_(const VarBindingNode* binding, const FunctionNode* func) 
override {
+    // A local Relax function may be recursively defined.  References to
+    // `binding->var` that appear within `func` are valid.
+    DefineVar(binding->var);
+    forward_declarations.insert(binding->var);
+    ExprVisitor::VisitBinding_(binding, func);
+  }
+
   void VisitVarDef(const Var& var) override {
-    CHECK(!usage_map.count(var)) << "Variable " << var << " was used before 
its definition";
-    usage_map[var] = {};
+    if (forward_declarations.count(var)) {
+      forward_declarations.erase(var);
+    } else {
+      DefineVar(var);
+    }
   }
   void VisitExpr_(const VarNode* op) override {
     auto var = GetRef<Var>(op);
@@ -89,6 +101,11 @@ class UDChain : relax::ExprVisitor {
     cur_user_ = nullptr;
     ExprVisitor::VisitExpr_(op);
   }
+
+  void DefineVar(const Var& var) {
+    CHECK(!usage_map.count(var)) << "Variable " << var << " was used before 
its definition";
+    usage_map[var] = {};
+  }
 };
 
 std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> 
FunctionUseDef(
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 142faf5160..6546d09777 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -658,5 +658,86 @@ def test_well_formed_output_with_restricted_scope():
     tvm.ir.assert_structural_equal(Expected, After)
 
 
+def test_recursively_defined_lambda():
+    """DCE may be applied to recursively-defined functions
+
+    While most expressions may only contain references to
+    previously-defined variables, local Relax function definitions may
+    contain references to themselves.
+
+    This is a regression test.  In previous implementations, the
+    recursive use of `while_loop` resulted in an error, as
+    `while_loop` was not considered in-scope by the `CollectVarUsage`
+    utility until after the body of `while_loop` had been visited.
+
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+            @R.function
+            def while_loop(
+                i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
+            ) -> R.Tensor((2, 3), "float32"):
+                cond = R.call_pure_packed(
+                    "test.vm.less", i, R.const(10), sinfo_args=R.Tensor((), 
dtype="bool")
+                )
+                c = R.const(1, dtype="int32")
+                if cond:
+                    new_i = R.add(i, c)
+                    new_s = R.add(s, x)
+                    r = while_loop(new_i, new_s)
+                else:
+                    r = s
+                return r
+
+            gv = while_loop(R.const(0), x)
+            return gv
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
+def test_recursively_defined_closure():
+    """DCE may be applied to recursively-defined closures
+
+    This test is identical to `test_recursively_defined_lambda`,
+    except that the threshold for recursion is defined in an enclosed
+    variable outside of the recursive function.
+
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor:
+            threshold = R.const(10)
+
+            @R.function
+            def while_loop(
+                i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32")
+            ) -> R.Tensor((2, 3), "float32"):
+                cond = R.call_pure_packed(
+                    "test.vm.less", i, threshold, sinfo_args=R.Tensor((), 
dtype="bool")
+                )
+                c = R.const(1, dtype="int32")
+                if cond:
+                    new_i = R.add(i, c)
+                    new_s = R.add(s, x)
+                    r = while_loop(new_i, new_s)
+                else:
+                    r = s
+                return r
+
+            gv = while_loop(R.const(0), x)
+            return gv
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to