This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 66f993be83bb55232dafa23a286511e6c2425ccd Author: tqchen <[email protected]> AuthorDate: Sun Mar 9 18:14:26 2025 -0400 pass expr first stab --- ffi/include/tvm/ffi/object.h | 28 +++++++++------------- ffi/include/tvm/ffi/reflection.h | 8 +++++-- include/tvm/runtime/object.h | 51 +++++++++++++++++++++++++++++++++++++++- src/node/reflection.cc | 5 ++-- src/runtime/debug_compile.cc | 2 ++ 5 files changed, 72 insertions(+), 22 deletions(-) diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index d63b4efbd7..b2189fb7f5 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -426,24 +426,15 @@ struct ObjectPtrEqual { } }; -/*! - * \brief Helper macro to declare list of static checks about object meta-data. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ - TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it.") // If dynamic type is enabled, we still need to register the runtime type of parent #define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static_assert(!ParentType::_type_final, "ParentType marked as final"); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ TypeName::_type_key, TypeName::_type_index, TypeName::_type_depth, \ TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow, \ @@ -458,9 +449,8 @@ struct ObjectPtrEqual { * \param ParentType The name of the ParentType */ #define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ - TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType); \ static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ - TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) + TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType); /* @@ -524,7 +514,6 @@ struct ObjectPtrEqual { namespace details { - template <typename TargetType> TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { static_assert(std::is_base_of_v<Object, TargetType>); @@ -575,6 +564,11 @@ class ObjectUnsafe { reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_))); } + template <typename T> + static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromObjectRef(const ObjectRef& ref) { + return tvm::ffi::ObjectPtr<T>(ref.data_.data_); + } + template <typename T> static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) { tvm::ffi::ObjectPtr<T> ptr; diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h index f3cc580853..d154522b9f 100644 --- a/ffi/include/tvm/ffi/reflection.h +++ b/ffi/include/tvm/ffi/reflection.h @@ -154,8 +154,12 @@ class ReflectionFieldGetter { * \param ParentType The name of the ParentType */ #define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType); \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static_assert(!ParentType::_type_final, "ParentType marked as final"); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ TypeName::_type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index ec39461b0c..d956eebf1f 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -128,11 +128,13 @@ class ObjectRef : public tvm::ffi::ObjectRef { * \brief Internal helper function get data_ as ObjectPtr of ObjectType. * \note only used for internal dev purpose. * \tparam ObjectType The corresponding object type. + * \param ref The object reference * \return the corresponding type. */ template <typename ObjectType> static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) { - return ObjectPtr<ObjectType>(ref.data_.data_); + // return ObjectPtr<ObjectType>(ref.data_.data_); + return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ObjectType>(ref); } // friend classes. @@ -152,6 +154,53 @@ class ObjectRef : public tvm::ffi::ObjectRef { TypeName& operator=(const TypeName& other) = default; \ TypeName& operator=(TypeName&& other) = default; +/*! + * \brief Define CopyOnWrite function in an ObjectRef. + * \param ObjectName The Type of the Node. + * + * CopyOnWrite will generate a unique copy of the internal node. + * The node will be copied if it is referenced by multiple places. + * The function returns the raw pointer to the node to allow modification + * of the content. + * + * \code + * + * MyCOWObjectRef ref, ref2; + * ref2 = ref; + * ref.CopyOnWrite()->value = new_value; + * assert(ref2->value == old_value); + * assert(ref->value == new_value); + * + * \endcode + */ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ +static_assert(ObjectName::_type_final, \ + "TVM's CopyOnWrite may only be used for " \ + "Object types that are declared as final, " \ + "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ +ObjectName* CopyOnWrite() { \ + ICHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object<ObjectName>(*(operator->())); \ + ObjectPtr<Object>(std::move(n)).swap(data_); \ + } \ + return static_cast<ObjectName*>(data_.get()); \ +} + +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ + ObjectName) \ +explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ +TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ +const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ +const ObjectName* get() const { return operator->(); } \ +using ContainerType = ObjectName; + #define TVM_DECLARE_BASE_OBJECT_INFO TVM_FFI_DECLARE_BASE_OBJECT_INFO #define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO diff --git a/src/node/reflection.cc b/src/node/reflection.cc index aa572e9965..02e7b2423d 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -153,8 +153,9 @@ ReflectionVTable* ReflectionVTable::Global() { ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& type_key, const std::string& repr_bytes) const { - uint32_t tindex = Object::TypeKey2Index(type_key); - if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { + int32_t tindex; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKey2Index(type_key.c_str(), &tindex)); + if (static_cast<size_t>(tindex) >= fcreate_.size() || fcreate_[tindex] == nullptr) { LOG(FATAL) << "TypeError: " << type_key << " is not registered via TVM_REGISTER_NODE_TYPE"; } return fcreate_[tindex](repr_bytes); diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc index 4233035abd..17f86d5f7f 100644 --- a/src/runtime/debug_compile.cc +++ b/src/runtime/debug_compile.cc @@ -30,6 +30,8 @@ #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include <tvm/runtime/disco/disco_worker.h> +#include <tvm/ir/expr.h> +#include <tvm/tir/expr.h> namespace tvm { namespace debug {
