This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 6c16f46c9a1f524980aa332087902ec5d1620cdf Author: tqchen <[email protected]> AuthorDate: Sun Mar 9 11:25:30 2025 -0400 Move over container on new object --- include/tvm/runtime/container/array.h | 13 +++--- include/tvm/runtime/container/map.h | 4 +- include/tvm/runtime/object.h | 78 +++++++++++++++++++++++++---------- src/runtime/debug_compile.cc | 4 ++ 4 files changed, 70 insertions(+), 29 deletions(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index ba8fdfac55..7513a1299a 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -121,7 +121,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> { static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ArrayNode, Object); private: /*! \return Size of initialized memory, used by InplaceArrayBase. */ @@ -903,14 +903,15 @@ inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) { return std::move(lhs); } +} // namespace runtime + +namespace ffi{ // Specialize make_object<ArrayNode> to make sure it is correct. template <> -inline ObjectPtr<ArrayNode> make_object() { - return ArrayNode::Empty(); +inline ObjectPtr<tvm::runtime::ArrayNode> make_object() { + return tvm::runtime::ArrayNode::Empty(); +} } - -} // namespace runtime - // expose the functions to the root namespace. using runtime::Array; using runtime::ArrayNode; diff --git a/include/tvm/runtime/container/map.h b/include/tvm/runtime/container/map.h index 1e4015aa23..a24d61d4de 100644 --- a/include/tvm/runtime/container/map.h +++ b/include/tvm/runtime/container/map.h @@ -68,7 +68,7 @@ class MapNode : public Object { static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(MapNode, Object); /*! * \brief Number of elements in the SmallMapNode @@ -187,7 +187,7 @@ class MapNode : public Object { static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeMap; static constexpr const char* _type_key = "Map"; - TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object); + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(MapNode, Object); /*! * \brief Number of elements in the SmallMapNode diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index c2b3681bc4..8fc296519f 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -30,6 +30,8 @@ namespace tvm { namespace runtime { +template<typename T> +class Optional; using namespace tvm::ffi; /*! @@ -38,27 +40,27 @@ using namespace tvm::ffi; * the constant, but still able to use enum. */ enum TypeIndex : int32_t { - // Standard static index assignments, - // Frontends can take benefit of these constants. - kRuntimeString = TVMFFITypeIndex::kTVMFFIStr, - kRuntimeMap = TVMFFITypeIndex::kTVMFFIMap, - kRuntimeArray = TVMFFITypeIndex::kTVMFFIArray, - /*! \brief runtime::Module. */ - kRuntimeModule = TVMFFITypeIndex::kTVMFFIRuntimeModule, - /*! \brief runtime::NDArray. */ - kRuntimeNDArray = TVMFFITypeIndex::kTVMFFINDArray, - /*! \brief runtime::ShapeTuple. */ - kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShapeTuple, - // Extra builtin static index here - kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIStaticObjectEnd, - /*! \brief runtime::PackedFunc. */ - kRuntimePackedFunc = kCustomStaticIndex + 1, - /*! \brief runtime::DRef for disco distributed runtime */ - kRuntimeDiscoDRef = kCustomStaticIndex + 2, - /*! \brief runtime::RPCObjectRef */ - kRuntimeRPCObjectRef = kCustomStaticIndex + 3, - // static assignments that may subject to change. - kStaticIndexEnd + // Standard static index assignments, + // Frontends can take benefit of these constants. + kRuntimeString = TVMFFITypeIndex::kTVMFFIStr, + kRuntimeMap = TVMFFITypeIndex::kTVMFFIMap, + kRuntimeArray = TVMFFITypeIndex::kTVMFFIArray, + /*! \brief runtime::Module. */ + kRuntimeModule = TVMFFITypeIndex::kTVMFFIRuntimeModule, + /*! \brief runtime::NDArray. */ + kRuntimeNDArray = TVMFFITypeIndex::kTVMFFINDArray, + /*! \brief runtime::ShapeTuple. */ + kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShapeTuple, + // Extra builtin static index here + kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIStaticObjectEnd, + /*! \brief runtime::PackedFunc. */ + kRuntimePackedFunc = kCustomStaticIndex + 1, + /*! \brief runtime::DRef for disco distributed runtime */ + kRuntimeDiscoDRef = kCustomStaticIndex + 2, + /*! \brief runtime::RPCObjectRef */ + kRuntimeRPCObjectRef = kCustomStaticIndex + 3, + // static assignments that may subject to change. + kStaticIndexEnd, }; class ObjectRef : public tvm::ffi::ObjectRef { @@ -68,6 +70,29 @@ class ObjectRef : public tvm::ffi::ObjectRef { /*! \brief Constructor from existing object ptr */ explicit ObjectRef(ObjectPtr<Object> data) : tvm::ffi::ObjectRef(data) {} + using tvm::ffi::ObjectRef::as; + + /*! + * \brief Try to downcast the ObjectRef to a + * Optional<T> of the requested type. + * + * The function will return a NullOpt if the cast failed. + * + * if (Optional<Add> opt = node_ref.as<Add>()) { + * // This is an add node + * } + * + * \note While this method is declared in <tvm/runtime/object.h>, + * the implementation is in <tvm/runtime/container/optional.h> to + * prevent circular includes. This additional include file is only + * required in compilation units that uses this method. + * + * \tparam ObjectRefType the target type, must be a subtype of ObjectRef + */ + template <typename ObjectRefType, + typename = std::enable_if_t<std::is_base_of_v<ObjectRef, ObjectRefType>>> + inline Optional<ObjectRefType> as() const; + protected: /*! \return return a mutable internal ptr, can be used by sub-classes. */ Object* get_mutable() const { return data_.get(); } @@ -99,6 +124,7 @@ class ObjectRef : public tvm::ffi::ObjectRef { static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) { return ObjectPtr<ObjectType>(ref.data_.data_); } + // friend classes. friend struct ObjectPtrHash; friend class TVMRetValue; @@ -106,6 +132,16 @@ class ObjectRef : public tvm::ffi::ObjectRef { friend class ObjectInternal; }; +/* + * \brief Define the default copy/move constructor and assign operator + * \param TypeName The class typename. + */ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; + #define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO #define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc index 5f9f6b10d1..214302d80f 100644 --- a/src/runtime/debug_compile.cc +++ b/src/runtime/debug_compile.cc @@ -22,6 +22,10 @@ * \brief File used for debug migration */ #include <tvm/runtime/container/string.h> +#include <tvm/runtime/container/optional.h> +#include <tvm/runtime/container/array.h> +#include <tvm/runtime/container/map.h> +#include <tvm/runtime/container/variant.h> namespace tvm { namespace debug {
