This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 729108cfc4 [REFACTOR][RELAX] Fold CalleeCollector into relax
DeadCodeElimination (#19603)
729108cfc4 is described below
commit 729108cfc49f18ceb0c009ec037a85980985d751
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon May 25 17:45:51 2026 -0400
[REFACTOR][RELAX] Fold CalleeCollector into relax DeadCodeElimination
(#19603)
## Summary
The cross-IR `CalleeCollector` abstraction in
`include/tvm/ir/analysis.h`
had a single consumer (relax `DeadCodeElimination`) yet forced its
per-language visitors to live in separate `analysis/` files registered
via a runtime vtable. This PR folds both visitors (relax + tirx)
directly into `src/relax/transform/dead_code_elimination.cc` as
anonymous-namespace helpers and deletes the now-dead abstraction.
The indirection only paid off when multiple unrelated passes shared the
visitor; with one consumer, the cross-TU vtable adds compile cost and
spreads the implementation across three files. Inlining improves
locality without enlarging the consumer's complexity.
---
include/tvm/ir/analysis.h | 63 ---------------
python/tvm/ir/__init__.py | 1 -
python/tvm/ir/_ffi_analysis_api.py | 21 -----
python/tvm/ir/analysis.py | 43 ----------
src/ir/analysis.cc | 53 -------------
src/relax/analysis/collect_call_map.cc | 60 --------------
src/relax/transform/dead_code_elimination.cc | 62 ++++++++++++++-
src/relax/transform/replace_global_vars.cc | 1 -
src/tirx/analysis/collect_call_map.cc | 57 --------------
tests/python/ir/analysis/test_collect_call_map.py | 96 -----------------------
10 files changed, 60 insertions(+), 397 deletions(-)
diff --git a/include/tvm/ir/analysis.h b/include/tvm/ir/analysis.h
deleted file mode 100644
index 3b6d4e5501..0000000000
--- a/include/tvm/ir/analysis.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/ir/analysis.h
- *
- * Analysis routines that must function across multiple IR types for
- * correctness. For example, identifying unused functions, when both TIR
- *
- */
-#ifndef TVM_IR_ANALYSIS_H_
-#define TVM_IR_ANALYSIS_H_
-
-#include <tvm/ffi/container/array.h>
-#include <tvm/ir/expr.h>
-#include <tvm/ir/module.h>
-#include <tvm/ir/node_functor.h>
-
-namespace tvm {
-namespace ir {
-
-class CalleeCollector {
- public:
- /* \brief Functor to be registered for IR types
- *
- * Should be implemented for each `BaseFunc` subclass.
- * Implementation should call `CalleeCollector::Mark` for each
- * `GlobalVar` in the function.
- */
- using FType = NodeFunctor<void(const ffi::ObjectRef&, CalleeCollector*)>;
- TVM_DLL static FType& vtable() {
- static FType inst;
- return inst;
- }
-
- virtual ~CalleeCollector() {}
-
- /* \brief Collect the GlobalVar in a function */
- virtual void Mark(GlobalVar gvar) = 0;
-};
-
-ffi::Map<GlobalVar, ffi::Array<GlobalVar>> CollectCallMap(const IRModule& mod);
-
-} // namespace ir
-} // namespace tvm
-
-#endif // TVM_IR_ANALYSIS_H_
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index f721080a93..50073a942a 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -39,5 +39,4 @@ from .module import IRModule
from .op import Op, register_intrin_lowering, register_op_attr
from .type import FuncType, PointerType, PrimType, TupleType, Type
-from . import analysis
from tvm_ffi import Array, Map
diff --git a/python/tvm/ir/_ffi_analysis_api.py
b/python/tvm/ir/_ffi_analysis_api.py
deleted file mode 100644
index 6fe16a4e15..0000000000
--- a/python/tvm/ir/_ffi_analysis_api.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""FFI APIs for tvm.ir.analysis"""
-
-import tvm_ffi
-
-tvm_ffi.init_ffi_api("ir.analysis", __name__)
diff --git a/python/tvm/ir/analysis.py b/python/tvm/ir/analysis.py
deleted file mode 100644
index 2baf41f8e0..0000000000
--- a/python/tvm/ir/analysis.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=unused-import
-
-"""Common analysis across all IR variants."""
-
-import tvm
-
-from . import _ffi_analysis_api as _ffi
-
-
-def collect_call_map(
- module: "tvm.ir.IRModule",
-) -> dict["tvm.ir.GlobalVar", list["tvm.ir.GlobalVar"]]:
- """Collect the call map of a module
-
- Parameters
- ----------
- module: tvm.ir.IRModule
- The module to inspect
-
- Returns
- -------
- call_map: Dict[tvm.ir.GlobalVar, List[tvm.ir.GlobalVar]]
- A map from functions to the subroutines they call.
-
- """
- return _ffi.CollectCallMap(module)
diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc
deleted file mode 100644
index da35d87b25..0000000000
--- a/src/ir/analysis.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/ir/analysis.cc
- * \brief Analysis functions that must span multiple IR types
- */
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/analysis.h>
-
-#include "../support/ordered_set.h"
-
-namespace tvm {
-namespace ir {
-
-ffi::Map<GlobalVar, ffi::Array<GlobalVar>> CollectCallMap(const IRModule& mod)
{
- struct CalleeCollectorImpl : CalleeCollector {
- void Mark(GlobalVar gvar) override { gvars.push_back(gvar); }
- support::OrderedSet<GlobalVar, ffi::ObjectPtrHash, ffi::ObjectPtrEqual>
gvars;
- };
-
- ffi::Map<GlobalVar, ffi::Array<GlobalVar>> call_map;
- for (const auto& [gvar, base_func] : mod->functions) {
- CalleeCollectorImpl collector;
- CalleeCollector::vtable()(base_func, &collector);
- call_map.Set(gvar, ffi::Array<GlobalVar>{collector.gvars.begin(),
collector.gvars.end()});
- }
- return call_map;
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("ir.analysis.CollectCallMap", CollectCallMap);
-}
-
-} // namespace ir
-} // namespace tvm
diff --git a/src/relax/analysis/collect_call_map.cc
b/src/relax/analysis/collect_call_map.cc
deleted file mode 100644
index 0e72e4bca8..0000000000
--- a/src/relax/analysis/collect_call_map.cc
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/relax/analysis/collect_call_map.cc
- *
- * \brief Collect cross-IR call graph
- */
-
-#include <tvm/ffi/cast.h>
-#include <tvm/ir/analysis.h>
-#include <tvm/relax/analysis.h>
-#include <tvm/relax/expr_functor.h>
-#include <tvm/tirx/expr_functor.h>
-
-namespace tvm {
-namespace relax {
-
-namespace {
-using ir::CalleeCollector;
-
-struct Visitor : ExprVisitor {
- explicit Visitor(CalleeCollector* collector) : collector(collector) {}
- CalleeCollector* collector;
- void VisitExpr_(const GlobalVarNode* node) override {
- collector->Mark(ffi::GetRef<GlobalVar>(node));
- }
-};
-
-} // namespace
-
-TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable)
- .set_dispatch<relax::FunctionNode>([](const ffi::ObjectRef& func,
CalleeCollector* collector) {
- Visitor visitor{collector};
- visitor(Downcast<Function>(func));
- });
-
-TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable)
- .set_dispatch<relax::ExternFuncNode>([](const ffi::ObjectRef& func,
- CalleeCollector* collector) {});
-
-} // namespace relax
-} // namespace tvm
diff --git a/src/relax/transform/dead_code_elimination.cc
b/src/relax/transform/dead_code_elimination.cc
index fbb077ddf9..c9869e1509 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -33,19 +33,77 @@
*/
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/analysis.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
+#include <tvm/tirx/expr_functor.h>
+#include <tvm/tirx/function.h>
+#include <tvm/tirx/stmt_functor.h>
+
+#include <unordered_set>
#include "utils.h"
namespace tvm {
namespace relax {
+namespace {
+
+struct RelaxCalleeCollector : relax::ExprVisitor {
+ std::vector<GlobalVar>* callees;
+ explicit RelaxCalleeCollector(std::vector<GlobalVar>* out) : callees(out) {}
+ void VisitExpr_(const GlobalVarNode* node) final {
+ callees->push_back(ffi::GetRef<GlobalVar>(node));
+ }
+};
+
+struct TIRxCalleeCollector : tirx::StmtExprVisitor {
+ std::vector<GlobalVar>* callees;
+ explicit TIRxCalleeCollector(std::vector<GlobalVar>* out) : callees(out) {}
+ void VisitExpr_(const tirx::CallNode* node) final {
+ tirx::StmtExprVisitor::VisitExpr_(node);
+ if (auto opt_gvar = node->op.as<GlobalVar>()) {
+ callees->push_back(opt_gvar.value());
+ }
+ }
+};
+
+// Collect the GlobalVars directly called by `func`. Dedups while
+// preserving first-encounter order (same semantics the old
+// support::OrderedSet path provided).
+ffi::Array<GlobalVar> CollectCallees(const BaseFunc& func) {
+ std::vector<GlobalVar> raw;
+ if (auto opt = func.as<relax::Function>()) {
+ RelaxCalleeCollector visitor(&raw);
+ visitor(opt.value());
+ } else if (func.as<relax::ExternFunc>()) {
+ // no callees
+ } else if (auto opt = func.as<tirx::PrimFunc>()) {
+ TIRxCalleeCollector visitor(&raw);
+ visitor(opt.value()->body);
+ }
+ // dedup preserving order
+ ffi::Array<GlobalVar> result;
+ std::unordered_set<GlobalVar, ffi::ObjectPtrHash, ffi::ObjectPtrEqual> seen;
+ for (const auto& gv : raw) {
+ if (seen.insert(gv).second) result.push_back(gv);
+ }
+ return result;
+}
+
+ffi::Map<GlobalVar, ffi::Array<GlobalVar>> CollectCallMap(const IRModule& mod)
{
+ ffi::Map<GlobalVar, ffi::Array<GlobalVar>> call_map;
+ for (const auto& [gvar, base_func] : mod->functions) {
+ call_map.Set(gvar, CollectCallees(base_func));
+ }
+ return call_map;
+}
+
+} // namespace
+
IRModule RemoveUnusedFunctions(IRModule mod, const
std::unordered_set<GlobalVar>& entry_funcs) {
- auto call_map = ir::CollectCallMap(mod);
+ auto call_map = CollectCallMap(mod);
std::unordered_set<GlobalVar> reachable = entry_funcs;
std::vector<GlobalVar> to_visit(entry_funcs.begin(), entry_funcs.end());
diff --git a/src/relax/transform/replace_global_vars.cc
b/src/relax/transform/replace_global_vars.cc
index 6291663496..f895cd50eb 100644
--- a/src/relax/transform/replace_global_vars.cc
+++ b/src/relax/transform/replace_global_vars.cc
@@ -25,7 +25,6 @@
*/
#include <tvm/ffi/cast.h>
-#include <tvm/ir/analysis.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
diff --git a/src/tirx/analysis/collect_call_map.cc
b/src/tirx/analysis/collect_call_map.cc
deleted file mode 100644
index 210bc5aa92..0000000000
--- a/src/tirx/analysis/collect_call_map.cc
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/tirx/analysis/collect_call_map.cc
- *
- * \brief Collect cross-IR call graph
- */
-
-#include <tvm/ir/analysis.h>
-#include <tvm/tirx/function.h>
-#include <tvm/tirx/stmt_functor.h>
-
-namespace tvm {
-namespace tirx {
-
-namespace {
-using ir::CalleeCollector;
-
-struct Visitor : StmtExprVisitor {
- explicit Visitor(CalleeCollector* collector) : collector(collector) {}
- CalleeCollector* collector;
- void VisitExpr_(const CallNode* node) override {
- StmtExprVisitor::VisitExpr_(node);
- if (auto opt_gvar = node->op.as<GlobalVar>()) {
- collector->Mark(opt_gvar.value());
- }
- }
-};
-
-} // namespace
-
-TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable)
- .set_dispatch<tirx::PrimFuncNode>([](const ffi::ObjectRef& func,
CalleeCollector* collector) {
- Visitor visitor{collector};
- visitor(Downcast<tirx::PrimFunc>(func)->body);
- });
-
-} // namespace tirx
-} // namespace tvm
diff --git a/tests/python/ir/analysis/test_collect_call_map.py
b/tests/python/ir/analysis/test_collect_call_map.py
deleted file mode 100644
index 215842bbf9..0000000000
--- a/tests/python/ir/analysis/test_collect_call_map.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import tvm
-import tvm.testing
-from tvm.ir import GlobalVar
-from tvm.ir.analysis import collect_call_map
-from tvm.script import ir as I
-from tvm.script import relax as R
-from tvm.script import tirx as T
-
-
-def _build_str_map(call_map: dict[GlobalVar, list[GlobalVar]]) -> dict[str,
list[str]]:
- return {
- caller.name_hint: [callee.name_hint for callee in callees]
- for caller, callees in call_map.items()
- }
-
-
-def test_collect_relax_to_relax():
- @I.ir_module
- class Module:
- @R.function
- def main():
- return Module.subroutine()
-
- @R.function
- def subroutine():
- return R.tuple()
-
- call_map = collect_call_map(Module)
- str_map = _build_str_map(call_map)
- expected = {
- "main": ["subroutine"],
- "subroutine": [],
- }
- assert str_map == expected
-
-
-def test_collect_relax_to_tir():
- @I.ir_module
- class Module:
- @R.function
- def main() -> R.Prim("int32"):
- return Module.subroutine(R.prim_value(T.int32(42)))
-
- @T.prim_func(s_tir=True)
- def subroutine(i: T.int32) -> T.int32:
- return i + 1
-
- call_map = collect_call_map(Module)
- str_map = _build_str_map(call_map)
- expected = {
- "main": ["subroutine"],
- "subroutine": [],
- }
- assert str_map == expected
-
-
-def test_collect_tir_to_tir():
- @I.ir_module
- class Module:
- @T.prim_func(s_tir=True)
- def main() -> T.int32:
- return Module.subroutine(42)
-
- @T.prim_func(s_tir=True)
- def subroutine(i: T.int32) -> T.int32:
- return i + 1
-
- call_map = collect_call_map(Module)
- str_map = _build_str_map(call_map)
- expected = {
- "main": ["subroutine"],
- "subroutine": [],
- }
- assert str_map == expected
-
-
-if __name__ == "__main__":
- tvm.testing.main()