This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 5d3d586e94bb5ab700d4f4f7cdd91151ffd496ba Author: tqchen <[email protected]> AuthorDate: Sun Mar 9 12:51:09 2025 -0400 pass basic disco session compile --- ffi/include/tvm/ffi/object.h | 49 +++++++++++++++++++++++++++++ include/tvm/runtime/container/string.h | 1 + include/tvm/runtime/disco/session.h | 3 +- include/tvm/runtime/memory/memory_manager.h | 1 - include/tvm/runtime/object.h | 6 ++++ include/tvm/runtime/packed_func.h | 5 +-- src/runtime/debug_compile.cc | 1 + src/runtime/disco/protocol.h | 4 +-- src/runtime/minrpc/rpc_reference.h | 2 +- 9 files changed, 65 insertions(+), 7 deletions(-) diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 8fbd6705fd..01dca88878 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -114,6 +114,9 @@ class Object { return details::IsObjectInstance<TargetType>(header_.type_index); } + /*! \return The internal runtime type index of the object. */ + int32_t type_index() const { return header_.type_index; } + /*! * \return the type key of the object. * \note this operation is expensive, can be used for error reporting. @@ -124,6 +127,16 @@ class Object { return type_info->type_key; } + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + static std::string TypeIndex2Key(int32_t tindex) { + const TypeInfo* type_info = TVMFFIGetTypeInfo(tindex); + return type_info->type_key; + } + // Information about the object static constexpr const char* _type_key = "object.Object"; @@ -464,6 +477,7 @@ struct ObjectPtrEqual { static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) + /* * \brief Define object reference methods. * \param TypeName The object type name @@ -493,6 +507,37 @@ struct ObjectPtrEqual { static constexpr bool _type_is_nullable = false; \ using ContainerType = ObjectName +/* + * \brief Define object reference methods of whose content is mutable. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + * \note We recommend making objects immutable when possible. + * This macro is only reserved for objects that stores runtime states. + */ +#define TVM_DEFINE_MUTABLE_NULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \ + using ContainerType = ObjectName; + +/* + * \brief Define object reference methods that is both not nullable and mutable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + ObjectName* operator->() const { return static_cast<ObjectName*>(data_.get()); } \ + ObjectName* get() const { return operator->(); } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + + namespace details { template <typename TargetType> @@ -587,6 +632,10 @@ class ObjectUnsafe { reinterpret_cast<Object*>(src->v_obj)->IncRef(); } + static TVM_FFI_INLINE Object* GetRawObjectPtrFromObjectRef(const ObjectRef& src) { + return src.data_.data_; + } + static TVM_FFI_INLINE TVMFFIObject* GetTVMFFIObjectPtrFromObjectRef(const ObjectRef& src) { return GetHeader(src.data_.data_); } diff --git a/include/tvm/runtime/container/string.h b/include/tvm/runtime/container/string.h index 661d95bba2..4ee48d9c95 100644 --- a/include/tvm/runtime/container/string.h +++ b/include/tvm/runtime/container/string.h @@ -59,6 +59,7 @@ class StringObj : public Object { static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString; static constexpr const char* _type_key = "runtime.String"; + static const constexpr bool _type_final = true; TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object); private: diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 9c34f8a2af..1c3af22d83 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -148,7 +148,8 @@ class DRefObj : public Object { static constexpr const char* _type_key = "runtime.disco.DRef"; static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef; - TVM_DECLARE_FINAL_OBJECT_INFO(DRefObj, Object); + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(DRefObj, Object); /*! \brief The id of the register */ int64_t reg_id; diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index ab1e6b5c9f..754d7b4fe3 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -182,7 +182,6 @@ class StorageObj : public Object { } } - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "vm.Storage"; TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object); }; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 046f0d3948..2f89d11965 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -143,14 +143,20 @@ class ObjectRef : public tvm::ffi::ObjectRef { TypeName& operator=(const TypeName& other) = default; \ TypeName& operator=(TypeName&& other) = default; + +#define TVM_DECLARE_BASE_OBJECT_INFO TVM_FFI_DECLARE_BASE_OBJECT_INFO #define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO #define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS +#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS TVM_DEFINE_MUTABLE_NULLABLE_OBJECT_REF_METHODS #define TVM_DEFINE_OBJECT_REF_METHODS TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS +#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS #define TVM_STR_CONCAT_(__x, __y) __x##__y #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) +// Object register type is now a nop +#define TVM_REGISTER_OBJECT_TYPE(x) } // namespace runtime } // namespace tvm diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 804fd3b98e..0a2461515e 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -2096,7 +2096,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { return; } - Object* ptr = value.data_.data_; + Object* ptr = details::ObjectUnsafe::GetRawObjectPtrFromObjectRef(value); if constexpr (std::is_base_of_v<NDArray::ContainerType, ContainerType> || std::is_base_of_v<ContainerType, NDArray::ContainerType>) { if (std::is_base_of_v<NDArray::ContainerType, ContainerType> || @@ -2186,7 +2186,8 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { values_[i].v_handle = const_cast<Object**>(&(value.data_.data_)); type_codes_[i] = kTVMObjectRValueRefArg; } else { - values_[i].v_handle = value.data_.data_; + // value.data_.data_; + values_[i].v_handle = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(value); type_codes_[i] = kTVMObjectHandle; } } diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc index 81f1edffc7..4233035abd 100644 --- a/src/runtime/debug_compile.cc +++ b/src/runtime/debug_compile.cc @@ -29,6 +29,7 @@ #include <tvm/runtime/ndarray.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> +#include <tvm/runtime/disco/disco_worker.h> namespace tvm { namespace debug { diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h index 50a6b091af..cc741fdbb1 100644 --- a/src/runtime/disco/protocol.h +++ b/src/runtime/disco/protocol.h @@ -155,7 +155,7 @@ inline void DiscoProtocol<SubClassType>::WriteObject(Object* obj) { self->template Write<uint64_t>(shape->size); self->template WriteArray<ShapeTupleObj::index_type>(shape->data, shape->size); } else if (obj->IsInstance<DiscoDebugObject>()) { - self->template Write<uint32_t>(TypeIndex::kRoot); + self->template Write<uint32_t>(0); std::string str = static_cast<DiscoDebugObject*>(obj)->SaveToStr(); self->template Write<uint64_t>(str.size()); self->template WriteArray<char>(str.data(), str.size()); @@ -188,7 +188,7 @@ inline void DiscoProtocol<SubClassType>::ReadObject(int* tcode, TVMValue* value) std::vector<ShapeTupleObj::index_type> data(ndim); self->template ReadArray<ShapeTupleObj::index_type>(data.data(), ndim); result = ShapeTuple(std::move(data)); - } else if (type_index == TypeIndex::kRoot) { + } else if (type_index == 0) { uint64_t size = 0; self->template Read<uint64_t>(&size); std::string data(size, '\0'); diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 13c1fa4b38..bb9c5d3c86 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -28,7 +28,7 @@ namespace tvm { namespace runtime { // Forward declare TVM Object to use `Object*` in RPC protocol. -class Object; +// class Object; /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0";
