This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 5d0ef94ae4 [Unity][Transform] Track callees from external functions in
DeadCodeElimination (#15561)
5d0ef94ae4 is described below
commit 5d0ef94ae4fe92cdcd991381ff0b78aec3363541
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Aug 25 11:04:06 2023 -0400
[Unity][Transform] Track callees from external functions in
DeadCodeElimination (#15561)
Prior to this commit, the DeadCodeElimination pass avoided removing
externally-exposed functions, only checked for callees from
user-specified functions. If an external function wasn't specified,
the dead-code elimination could remove a callee resulting in a
dangling `GlobalVar`.
This commit updates the dead code elimination to treat
all externally-exposed functions as potential entry points.
User-specified functions can still be provided, and are treated as
additional entry points.
---
python/tvm/relax/transform/transform.py | 2 +-
src/relax/transform/dead_code_elimination.cc | 77 ++++++++++++++++------
.../relax/test_transform_dead_code_elimination.py | 55 +++++++++++++++-
3 files changed, 110 insertions(+), 24 deletions(-)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 6c08a6fe68..5f107315a5 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1039,7 +1039,7 @@ def DeadCodeElimination(entry_functions:
Optional[List[str]] = None) -> tvm.ir.t
The registered pass.
"""
if entry_functions is None:
- entry_functions = ["main"]
+ entry_functions = []
return _ffi_api.DeadCodeElimination(entry_functions) # type: ignore
diff --git a/src/relax/transform/dead_code_elimination.cc
b/src/relax/transform/dead_code_elimination.cc
index fe36eb28ef..494665ec71 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -43,9 +43,9 @@ namespace relax {
/**
* \brief Detects all the functions that can be possibly called by entry
function.
*/
-class CallTracer : ExprVisitor {
+class CallTracer : public ExprVisitor {
public:
- explicit CallTracer(IRModule mod_) : mod_{mod_}, called_funcs_{},
visiting_{} {}
+ explicit CallTracer(IRModule mod) : mod_{mod}, called_funcs_{}, visiting_{}
{}
void VisitExpr_(const GlobalVarNode* op) final {
called_funcs_.insert(GetRef<GlobalVar>(op));
@@ -87,34 +87,71 @@ class CallTracer : ExprVisitor {
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visiting_;
};
-IRModule RemoveUnusedFunctions(IRModule mod_, Array<runtime::String>
entry_funcs) {
- auto tracer = CallTracer(mod_);
- for (auto entry : entry_funcs) {
- tracer.Trace(entry);
+IRModule RemoveUnusedFunctions(
+ IRModule mod, const std::unordered_set<GlobalVar, ObjectPtrHash,
ObjectPtrEqual>& entry_funcs) {
+ CallTracer tracer(mod);
+ for (const auto& gvar : entry_funcs) {
+ tracer.VisitExpr(gvar);
}
- auto existing_functions = mod_->functions;
- for (auto f : existing_functions) {
- // If a function has an external linkage type, we do not remove it.
- // Otherwise, we check the function and remove it if it is not used
anywhere.
- if (f.second->GetLinkageType() == LinkageType::kInternal &&
!tracer.check_if_called(f.first)) {
- mod_->Remove(f.first);
+
+ std::vector<GlobalVar> to_remove;
+ for (const auto& kv : mod->functions) {
+ // The tracer contains all user-provided entry functions, all
+ // externally-callable functions, and anything that is directly or
+ // indirectly accessible from an entry function.
+ if (!tracer.check_if_called(kv.first)) {
+ to_remove.push_back(kv.first);
+ }
+ }
+
+ if (to_remove.size()) {
+ auto write_ptr = mod.CopyOnWrite();
+ for (const auto& gvar : to_remove) {
+ write_ptr->Remove(gvar);
}
}
- return mod_;
+
+ return mod;
}
-IRModule DeadCodeElimination(const IRModule& mod, Array<runtime::String>
entry_functions) {
+IRModule DeadCodeElimination(const IRModule& arg_mod, Array<runtime::String>
entry_function_names) {
+ IRModule mod = arg_mod;
+
+ // S0: Make a list of all user-specified entry functions and
+ // externally-visible entry functions.
+ std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> entry_functions;
+ for (const auto& name : entry_function_names) {
+ entry_functions.insert(mod->GetGlobalVar(name));
+ }
+ for (const auto& [gv, func] : mod->functions) {
+ if (func->GetLinkageType() == LinkageType::kExternal) {
+ entry_functions.insert(gv);
+ }
+ }
+
// S1: remove unused functions to reduce the number of functions to be
analyzed.
- IRModule tmp_mod = RemoveUnusedFunctions(mod, entry_functions);
+ mod = RemoveUnusedFunctions(mod, entry_functions);
+
// S2: remove unused variables in each function.
- for (const auto& gv : tmp_mod->GetGlobalVars()) {
- auto func = tmp_mod->Lookup(gv);
- if (func->IsInstance<FunctionNode>()) {
- tmp_mod->Update(gv, RemoveAllUnused(Downcast<Function>(func)));
+ {
+ IRModule updates;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ if (auto opt = base_func.as<Function>()) {
+ auto new_func = RemoveAllUnused(opt.value());
+ if (!new_func.same_as(base_func)) {
+ updates->Add(gvar, new_func);
+ }
+ }
+ }
+ if (updates->functions.size()) {
+ mod.CopyOnWrite()->Update(updates);
}
}
+
// S3: remove unused functions again as some callers may be removed in S2.
- return RemoveUnusedFunctions(tmp_mod, entry_functions);
+ mod = RemoveUnusedFunctions(mod, entry_functions);
+
+ return mod;
}
namespace transform {
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py
b/tests/python/relax/test_transform_dead_code_elimination.py
index 2559eed34e..7b749d6778 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -188,10 +188,13 @@ def test_unused_relax_func():
assert not check_if_func_exists(new_mod, "unused_func")
-def test_unused_relax_func_custom_entry_func():
+provide_entry_func_name = tvm.testing.parameter(True, False)
+
+
+def test_unused_relax_func_custom_entry_func(provide_entry_func_name):
@tvm.script.ir_module
class InputModule:
- @T.prim_func
+ @T.prim_func(private=True)
def tir_add(
x: T.Buffer((16, 16), "float32"),
y: T.Buffer((16, 16), "float32"),
@@ -217,8 +220,54 @@ def test_unused_relax_func_custom_entry_func():
mod = InputModule
assert mod
+ if provide_entry_func_name:
+ entry_functions = ["foo"]
+ else:
+ entry_functions = None
+
# Test entry function other than "main".
- new_mod = DeadCodeElimination(entry_functions=["foo"])(mod)
+ new_mod = DeadCodeElimination(entry_functions=entry_functions)(mod)
+ assert check_if_func_exists(new_mod, "foo")
+ assert check_if_func_exists(new_mod, "tir_add")
+ assert not check_if_func_exists(new_mod, "unused_func")
+
+
+def test_tracking_through_externally_exposed_func(provide_entry_func_name):
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func(private=True)
+ def tir_add(
+ x: T.Buffer((16, 16), "float32"),
+ y: T.Buffer((16, 16), "float32"),
+ z: T.Buffer((16, 16), "float32"),
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("add"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ z[vi, vj] = x[vi, vj] + y[vi, vj]
+
+ @R.function(private=True)
+ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16,
16), "float32")):
+ gv0 = R.add(x, w)
+ return gv0
+
+ @R.function
+ def foo(
+ x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16),
dtype="float32"))
+ return gv0
+
+ @R.function
+ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16),
"float32"):
+ return x
+
+ mod = InputModule
+ assert mod
+
+ # Test tracking of usage through externally-exposed function
+ new_mod = DeadCodeElimination(entry_functions=["main"])(mod)
+ assert check_if_func_exists(new_mod, "main")
assert check_if_func_exists(new_mod, "foo")
assert check_if_func_exists(new_mod, "tir_add")
assert not check_if_func_exists(new_mod, "unused_func")