This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7b9d3d9dc6 [Relax] Move GetUsedVars to analysis module (#18632)
7b9d3d9dc6 is described below

commit 7b9d3d9dc6c88412c2442983d5b172c280972536
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Mon Jan 5 10:24:21 2026 +0800

    [Relax] Move GetUsedVars to analysis module (#18632)
    
    ## Why
    
    The GetUsedVars function was defined locally in binding_rewrite.cc with
    a TODO comment suggesting it should be moved to the analysis module.
    This refactoring improves code organization by placing the utility
    function alongside other variable analysis functions.
    
    ## How
    - Move GetUsedVars implementation to analysis module
    - Add FFI registration and Python wrapper
    - Add parametrized test
---
 include/tvm/relax/analysis.h          | 12 ++++++++++++
 python/tvm/relax/analysis/__init__.py |  1 +
 python/tvm/relax/analysis/analysis.py | 20 ++++++++++++++++++++
 src/relax/analysis/udchain.cc         | 21 ++++++++++++++++++++-
 src/relax/ir/binding_rewrite.cc       | 12 +-----------
 tests/python/relax/test_analysis.py   | 23 +++++++++++++++++++++++
 6 files changed, 77 insertions(+), 12 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 73d1a3dbeb..b2f9463289 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -34,6 +34,7 @@
 #include <tvm/tir/index_map.h>
 
 #include <functional>
+#include <set>
 #include <utility>
 
 namespace tvm {
@@ -494,6 +495,17 @@ struct VarUsageInfo {
  */
 VarUsageInfo CollectVarUsage(const Expr& expr);
 
+/*!
+ * \brief Get the used variables in an expression.
+ *
+ * This function collects all variables that are referenced within the given 
expression.
+ *
+ * \param expr The expression to analyze
+ *
+ * \return A set of variable nodes that are used in the expression
+ */
+TVM_DLL std::set<const VarNode*> GetUsedVars(const Expr& expr);
+
 /*!
  * \brief Remove unused statements inside DataflowBlocks.
  *
diff --git a/python/tvm/relax/analysis/__init__.py 
b/python/tvm/relax/analysis/__init__.py
index 592e3bb5db..7e267b0f78 100644
--- a/python/tvm/relax/analysis/__init__.py
+++ b/python/tvm/relax/analysis/__init__.py
@@ -32,6 +32,7 @@ from .analysis import (
     free_symbolic_vars,
     free_vars,
     get_static_type,
+    used_vars,
     get_var2val,
     has_reshape_pattern,
     name_to_binding,
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index af0772ea6c..8d40d3d427 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -312,6 +312,26 @@ def all_vars(expr: Expr) -> List[Var]:
     return _ffi_api.all_vars(expr)
 
 
+def used_vars(expr: Expr) -> List[Var]:
+    """
+    Return all variables used in an expression.
+
+    This function collects all variable references within the given expression,
+    which is useful for analyzing variable dependencies.
+
+    Parameters
+    ----------
+    expr: Expr
+        The expression to analyze.
+
+    Returns
+    -------
+    ret: List[Var]
+        List of variables used in the expression.
+    """
+    return _ffi_api.used_vars(expr)  # type: ignore
+
+
 def all_global_vars(expr: Expr) -> List[GlobalVar]:
     """
     Return all global variables from expression expr.
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index bbdbb7b644..fcd628f606 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -121,10 +121,29 @@ ffi::Map<Var, ffi::Array<Var>> DataflowBlockUseDef(const 
DataflowBlock& dfb) {
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("relax.analysis.udchain", DataflowBlockUseDef);
+  refl::GlobalDef()
+      .def("relax.analysis.udchain", DataflowBlockUseDef)
+      .def("relax.analysis.used_vars", [](const Expr& expr) {
+        auto used_vars = GetUsedVars(expr);
+        ffi::Array<Var> result;
+        for (const VarNode* var_node : used_vars) {
+          result.push_back(ffi::GetRef<Var>(var_node));
+        }
+        return result;
+      });
 }
 
 VarUsageInfo CollectVarUsage(const Expr& expr) { return 
UDChain::Collect(expr); }
 
+std::set<const VarNode*> GetUsedVars(const Expr& expr) {
+  class UsedVars : public ExprVisitor {
+   public:
+    std::set<const VarNode*> used_vars;
+    void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
+  } visitor;
+  visitor.VisitExpr(expr);
+  return std::move(visitor.used_vars);
+}
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index 0bbfef31b8..a8dcf78155 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -23,6 +23,7 @@
  */
 
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/relax/analysis.h>
 #include <tvm/relax/binding_rewrite.h>
 #include <tvm/relax/block_builder.h>
 #include <tvm/relax/expr.h>
@@ -134,17 +135,6 @@ class UpdateDFB : public ExprMutator {
   }
 };
 
-// TODO(masahi): Consider moving this to analysis
-std::set<const VarNode*> GetUsedVars(Expr val) {
-  class UsedVars : public ExprVisitor {
-   public:
-    std::set<const VarNode*> used_vars;
-    void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
-  } uvar{};
-  uvar.VisitExpr(val);
-  return std::move(uvar.used_vars);
-}
-
 void DataflowBlockRewriteNode::Add(Binding binding) {
   auto [var, val] = [binding] {
     if (auto vb = binding.as<VarBindingNode>()) {
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
index 9f5c200cde..2845622bbe 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -17,6 +17,7 @@
 
 from typing import List, Set, Union
 
+import pytest
 import tvm
 import tvm.testing
 from tvm import relax as rx
@@ -26,6 +27,7 @@ from tvm.relax.analysis import (
     all_vars,
     bound_vars,
     free_vars,
+    used_vars,
     has_reshape_pattern,
     name_to_binding,
     remove_all_unused,
@@ -61,6 +63,27 @@ def test_use_def():
     assert set(udc[gv0]) == set()
 
 
[email protected](
+    "expr_fn, expected_var_names",
+    [
+        (lambda x, y, z: rx.op.add(x, y), {"x", "y"}),
+        (lambda x, y, z: rx.op.multiply(x, x), {"x"}),
+        (lambda x, y, z: rx.Tuple([x, y, z]), {"x", "y", "z"}),
+    ],
+    ids=["binary_op", "self_reference", "tuple"],
+)
+def test_used_vars(expr_fn, expected_var_names):
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x = rx.Var("x", R.Tensor([m, n], "float16"))
+    y = rx.Var("y", R.Tensor([n], "float16"))
+    z = rx.Var("z", R.Tensor([m], "float16"))
+
+    expr = expr_fn(x, y, z)
+    result = used_vars(expr)
+    assert var_name_set(result) == expected_var_names
+
+
 def test_chained_remove_all_unused():
     @tvm.script.ir_module
     class IdentityUnused:

Reply via email to