This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 148737b1ff [IR] Compact Functor vtable (#17731)
148737b1ff is described below

commit 148737b1ffe194645cb25fb810296c2edc8ef345
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 11 10:40:29 2025 -0400

    [IR] Compact Functor vtable (#17731)
    
    This PR add a finalize routine to optionally compact functor vtable 
dynamically.
    Also updates child_slots for key types to make sure the IR node type
    index stay within range and such compact happens.
---
 include/tvm/arith/iter_affine_map.h          |  2 +-
 include/tvm/ir/expr.h                        |  4 ++--
 include/tvm/ir/type_functor.h                |  1 +
 include/tvm/node/functor.h                   | 28 +++++++++++++++++++++++++++-
 include/tvm/relax/dataflow_pattern.h         |  2 ++
 include/tvm/relax/dataflow_pattern_functor.h |  2 +-
 include/tvm/relax/expr.h                     |  4 ++--
 include/tvm/relax/expr_functor.h             |  1 +
 include/tvm/relax/struct_info_functor.h      |  1 +
 include/tvm/tir/expr_functor.h               |  1 +
 include/tvm/tir/stmt_functor.h               |  1 +
 src/ir/attr_functor.h                        |  1 +
 src/relax/ir/py_expr_functor.cc              |  3 +++
 src/runtime/object.cc                        | 10 +++++++++-
 14 files changed, 53 insertions(+), 8 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h 
b/include/tvm/arith/iter_affine_map.h
index 53c5b32dd2..d2a6f9a745 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -69,7 +69,7 @@ class IterMapExprNode : public PrimExprNode {
   void VisitAttrs(tvm::AttrVisitor* v) {}
 
   static constexpr const char* _type_key = "arith.IterMapExpr";
-  static constexpr const uint32_t _type_child_slots = 3;
+  static constexpr const uint32_t _type_child_slots = 2;
   TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
 };
 
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index b3b4e8ab32..53af269756 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -58,7 +58,7 @@ class BaseExprNode : public Object {
   static constexpr const char* _type_key = "BaseExpr";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
-  static constexpr const uint32_t _type_child_slots = 62;
+  static constexpr const uint32_t _type_child_slots = 64;
   TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
 };
 
@@ -104,7 +104,7 @@ class PrimExprNode : public BaseExprNode {
   TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
 
   static constexpr const char* _type_key = "PrimExpr";
-  static constexpr const uint32_t _type_child_slots = 38;
+  static constexpr const uint32_t _type_child_slots = 40;
   TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
 };
 
diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h
index 2c145e480b..858226354c 100644
--- a/include/tvm/ir/type_functor.h
+++ b/include/tvm/ir/type_functor.h
@@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
     TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
     TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
     TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h
index 58d59c81cb..82ea37566e 100644
--- a/include/tvm/node/functor.h
+++ b/include/tvm/node/functor.h
@@ -26,6 +26,7 @@
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/object.h>
 
+#include <cstring>
 #include <type_traits>
 #include <utility>
 #include <vector>
@@ -72,6 +73,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
   using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
   /*! \brief internal function table */
   std::vector<FPointer> func_;
+  /*! \brief start range of func index */
+  uint32_t begin_type_index_{0};
 
  public:
   /*! \brief the result type of this functor */
@@ -83,6 +86,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
    */
   bool can_dispatch(const ObjectRef& n) const {
     uint32_t type_index = n->type_index();
+    if (type_index < begin_type_index_) return false;
+    type_index -= begin_type_index_;
     return type_index < func_.size() && func_[type_index] != nullptr;
   }
   /*!
@@ -94,7 +99,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
   R operator()(const ObjectRef& n, Args... args) const {
     ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on 
type "
                             << n->GetTypeKey();
-    return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
+    return (*func_[n->type_index() - begin_type_index_])(n, 
std::forward<Args>(args)...);
   }
   /*!
    * \brief set the dispatcher for type TNode
@@ -109,6 +114,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
       func_.resize(tindex + 1, nullptr);
     }
     ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << 
" is already set";
+    ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after 
calling Finalize";
     func_[tindex] = f;
     return *this;
   }
@@ -122,9 +128,29 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
   TSelf& clear_dispatch() {  // NOLINT(*)
     uint32_t tindex = TNode::RuntimeTypeIndex();
     ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
+    ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after 
calling Finalize";
     func_[tindex] = nullptr;
     return *this;
   }
+  /*!
+   * \brief Finalize the functor after calling sequence of set_dispatch
+   * This function will attempt to find the min type index that is not null
+   * and optimize the space of the func table so it is more compact
+   */
+  void Finalize() {
+    ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once";
+    while (begin_type_index_ < func_.size() && func_[begin_type_index_] == 
nullptr) {
+      ++begin_type_index_;
+    }
+    // shift up the function value
+    size_t new_ftable_size = func_.size() - begin_type_index_;
+    if (begin_type_index_ != 0) {
+      std::memmove(func_.data(), func_.data() + begin_type_index_,
+                   new_ftable_size * sizeof(FPointer));
+    }
+    func_.resize(new_ftable_size);
+    func_.shrink_to_fit();
+  }
 };
 
 #define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& 
__make_functor##_##ClsName
diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index df9fdcad97..b3bbebd0e0 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -91,6 +91,7 @@ TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const 
PatternSeq& rhs);
 class DFPatternNode : public Object {
  public:
   static constexpr const char* _type_key = "DFPatternNode";
+  static constexpr const uint32_t _type_child_slots = 21;
   TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
 };
 
@@ -373,6 +374,7 @@ class VarPatternNode : public DFPatternNode {
   void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }
 
   static constexpr const char* _type_key = "relax.dpl.VarPattern";
+  static constexpr const uint32_t _type_child_slots = 1;
   TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode);
 };
 
diff --git a/include/tvm/relax/dataflow_pattern_functor.h 
b/include/tvm/relax/dataflow_pattern_functor.h
index bbdda44213..fb67f3cc4a 100644
--- a/include/tvm/relax/dataflow_pattern_functor.h
+++ b/include/tvm/relax/dataflow_pattern_functor.h
@@ -135,12 +135,12 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
-
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode);
     RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index fb6f0e40b1..330ff7e8da 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -119,7 +119,7 @@ class StructInfoNode : public Object {
   static constexpr const char* _type_key = "StructInfo";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
-  static constexpr const uint32_t _type_child_slots = 5;
+  static constexpr const uint32_t _type_child_slots = 7;
   TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object);
 };
 
@@ -416,7 +416,7 @@ class VarNode : public LeafExprNode {
   static constexpr const char* _type_key = "relax.expr.Var";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
-  static constexpr const uint32_t _type_child_slots = 2;
+  static constexpr const uint32_t _type_child_slots = 1;
   TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode);
 };
 
diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h
index cdc09c4431..4904b02960 100644
--- a/include/tvm/relax/expr_functor.h
+++ b/include/tvm/relax/expr_functor.h
@@ -176,6 +176,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
     RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode);
     RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode);
     RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/include/tvm/relax/struct_info_functor.h 
b/include/tvm/relax/struct_info_functor.h
index 8418b48dc1..2ce5627547 100644
--- a/include/tvm/relax/struct_info_functor.h
+++ b/include/tvm/relax/struct_info_functor.h
@@ -108,6 +108,7 @@ class StructInfoFunctor<R(const StructInfo& n, Args...)> {
     TVM_STRUCT_INFO_FUNCTOR_DISPATCH(distributed::DTensorStructInfoNode);
     TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode);
     TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h
index 3f66164b42..7a9cf91a65 100644
--- a/include/tvm/tir/expr_functor.h
+++ b/include/tvm/tir/expr_functor.h
@@ -193,6 +193,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
     IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
     IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
     IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index c5b20f8ec0..e9a41468d3 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -126,6 +126,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
     IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
     IR_STMT_FUNCTOR_DISPATCH(BlockNode);
     IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h
index 12b4f6f65b..008e63fffc 100644
--- a/src/ir/attr_functor.h
+++ b/src/ir/attr_functor.h
@@ -139,6 +139,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
     ATTR_FUNCTOR_DISPATCH(CastNode);
     ATTR_FUNCTOR_DISPATCH(CallNode);
     ATTR_FUNCTOR_DISPATCH(SelectNode);
+    vtable.Finalize();
     return vtable;
   }
 };
diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc
index a7ac245610..eb286b4ef6 100644
--- a/src/relax/ir/py_expr_functor.cc
+++ b/src/relax/ir/py_expr_functor.cc
@@ -161,6 +161,7 @@ class PyExprVisitorNode : public Object, public ExprVisitor 
{
     PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_);
     PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_);
     PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_);
+    vtable.Finalize();
     return vtable;
   }
 };
@@ -414,6 +415,7 @@ class PyExprMutatorNode : public Object, public ExprMutator 
{
     PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_);
     PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_);
     PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_);
+    vtable.Finalize();
     return vtable;
   }
 
@@ -437,6 +439,7 @@ class PyExprMutatorNode : public Object, public ExprMutator 
{
     PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode);
     PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode);
     PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode);
+    post_order_vtable.Finalize();
     return post_order_vtable;
   }
 };
diff --git a/src/runtime/object.cc b/src/runtime/object.cc
index 05bfd6d1cf..85ec4f0360 100644
--- a/src/runtime/object.cc
+++ b/src/runtime/object.cc
@@ -170,10 +170,17 @@ class TypeContext {
 
   void Dump(int min_children_count) {
     std::vector<int> num_children(type_table_.size(), 0);
+    // expected child slots compute the expected slots
+    // based on the current child slot setting
+    std::vector<int> expected_child_slots(type_table_.size(), 0);
     // reverse accumulation so we can get total counts in a bottom-up manner.
     for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
       if (it->index != 0) {
         num_children[it->parent_index] += num_children[it->index] + 1;
+        if (static_cast<uint32_t>(expected_child_slots[it->index] + 1) < 
it->num_slots) {
+          expected_child_slots[it->index] = it->num_slots - 1;
+        }
+        expected_child_slots[it->parent_index] += 
expected_child_slots[it->index] + 1;
       }
     }
 
@@ -182,7 +189,8 @@ class TypeContext {
         std::cerr << '[' << info.index << "] " << info.name
                   << "\tparent=" << type_table_[info.parent_index].name
                   << "\tnum_child_slots=" << info.num_slots - 1
-                  << "\tnum_children=" << num_children[info.index] << 
std::endl;
+                  << "\tnum_children=" << num_children[info.index]
+                  << "\texpected_child_slots=" << 
expected_child_slots[info.index] << std::endl;
       }
     }
   }

Reply via email to