This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 5df4dec4b22cea67771be36f895b76138fdb9266 Author: tqchen <[email protected]> AuthorDate: Sun Mar 9 14:42:32 2025 -0400 pass object and runtime compile --- ffi/include/tvm/ffi/function.h | 2 +- src/runtime/object.cc | 203 ++--------------------------------------- src/runtime/object_internal.h | 10 +- 3 files changed, 15 insertions(+), 200 deletions(-) diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 680758159a..7ba12c3056 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -74,7 +74,7 @@ namespace ffi { if (ret_code == -2) { \ throw ::tvm::ffi::EnvErrorAlreadySet(); \ } \ - Any error_any; \ + ::tvm::ffi::Any error_any; \ TVMFFIMoveFromLastError(reinterpret_cast<TVMFFIAny*>(&error_any)); \ if (std::optional<tvm::ffi::Error> error = error_any.TryAs<tvm::ffi::Error>()) { \ throw std::move(*error); \ diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 05bfd6d1cf..9c4356ef4c 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -37,203 +37,10 @@ namespace tvm { namespace runtime { -/*! \brief Type information */ -struct TypeInfo { - /*! \brief The current index. */ - uint32_t index{0}; - /*! \brief Index of the parent in the type hierarchy */ - uint32_t parent_index{0}; - // NOTE: the indices in [index, index + num_reserved_slots) are - // reserved for the child-class of this type. - /*! \brief Total number of slots reserved for the type and its children. */ - uint32_t num_slots{0}; - /*! \brief number of allocated child slots. */ - uint32_t allocated_slots{0}; - /*! \brief Whether child can overflow. */ - bool child_slots_can_overflow{true}; - /*! \brief name of the type. */ - std::string name; - /*! \brief hash of the name */ - size_t name_hash{0}; -}; - -/*! - * \brief Type context that manages the type hierarchy information. - */ -class TypeContext { - public: - // NOTE: this is a relatively slow path for child checking - // Most types are already checked by the fast-path via reserved slot checking. - bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) { - // invariance: child's type index is always bigger than its parent. - if (child_tindex < parent_tindex) return false; - if (child_tindex == parent_tindex) return true; - { - std::lock_guard<std::mutex> lock(mutex_); - ICHECK_LT(child_tindex, type_table_.size()); - while (child_tindex > parent_tindex) { - child_tindex = type_table_[child_tindex].parent_index; - } - } - return child_tindex == parent_tindex; - } - - uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex, - uint32_t parent_tindex, uint32_t num_child_slots, - bool child_slots_can_overflow) { - std::lock_guard<std::mutex> lock(mutex_); - auto it = type_key2index_.find(skey); - if (it != type_key2index_.end()) { - return it->second; - } - // try to allocate from parent's type table. - ICHECK_LT(parent_tindex, type_table_.size()) - << " skey=" << skey << ", static_index=" << static_tindex; - TypeInfo& pinfo = type_table_[parent_tindex]; - ICHECK_EQ(pinfo.index, parent_tindex); - - // if parent cannot overflow, then this class cannot. - if (!pinfo.child_slots_can_overflow) { - child_slots_can_overflow = false; - } - - // total number of slots include the type itself. - uint32_t num_slots = num_child_slots + 1; - uint32_t allocated_tindex; - - if (static_tindex != TypeIndex::kDynamic) { - // statically assigned type - VLOG(3) << "TypeIndex[" << static_tindex << "]: static: " << skey << ", parent " - << type_table_[parent_tindex].name; - allocated_tindex = static_tindex; - ICHECK_LT(static_tindex, type_table_.size()); - ICHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) - << "Conflicting static index " << static_tindex << " between " - << type_table_[allocated_tindex].name << " and " << skey; - } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) { - // allocate the slot from parent's reserved pool - allocated_tindex = parent_tindex + pinfo.allocated_slots; - VLOG(3) << "TypeIndex[" << allocated_tindex << "]: dynamic: " << skey << ", parent " - << type_table_[parent_tindex].name; - // update parent's state - pinfo.allocated_slots += num_slots; - } else { - VLOG(3) << "TypeIndex[" << type_counter_ << "]: dynamic (overflow): " << skey << ", parent " - << type_table_[parent_tindex].name; - ICHECK(pinfo.child_slots_can_overflow) - << "Reach maximum number of sub-classes for " << pinfo.name; - // allocate new entries. - allocated_tindex = type_counter_; - type_counter_ += num_slots; - ICHECK_LE(type_table_.size(), type_counter_); - type_table_.resize(type_counter_, TypeInfo()); - } - ICHECK_GT(allocated_tindex, parent_tindex); - // initialize the slot. - type_table_[allocated_tindex].index = allocated_tindex; - type_table_[allocated_tindex].parent_index = parent_tindex; - type_table_[allocated_tindex].num_slots = num_slots; - type_table_[allocated_tindex].allocated_slots = 1; - type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow; - type_table_[allocated_tindex].name = skey; - type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey); - // update the key2index mapping. - type_key2index_[skey] = allocated_tindex; - return allocated_tindex; - } - - std::string TypeIndex2Key(uint32_t tindex) { - std::lock_guard<std::mutex> lock(mutex_); - if (tindex != 0) { - // always return the right type key for root - // for non-root type nodes, allocated slots should not equal 0 - ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) - << "Unknown type index " << tindex; - } - return type_table_[tindex].name; - } - - size_t TypeIndex2KeyHash(uint32_t tindex) { - std::lock_guard<std::mutex> lock(mutex_); - ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0) - << "Unknown type index " << tindex; - return type_table_[tindex].name_hash; - } - - uint32_t TypeKey2Index(const std::string& skey) { - auto it = type_key2index_.find(skey); - ICHECK(it != type_key2index_.end()) - << "Cannot find type " << skey - << ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?"; - return it->second; - } - - void Dump(int min_children_count) { - std::vector<int> num_children(type_table_.size(), 0); - // reverse accumulation so we can get total counts in a bottom-up manner. - for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { - if (it->index != 0) { - num_children[it->parent_index] += num_children[it->index] + 1; - } - } - - for (const auto& info : type_table_) { - if (info.index != 0 && num_children[info.index] >= min_children_count) { - std::cerr << '[' << info.index << "] " << info.name - << "\tparent=" << type_table_[info.parent_index].name - << "\tnum_child_slots=" << info.num_slots - 1 - << "\tnum_children=" << num_children[info.index] << std::endl; - } - } - } - - static TypeContext* Global() { - static TypeContext inst; - return &inst; - } - - private: - TypeContext() { - type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo()); - type_table_[0].name = "runtime.Object"; - } - // mutex to avoid registration from multiple threads. - std::mutex mutex_; - std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd}; - std::vector<TypeInfo> type_table_; - std::unordered_map<std::string, uint32_t> type_key2index_; -}; - -uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex, - uint32_t parent_tindex, uint32_t num_child_slots, - bool child_slots_can_overflow) { - return TypeContext::Global()->GetOrAllocRuntimeTypeIndex( - key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow); -} - -bool Object::DerivedFrom(uint32_t parent_tindex) const { - return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex); -} - -std::string Object::TypeIndex2Key(uint32_t tindex) { - return TypeContext::Global()->TypeIndex2Key(tindex); -} - -size_t Object::TypeIndex2KeyHash(uint32_t tindex) { - return TypeContext::Global()->TypeIndex2KeyHash(tindex); -} - -uint32_t Object::TypeKey2Index(const std::string& key) { - return TypeContext::Global()->TypeKey2Index(key); -} - TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) { return static_cast<int64_t>(ObjectPtrHash()(obj)); }); -TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) { - TypeContext::Global()->Dump(min_child_count); -}); } // namespace runtime } // namespace tvm @@ -258,8 +65,14 @@ int TVMObjectFree(TVMObjectHandle obj) { int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { API_BEGIN(); - *is_derived = - tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index); + *is_derived = [&]() { + if (child_type_index == parent_type_index) return true; + if (child_type_index < parent_type_index) return false; + const TVMFFITypeInfo* child_type_info = TVMFFIGetTypeInfo(child_type_index); + const TVMFFITypeInfo* parent_type_info = TVMFFIGetTypeInfo(parent_type_index); + return (child_type_info->type_depth > parent_type_info->type_depth && + child_type_info->type_acenstors[parent_type_info->type_depth] == static_cast<int32_t>(parent_type_index)); + }(); API_END(); } diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 3cde330f59..a369895434 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -64,16 +64,18 @@ class ObjectInternal { * \param type_index The type index of interest. * \return The derivation checking result. */ - static bool DerivedFrom(const Object* obj, uint32_t type_index) { - return obj->DerivedFrom(type_index); - } + // static bool DerivedFrom(const Object* obj, uint32_t type_index) { + // return obj->DerivedFrom(type_index); + // } /*! * \brief Expose TypeKey2Index * \param type_key The original type key. * \return the corresponding index. */ static uint32_t ObjectTypeKey2Index(const std::string& type_key) { - return Object::TypeKey2Index(type_key); + int32_t type_index; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKey2Index(type_key.c_str(), &type_index)); + return static_cast<uint32_t>(type_index); } /*! * \brief Convert ModuleHandle to module node pointer.
