This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5d0ef94ae4 [Unity][Transform] Track callees from external functions in 
DeadCodeElimination (#15561)
5d0ef94ae4 is described below

commit 5d0ef94ae4fe92cdcd991381ff0b78aec3363541
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Aug 25 11:04:06 2023 -0400

    [Unity][Transform] Track callees from external functions in 
DeadCodeElimination (#15561)
    
    Prior to this commit, the DeadCodeElimination pass avoided removing
    externally-exposed functions, only checked for callees from
    user-specified functions.  If an external function wasn't specified,
    the dead-code elimination could remove a callee resulting in a
    dangling `GlobalVar`.
    
    This commit updates the dead code elimination to treat
    all externally-exposed functions as potential entry points.
    User-specified functions can still be provided, and are treated as
    additional entry points.
---
 python/tvm/relax/transform/transform.py            |  2 +-
 src/relax/transform/dead_code_elimination.cc       | 77 ++++++++++++++++------
 .../relax/test_transform_dead_code_elimination.py  | 55 +++++++++++++++-
 3 files changed, 110 insertions(+), 24 deletions(-)

diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 6c08a6fe68..5f107315a5 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1039,7 +1039,7 @@ def DeadCodeElimination(entry_functions: 
Optional[List[str]] = None) -> tvm.ir.t
         The registered pass.
     """
     if entry_functions is None:
-        entry_functions = ["main"]
+        entry_functions = []
     return _ffi_api.DeadCodeElimination(entry_functions)  # type: ignore
 
 
diff --git a/src/relax/transform/dead_code_elimination.cc 
b/src/relax/transform/dead_code_elimination.cc
index fe36eb28ef..494665ec71 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -43,9 +43,9 @@ namespace relax {
 /**
  * \brief Detects all the functions that can be possibly called by entry 
function.
  */
-class CallTracer : ExprVisitor {
+class CallTracer : public ExprVisitor {
  public:
-  explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{}, 
visiting_{} {}
+  explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{} 
{}
 
   void VisitExpr_(const GlobalVarNode* op) final {
     called_funcs_.insert(GetRef<GlobalVar>(op));
@@ -87,34 +87,71 @@ class CallTracer : ExprVisitor {
   std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
 };
 
-IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String> 
entry_funcs) {
-  auto tracer = CallTracer(mod_);
-  for (auto entry : entry_funcs) {
-    tracer.Trace(entry);
+IRModule RemoveUnusedFunctions(
+    IRModule mod, const std::unordered_set<GlobalVar, ObjectPtrHash, 
ObjectPtrEqual>& entry_funcs) {
+  CallTracer tracer(mod);
+  for (const auto& gvar : entry_funcs) {
+    tracer.VisitExpr(gvar);
   }
-  auto existing_functions = mod_->functions;
-  for (auto f : existing_functions) {
-    // If a function has an external linkage type, we do not remove it.
-    // Otherwise, we check the function and remove it if it is not used 
anywhere.
-    if (f.second->GetLinkageType() == LinkageType::kInternal && 
!tracer.check_if_called(f.first)) {
-      mod_->Remove(f.first);
+
+  std::vector<GlobalVar> to_remove;
+  for (const auto& kv : mod->functions) {
+    // The tracer contains all user-provided entry functions, all
+    // externally-callable functions, and anything that is directly or
+    // indirectly accessible from an entry function.
+    if (!tracer.check_if_called(kv.first)) {
+      to_remove.push_back(kv.first);
+    }
+  }
+
+  if (to_remove.size()) {
+    auto write_ptr = mod.CopyOnWrite();
+    for (const auto& gvar : to_remove) {
+      write_ptr->Remove(gvar);
     }
   }
-  return mod_;
+
+  return mod;
 }
 
-IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String> 
entry_functions) {
+IRModule DeadCodeElimination(const IRModule& arg_mod, Array<runtime::String> 
entry_function_names) {
+  IRModule mod = arg_mod;
+
+  // S0: Make a list of all user-specified entry functions and
+  // externally-visible entry functions.
+  std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> entry_functions;
+  for (const auto& name : entry_function_names) {
+    entry_functions.insert(mod->GetGlobalVar(name));
+  }
+  for (const auto& [gv, func] : mod->functions) {
+    if (func->GetLinkageType() == LinkageType::kExternal) {
+      entry_functions.insert(gv);
+    }
+  }
+
   // S1: remove unused functions to reduce the number of functions to be 
analyzed.
-  IRModule tmp_mod = RemoveUnusedFunctions(mod, entry_functions);
+  mod = RemoveUnusedFunctions(mod, entry_functions);
+
   // S2: remove unused variables in each function.
-  for (const auto& gv : tmp_mod->GetGlobalVars()) {
-    auto func = tmp_mod->Lookup(gv);
-    if (func->IsInstance<FunctionNode>()) {
-      tmp_mod->Update(gv, RemoveAllUnused(Downcast<Function>(func)));
+  {
+    IRModule updates;
+    for (const auto& [gvar, base_func] : mod->functions) {
+      if (auto opt = base_func.as<Function>()) {
+        auto new_func = RemoveAllUnused(opt.value());
+        if (!new_func.same_as(base_func)) {
+          updates->Add(gvar, new_func);
+        }
+      }
+    }
+    if (updates->functions.size()) {
+      mod.CopyOnWrite()->Update(updates);
     }
   }
+
   // S3: remove unused functions again as some callers may be removed in S2.
-  return RemoveUnusedFunctions(tmp_mod, entry_functions);
+  mod = RemoveUnusedFunctions(mod, entry_functions);
+
+  return mod;
 }
 
 namespace transform {
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 2559eed34e..7b749d6778 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -188,10 +188,13 @@ def test_unused_relax_func():
     assert not check_if_func_exists(new_mod, "unused_func")
 
 
-def test_unused_relax_func_custom_entry_func():
+provide_entry_func_name = tvm.testing.parameter(True, False)
+
+
+def test_unused_relax_func_custom_entry_func(provide_entry_func_name):
     @tvm.script.ir_module
     class InputModule:
-        @T.prim_func
+        @T.prim_func(private=True)
         def tir_add(
             x: T.Buffer((16, 16), "float32"),
             y: T.Buffer((16, 16), "float32"),
@@ -217,8 +220,54 @@ def test_unused_relax_func_custom_entry_func():
     mod = InputModule
     assert mod
 
+    if provide_entry_func_name:
+        entry_functions = ["foo"]
+    else:
+        entry_functions = None
+
     # Test entry function other than "main".
-    new_mod = DeadCodeElimination(entry_functions=["foo"])(mod)
+    new_mod = DeadCodeElimination(entry_functions=entry_functions)(mod)
+    assert check_if_func_exists(new_mod, "foo")
+    assert check_if_func_exists(new_mod, "tir_add")
+    assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_tracking_through_externally_exposed_func(provide_entry_func_name):
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func(private=True)
+        def tir_add(
+            x: T.Buffer((16, 16), "float32"),
+            y: T.Buffer((16, 16), "float32"),
+            z: T.Buffer((16, 16), "float32"),
+        ) -> None:
+            for i, j in T.grid(16, 16):
+                with T.block("add"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+        @R.function(private=True)
+        def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 
16), "float32")):
+            gv0 = R.add(x, w)
+            return gv0
+
+        @R.function
+        def foo(
+            x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+        ) -> R.Tensor((16, 16), "float32"):
+            gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), 
dtype="float32"))
+            return gv0
+
+        @R.function
+        def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), 
"float32"):
+            return x
+
+    mod = InputModule
+    assert mod
+
+    # Test tracking of usage through externally-exposed function
+    new_mod = DeadCodeElimination(entry_functions=["main"])(mod)
+    assert check_if_func_exists(new_mod, "main")
     assert check_if_func_exists(new_mod, "foo")
     assert check_if_func_exists(new_mod, "tir_add")
     assert not check_if_func_exists(new_mod, "unused_func")

Reply via email to