This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7b9d3d9dc6 [Relax] Move GetUsedVars to analysis module (#18632)
7b9d3d9dc6 is described below
commit 7b9d3d9dc6c88412c2442983d5b172c280972536
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Jan 5 10:24:21 2026 +0800
[Relax] Move GetUsedVars to analysis module (#18632)
## Why
The GetUsedVars function was defined locally in binding_rewrite.cc with
a TODO comment suggesting it should be moved to the analysis module.
This refactoring improves code organization by placing the utility
function alongside other variable analysis functions.
## How
- Move GetUsedVars implementation to analysis module
- Add FFI registration and Python wrapper
- Add parametrized test
---
include/tvm/relax/analysis.h | 12 ++++++++++++
python/tvm/relax/analysis/__init__.py | 1 +
python/tvm/relax/analysis/analysis.py | 20 ++++++++++++++++++++
src/relax/analysis/udchain.cc | 21 ++++++++++++++++++++-
src/relax/ir/binding_rewrite.cc | 12 +-----------
tests/python/relax/test_analysis.py | 23 +++++++++++++++++++++++
6 files changed, 77 insertions(+), 12 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 73d1a3dbeb..b2f9463289 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -34,6 +34,7 @@
#include <tvm/tir/index_map.h>
#include <functional>
+#include <set>
#include <utility>
namespace tvm {
@@ -494,6 +495,17 @@ struct VarUsageInfo {
*/
VarUsageInfo CollectVarUsage(const Expr& expr);
+/*!
+ * \brief Get the used variables in an expression.
+ *
+ * This function collects all variables that are referenced within the given
expression.
+ *
+ * \param expr The expression to analyze
+ *
+ * \return A set of variable nodes that are used in the expression
+ */
+TVM_DLL std::set<const VarNode*> GetUsedVars(const Expr& expr);
+
/*!
* \brief Remove unused statements inside DataflowBlocks.
*
diff --git a/python/tvm/relax/analysis/__init__.py
b/python/tvm/relax/analysis/__init__.py
index 592e3bb5db..7e267b0f78 100644
--- a/python/tvm/relax/analysis/__init__.py
+++ b/python/tvm/relax/analysis/__init__.py
@@ -32,6 +32,7 @@ from .analysis import (
free_symbolic_vars,
free_vars,
get_static_type,
+ used_vars,
get_var2val,
has_reshape_pattern,
name_to_binding,
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index af0772ea6c..8d40d3d427 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -312,6 +312,26 @@ def all_vars(expr: Expr) -> List[Var]:
return _ffi_api.all_vars(expr)
+def used_vars(expr: Expr) -> List[Var]:
+ """
+ Return all variables used in an expression.
+
+ This function collects all variable references within the given expression,
+ which is useful for analyzing variable dependencies.
+
+ Parameters
+ ----------
+ expr: Expr
+ The expression to analyze.
+
+ Returns
+ -------
+ ret: List[Var]
+ List of variables used in the expression.
+ """
+ return _ffi_api.used_vars(expr) # type: ignore
+
+
def all_global_vars(expr: Expr) -> List[GlobalVar]:
"""
Return all global variables from expression expr.
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index bbdbb7b644..fcd628f606 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -121,10 +121,29 @@ ffi::Map<Var, ffi::Array<Var>> DataflowBlockUseDef(const
DataflowBlock& dfb) {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef);
+ refl::GlobalDef()
+ .def("relax.analysis.udchain", DataflowBlockUseDef)
+ .def("relax.analysis.used_vars", [](const Expr& expr) {
+ auto used_vars = GetUsedVars(expr);
+ ffi::Array<Var> result;
+ for (const VarNode* var_node : used_vars) {
+ result.push_back(ffi::GetRef<Var>(var_node));
+ }
+ return result;
+ });
}
VarUsageInfo CollectVarUsage(const Expr& expr) { return
UDChain::Collect(expr); }
+std::set<const VarNode*> GetUsedVars(const Expr& expr) {
+ class UsedVars : public ExprVisitor {
+ public:
+ std::set<const VarNode*> used_vars;
+ void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
+ } visitor;
+ visitor.VisitExpr(expr);
+ return std::move(visitor.used_vars);
+}
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index 0bbfef31b8..a8dcf78155 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/relax/analysis.h>
#include <tvm/relax/binding_rewrite.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
@@ -134,17 +135,6 @@ class UpdateDFB : public ExprMutator {
}
};
-// TODO(masahi): Consider moving this to analysis
-std::set<const VarNode*> GetUsedVars(Expr val) {
- class UsedVars : public ExprVisitor {
- public:
- std::set<const VarNode*> used_vars;
- void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
- } uvar{};
- uvar.VisitExpr(val);
- return std::move(uvar.used_vars);
-}
-
void DataflowBlockRewriteNode::Add(Binding binding) {
auto [var, val] = [binding] {
if (auto vb = binding.as<VarBindingNode>()) {
diff --git a/tests/python/relax/test_analysis.py
b/tests/python/relax/test_analysis.py
index 9f5c200cde..2845622bbe 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -17,6 +17,7 @@
from typing import List, Set, Union
+import pytest
import tvm
import tvm.testing
from tvm import relax as rx
@@ -26,6 +27,7 @@ from tvm.relax.analysis import (
all_vars,
bound_vars,
free_vars,
+ used_vars,
has_reshape_pattern,
name_to_binding,
remove_all_unused,
@@ -61,6 +63,27 @@ def test_use_def():
assert set(udc[gv0]) == set()
[email protected](
+ "expr_fn, expected_var_names",
+ [
+ (lambda x, y, z: rx.op.add(x, y), {"x", "y"}),
+ (lambda x, y, z: rx.op.multiply(x, x), {"x"}),
+ (lambda x, y, z: rx.Tuple([x, y, z]), {"x", "y", "z"}),
+ ],
+ ids=["binary_op", "self_reference", "tuple"],
+)
+def test_used_vars(expr_fn, expected_var_names):
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ x = rx.Var("x", R.Tensor([m, n], "float16"))
+ y = rx.Var("y", R.Tensor([n], "float16"))
+ z = rx.Var("z", R.Tensor([m], "float16"))
+
+ expr = expr_fn(x, y, z)
+ result = used_vars(expr)
+ assert var_name_set(result) == expected_var_names
+
+
def test_chained_remove_all_unused():
@tvm.script.ir_module
class IdentityUnused: