This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 130f39e5f89304191629eeff10ab5a7dec791b54 Author: tqchen <[email protected]> AuthorDate: Sun Mar 9 17:45:16 2025 -0400 Pass all runtime --- CMakeLists.txt | 6 +++--- cmake/modules/contrib/CUTLASS.cmake | 2 ++ ffi/include/tvm/ffi/cast.h | 15 +++++++++++++++ ffi/include/tvm/ffi/object.h | 20 ++++---------------- include/tvm/runtime/container/array.h | 2 +- include/tvm/runtime/memory/memory_manager.h | 4 ++-- include/tvm/runtime/packed_func.h | 3 ++- include/tvm/runtime/relax_vm/vm.h | 2 -- src/runtime/memory/memory_manager.cc | 18 +++++++++++------- src/runtime/minrpc/rpc_reference.h | 12 +++++++----- src/runtime/ndarray.cc | 4 ++-- src/runtime/relax_vm/kv_state.h | 5 +---- src/runtime/relax_vm/lm_support.cc | 1 - src/runtime/relax_vm/paged_kv_cache.cc | 1 - src/runtime/relax_vm/rnn_state.cc | 4 ++-- src/runtime/rpc/rpc_module.cc | 2 +- src/runtime/rpc/rpc_session.h | 3 ++- 17 files changed, 55 insertions(+), 49 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b630f12b06..4f3dc2c9e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -753,9 +753,9 @@ string(APPEND PROJECT_CONFIG_CONTENT "include(\"\${CMAKE_CURRENT_LIST_DIR}/${PROJECT_NAME}Targets.cmake\")") file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/temp_config_file.cmake" ${PROJECT_CONFIG_CONTENT}) -install(EXPORT ${PROJECT_NAME}Targets - NAMESPACE ${PROJECT_NAME}:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) +# install(EXPORT ${PROJECT_NAME}Targets +# NAMESPACE ${PROJECT_NAME}:: +# DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) # Create config for find_package() configure_package_config_file( diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index 75e7f68c88..de3bc860f9 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -33,6 +33,8 @@ if(USE_CUDA AND USE_CUTLASS) ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include ) + target_link_libraries(fpA_intB_gemm_tvm PRIVATE tvm_ffi_header) + set(CUTLASS_FPA_INTB_RUNTIME_SRCS "") list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS}) diff --git a/ffi/include/tvm/ffi/cast.h b/ffi/include/tvm/ffi/cast.h index 3cdb620d22..c7365a2f45 100644 --- a/ffi/include/tvm/ffi/cast.h +++ b/ffi/include/tvm/ffi/cast.h @@ -51,6 +51,21 @@ inline RefType GetRef(const ObjectType* ptr) { const_cast<Object*>(static_cast<const Object*>(ptr)))); } +/*! + * \brief Get an object ptr type from a raw object ptr. + * + * \param ptr The object pointer + * \tparam BaseType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template <typename BaseType, typename ObjectType> +inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr) { + static_assert(std::is_base_of<BaseType, ObjectType>::value, + "Can only cast to the ref of same container type"); + return details::ObjectUnsafe::ObjectPtrFromUnowned<BaseType>(ptr); +} + /*! * \brief Downcast a base reference type to a more specific type. * diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 01dca88878..d63b4efbd7 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -320,10 +320,6 @@ class ObjectPtr { template <typename> friend class ObjectPtr; friend class tvm::ffi::details::ObjectUnsafe; - template <typename RelayRefType, typename ObjType> - friend RelayRefType GetRef(const ObjType* ptr); - template <typename BaseType, typename ObjType> - friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr); }; // Forward declaration, to prevent circular includes. @@ -410,17 +406,6 @@ class ObjectRef { friend class tvm::ffi::details::ObjectUnsafe; }; -/*! - * \brief Get an object ptr type from a raw object ptr. - * - * \param ptr The object pointer - * \tparam BaseType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template <typename BaseType, typename ObjectType> -inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); - /*! \brief ObjectRef hash functor */ struct ObjectPtrHash { size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } @@ -673,12 +658,15 @@ class ObjectUnsafe { return ptr; } + static TVM_FFI_INLINE Object** GetObjectRValueRefValue(ObjectRef* ref) { + return const_cast<Object**>(&(ref->data_.data_)); + } + // legacy APIs to support migration and can be moved later static TVM_FFI_INLINE void LegacyClearObjectPtrAfterMove(ObjectRef* src) { src->data_.data_ = nullptr; } }; - } // namespace details } // namespace ffi } // namespace tvm diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 7513a1299a..f4e825f44f 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -247,7 +247,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> { friend class Array; // To specialize make_object<ArrayNode> - friend ObjectPtr<ArrayNode> make_object<>(); + friend ObjectPtr<ArrayNode> tvm::ffi::make_object<>(); }; /*! \brief Helper struct for type-checking diff --git a/include/tvm/runtime/memory/memory_manager.h b/include/tvm/runtime/memory/memory_manager.h index 754d7b4fe3..a88d05603f 100644 --- a/include/tvm/runtime/memory/memory_manager.h +++ b/include/tvm/runtime/memory/memory_manager.h @@ -171,10 +171,10 @@ class StorageObj : public Object { String scope = "global"); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ - static void ScopedDeleter(Object* ptr); + static void ScopedDeleter(void* ptr); /*! \brief The deleter for an NDArray when allocated from underlying storage. */ - static void Deleter(Object* ptr); + static void Deleter(void* ptr); ~StorageObj() { if (allocator) { diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 168798149a..912a15e976 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -2183,7 +2183,8 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { // Final fallback, if the ObjectRef has no special cases that must // be expressed within the TVMRetValue. if constexpr (std::is_rvalue_reference_v<decltype(value)>) { - values_[i].v_handle = const_cast<Object**>(&(value.data_.data_)); + // values_[i].v_handle = const_cast<Object**>(&(value.data_.data_)); + values_[i].v_handle = ffi::details::ObjectUnsafe::GetObjectRValueRefValue(&value); type_codes_[i] = kTVMObjectRValueRefArg; } else { // value.data_.data_; diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index 7bf716ae50..3d2bbc130e 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -76,7 +76,6 @@ class VMClosureObj : public Object { */ PackedFunc impl; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.vm.Closure"; TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, Object); }; @@ -108,7 +107,6 @@ class VMClosure : public ObjectRef { */ class VMExtensionNode : public Object { protected: - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "runtime.VMExtension"; TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object); }; diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index a4b8e15943..fb10b7724b 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -34,7 +34,7 @@ namespace tvm { namespace runtime { namespace memory { -static void BufferDeleter(Object* obj) { +static void BufferDeleter(void* obj) { auto* ptr = static_cast<NDArray::Container*>(obj); ICHECK(ptr->manager_ctx != nullptr); Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx); @@ -50,7 +50,7 @@ Storage::Storage(Buffer buffer, Allocator* allocator) { data_ = std::move(n); } -void StorageObj::Deleter(Object* obj) { +void StorageObj::Deleter(void* obj) { auto* ptr = static_cast<NDArray::Container*>(obj); // When invoking AllocNDArray we don't own the underlying allocation // and should not delete the buffer, but instead let it be reclaimed @@ -62,7 +62,8 @@ void StorageObj::Deleter(Object* obj) { // We decrement the object allowing for the buffer to release our // reference count from allocation. StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx); - storage->DecRef(); + // storage->DecRef(); + tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(storage); delete ptr; } @@ -84,13 +85,14 @@ inline size_t GetDataAlignment(const DLTensor& arr) { return align; } -void StorageObj::ScopedDeleter(Object* obj) { +void StorageObj::ScopedDeleter(void* obj) { auto* ptr = static_cast<NDArray::Container*>(obj); StorageObj* storage = reinterpret_cast<StorageObj*>(ptr->manager_ctx); // Let the device handle proper cleanup of view storage->allocator->FreeView(ptr->dl_tensor.device, ptr->dl_tensor.data); - storage->DecRef(); + // storage->DecRef(); + tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(storage); delete ptr; } @@ -105,7 +107,8 @@ NDArray StorageObj::AllocNDArrayScoped(int64_t offset, ShapeTuple shape, DLDataT container->dl_tensor.byte_offset = offset; container->SetDeleter(StorageObj::ScopedDeleter); size_t needed_size = DeviceAPI::Get(this->buffer.device)->GetDataSize(container->dl_tensor); - this->IncRef(); + // this->IncRef(); + tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(this); container->manager_ctx = reinterpret_cast<void*>(this); NDArray ret(GetObjectPtr<Object>(container)); // RAII in effect, now run the check. @@ -125,7 +128,8 @@ NDArray StorageObj::AllocNDArray(int64_t offset, ShapeTuple shape, DLDataType dt container->SetDeleter(StorageObj::Deleter); size_t needed_size = DeviceAPI::Get(this->buffer.device)->GetDataSize(container->dl_tensor); - this->IncRef(); + // this->IncRef(); + tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(this); // The manager context pointer must continue to point to the storage object // which owns the backing memory, and keeps track of the reference count. // diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index bb9c5d3c86..ff3c9f22fd 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -25,10 +25,12 @@ #define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ namespace tvm { -namespace runtime { - +namespace ffi { // Forward declare TVM Object to use `Object*` in RPC protocol. -// class Object; +class Object; +} // namespace ffi + +namespace runtime { /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; @@ -206,7 +208,7 @@ struct RPCReference { num_bytes_ += sizeof(T) * num; } - void WriteObject(Object* obj) { num_bytes_ += channel_->GetObjectBytes(obj); } + void WriteObject(ffi::Object* obj) { num_bytes_ += channel_->GetObjectBytes(obj); } void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); } @@ -383,7 +385,7 @@ struct RPCReference { break; } case kTVMObjectHandle: { - channel->WriteObject(static_cast<Object*>(value.v_handle)); + channel->WriteObject(static_cast<ffi::Object*>(value.v_handle)); break; } default: { diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 32f11a9e8a..63087fda6f 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -104,7 +104,7 @@ struct NDArray::Internal { static void DefaultDeleter(void* ptr_obj) { auto* ptr = static_cast<NDArray::Container*>(ptr_obj); if (ptr->manager_ctx != nullptr) { - details::ObjectUnsafe::DecRefObjectHandle( + ffi::details::ObjectUnsafe::DecRefObjectHandle( static_cast<NDArray::Container*>(ptr->manager_ctx) ); } else if (ptr->dl_tensor.data != nullptr) { @@ -179,7 +179,7 @@ struct NDArray::Internal { } // Delete dlpack object. static void NDArrayDLPackDeleter(DLManagedTensor* tensor) { - details::ObjectUnsafe::DecRefObjectHandle(static_cast<NDArray::Container*>(tensor->manager_ctx)); + ffi::details::ObjectUnsafe::DecRefObjectHandle(static_cast<NDArray::Container*>(tensor->manager_ctx)); delete tensor; } }; diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 1e530b41ec..14e17d26b7 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -103,9 +103,8 @@ class KVStateObj : public Object { */ virtual void EndForward() = 0; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.vm.KVState"; - TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object) + TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object); }; class KVState : public ObjectRef { @@ -293,7 +292,6 @@ class AttentionKVCacheObj : public KVStateObj { */ virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) = 0; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.vm.AttentionKVCache"; TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj); }; @@ -337,7 +335,6 @@ class RNNStateObj : public KVStateObj { */ virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) = 0; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.vm.RNNState"; TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj); }; diff --git a/src/runtime/relax_vm/lm_support.cc b/src/runtime/relax_vm/lm_support.cc index 95dca0c6d5..4af4f42db4 100644 --- a/src/runtime/relax_vm/lm_support.cc +++ b/src/runtime/relax_vm/lm_support.cc @@ -225,7 +225,6 @@ class AttentionKVCacheLegacyObj : public Object { this->fill_count += value->shape[0]; } - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy"; TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object); }; diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 6b68bb1b1a..0011fa28d1 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1600,7 +1600,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { ICHECK(false) << "DebugSetKV for PageAttentionKVCache not implemented yet."; } - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; static constexpr const char* _type_key = "relax.vm.PagedAttentionKVCache"; TVM_DECLARE_FINAL_OBJECT_INFO(PagedAttentionKVCacheObj, AttentionKVCacheObj); diff --git a/src/runtime/relax_vm/rnn_state.cc b/src/runtime/relax_vm/rnn_state.cc index 16fe6791b8..c823cfbb98 100644 --- a/src/runtime/relax_vm/rnn_state.cc +++ b/src/runtime/relax_vm/rnn_state.cc @@ -454,8 +454,8 @@ class RNNStateImpObj : public RNNStateObj { } public: - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; - static constexpr const char* _type_key = "relax.vm.RNNStateImp"; + + static constexpr const char* _type_key = "relax.vm.RNNStateImp"; TVM_DECLARE_FINAL_OBJECT_INFO(RNNStateImpObj, RNNStateObj); }; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index a696005ab8..57f2771501 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -41,7 +41,7 @@ namespace tvm { namespace runtime { // deleter of RPC remote array -static void RemoteNDArrayDeleter(Object* obj) { +static void RemoteNDArrayDeleter(void* obj) { auto* ptr = static_cast<NDArray::Container*>(obj); RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data); if (ptr->manager_ctx != nullptr) { diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index f01b571b25..9bb4f31946 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -320,7 +320,8 @@ class RPCObjectRefObj : public Object { static constexpr const uint32_t _type_index = TypeIndex::kRuntimeRPCObjectRef; static constexpr const char* _type_key = "runtime.RPCObjectRef"; - TVM_DECLARE_FINAL_OBJECT_INFO(RPCObjectRefObj, Object); + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(RPCObjectRefObj, Object); private: // The object handle
