This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ad6d5fbdf05de9448e0777d3ea27a81447704ffd Author: tqchen <[email protected]> AuthorDate: Sun Mar 9 12:22:06 2025 -0400 pass ndarray compile --- ffi/include/tvm/ffi/object.h | 4 ++++ include/tvm/runtime/object.h | 5 +++++ src/runtime/debug_compile.cc | 1 + src/runtime/ndarray.cc | 23 ++++++++++++++--------- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index a6f85be999..8fbd6705fd 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -575,6 +575,10 @@ class ObjectUnsafe { reinterpret_cast<Object*>(handle)->DecRef(); } + static TVM_FFI_INLINE void IncRefObjectHandle(TVMFFIObjectHandle handle) { + reinterpret_cast<Object*>(handle)->IncRef(); + } + static TVM_FFI_INLINE void DecRefObjectInAny(TVMFFIAny* src) { reinterpret_cast<Object*>(src->v_obj)->DecRef(); } diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index b84cf1914b..046f0d3948 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -26,6 +26,7 @@ #include <tvm/ffi/object.h> #include <tvm/ffi/cast.h> #include <tvm/ffi/reflection.h> +#include <tvm/runtime/c_runtime_api.h> namespace tvm { namespace runtime { @@ -147,6 +148,10 @@ class ObjectRef : public tvm::ffi::ObjectRef { #define TVM_DEFINE_OBJECT_REF_METHODS TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS +#define TVM_STR_CONCAT_(__x, __y) __x##__y +#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) + + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OBJECT_H_ diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc index e04c0b0003..81f1edffc7 100644 --- a/src/runtime/debug_compile.cc +++ b/src/runtime/debug_compile.cc @@ -28,6 +28,7 @@ #include <tvm/runtime/container/variant.h> #include <tvm/runtime/ndarray.h> #include <tvm/runtime/packed_func.h> +#include <tvm/runtime/registry.h> namespace tvm { namespace debug { diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 5a328413a1..32f11a9e8a 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -101,10 +101,12 @@ void ArrayCopyToBytes(const DLTensor* handle, void* data, size_t nbytes) { struct NDArray::Internal { // Default deleter for the container - static void DefaultDeleter(Object* ptr_obj) { + static void DefaultDeleter(void* ptr_obj) { auto* ptr = static_cast<NDArray::Container*>(ptr_obj); if (ptr->manager_ctx != nullptr) { - static_cast<NDArray::Container*>(ptr->manager_ctx)->DecRef(); + details::ObjectUnsafe::DecRefObjectHandle( + static_cast<NDArray::Container*>(ptr->manager_ctx) + ); } else if (ptr->dl_tensor.data != nullptr) { tvm::runtime::DeviceAPI::Get(ptr->dl_tensor.device) ->FreeDataSpace(ptr->dl_tensor.device, ptr->dl_tensor.data); @@ -116,7 +118,7 @@ struct NDArray::Internal { // that are not allocated inside of TVM. // This enables us to create NDArray from memory allocated by other // frameworks that are DLPack compatible - static void DLPackDeleter(Object* ptr_obj) { + static void DLPackDeleter(void* ptr_obj) { auto* ptr = static_cast<NDArray::Container*>(ptr_obj); DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx); if (tensor->deleter != nullptr) { @@ -127,7 +129,7 @@ struct NDArray::Internal { // Deleter for NDArray based on external DLTensor // The memory is allocated from outside and it is assumed that // responsibility for its freeing is also outside - static void SelfDeleter(Object* ptr_obj) { + static void SelfDeleter(void* ptr_obj) { auto* ptr = static_cast<NDArray::Container*>(ptr_obj); delete ptr; } @@ -171,13 +173,13 @@ struct NDArray::Internal { DLManagedTensor* ret = new DLManagedTensor(); ret->dl_tensor = from->dl_tensor; ret->manager_ctx = from; - from->IncRef(); + tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(from); ret->deleter = TVMNDArrayDLPackDeleter; return ret; } // Delete dlpack object. static void NDArrayDLPackDeleter(DLManagedTensor* tensor) { - static_cast<NDArray::Container*>(tensor->manager_ctx)->DecRef(); + details::ObjectUnsafe::DecRefObjectHandle(static_cast<NDArray::Container*>(tensor->manager_ctx)); delete tensor; } }; @@ -224,7 +226,8 @@ NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype, uint64_t relativ << ", dtype= " << curr_dl_tensor.dtype << ")."; // increase ref count - get_mutable()->IncRef(); + // get_mutable()->IncRef(); + tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(get_mutable()); ret.get_mutable()->manager_ctx = get_mutable(); ret.get_mutable()->dl_tensor.data = get_mutable()->dl_tensor.data; ret.get_mutable()->dl_tensor.byte_offset = @@ -352,7 +355,7 @@ bool NDArray::IsAligned(const DLTensor& tensor) { 0); } -TVM_REGISTER_OBJECT_TYPE(NDArray::Container); +// TVM_REGISTER_OBJECT_TYPE(NDArray::Container); } // namespace runtime } // namespace tvm @@ -365,7 +368,9 @@ void TVMNDArrayDLPackDeleter(DLManagedTensor* tensor) { int TVMArrayGetTypeIndex(TVMArrayHandle handle, unsigned* out_tindex) { API_BEGIN(); - *out_tindex = TVMArrayHandleToObjectHandle(handle)->type_index(); + *out_tindex = tvm::ffi::details::ObjectUnsafe::GetHeader( + TVMArrayHandleToObjectHandle(handle) + )->type_index; API_END(); }
