This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 67986d3369e58ec3b4ccd0c595a02b429d251fd0 Author: tqchen <[email protected]> AuthorDate: Fri Sep 6 18:57:38 2024 -0400 [FFI] Improve Any coverage to include ObjectPtr --- ffi/include/tvm/ffi/any.h | 10 +-- ffi/include/tvm/ffi/memory.h | 4 +- ffi/include/tvm/ffi/object.h | 29 ++++++--- ffi/include/tvm/ffi/type_traits.h | 128 ++++++++++++++++++++++++++++++++++++-- ffi/tests/example/test_any.cc | 22 +++++++ 5 files changed, 173 insertions(+), 20 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 6f2935f8a6..2a2d223c57 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -96,6 +96,7 @@ class AnyView { } TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) << "` to `" << TypeTraits<T>::TypeStr() << "`"; + TVM_FFI_UNREACHABLE(); } // The following functions are only used for testing purposes /*! @@ -126,7 +127,7 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, [[maybe_unused]] size_t extra_any_bytes = 0) { // TODO: string conversion. if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectInternal::IncRefObjectInAny(data); + details::ObjectUnsafe::IncRefObjectInAny(data); } } } // namespace details @@ -145,7 +146,7 @@ class Any { */ void reset() { if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectInternal::DecRefObjectInAny(&data_); + details::ObjectUnsafe::DecRefObjectInAny(&data_); } data_.type_index = TVMFFITypeIndex::kTVMFFINone; } @@ -166,7 +167,7 @@ class Any { // constructors from Any Any(const Any& other) : data_(other.data_) { if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectInternal::IncRefObjectInAny(&data_); + details::ObjectUnsafe::IncRefObjectInAny(&data_); } } Any(Any&& other) : data_(other.data_) { other.data_.type_index = TypeIndex::kTVMFFINone; } @@ -192,7 +193,7 @@ class Any { // constructor from general types template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> Any(T other) { // NOLINT(*) - TypeTraits<T>::MoveToManagedAny(std::move(other), &data_); + TypeTraits<T>::MoveToAny(std::move(other), &data_); } template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> Any& operator=(T other) { // NOLINT(*) @@ -213,6 +214,7 @@ class Any { } TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) << "` to `" << TypeTraits<T>::TypeStr() << "`"; + TVM_FFI_UNREACHABLE(); } // FFI related operations diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index ca3021ff78..03cd542d91 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -74,11 +74,11 @@ class ObjAllocatorBase { using Handler = typename Derived::template Handler<T>; static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...); - TVMFFIObject* ffi_ptr = details::ObjectInternal::GetHeader(ptr); + TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); // NOTE: ref_counter is initialized in object constructor ffi_ptr->type_index = T::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectInternal::ObjectPtrFromUnowned<T>(ptr); + return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(ptr); } /*! diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index a1547b1900..b7a0d1a7c5 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -36,8 +36,9 @@ using TypeIndex = TVMFFITypeIndex; using TypeInfo = TVMFFITypeInfo; namespace details { -// forward declare object internal -struct ObjectInternal; +// Helper to perform +// unsafe operations related to object +struct ObjectUnsafe; // Code section that depends on dynamic components #if TVM_FFI_ALLOW_DYN_TYPE @@ -191,7 +192,7 @@ class Object { // friend classes template <typename> friend class ObjectPtr; - friend class tvm::ffi::details::ObjectInternal; + friend class tvm::ffi::details::ObjectUnsafe; }; /*! @@ -328,7 +329,7 @@ class ObjectPtr { friend struct ObjectPtrHash; template <typename> friend class ObjectPtr; - friend class tvm::ffi::details::ObjectInternal; + friend class tvm::ffi::details::ObjectUnsafe; template <typename RelayRefType, typename ObjType> friend RelayRefType GetRef(const ObjType* ptr); template <typename BaseType, typename ObjType> @@ -417,7 +418,7 @@ class ObjectRef { Object* get_mutable() const { return data_.get(); } // friend classes. friend struct ObjectPtrHash; - friend class tvm::ffi::details::ObjectInternal; + friend class tvm::ffi::details::ObjectUnsafe; template <typename SubRef, typename BaseRef> friend SubRef Downcast(BaseRef ref); }; @@ -574,10 +575,10 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { * \note These functions are only supposed to be used by internal * implementations and not external users of the tvm::ffi */ -struct ObjectInternal { +struct ObjectUnsafe { // NOTE: get ffi header from an object - static TVM_FFI_INLINE TVMFFIObject* GetHeader(Object* src) { - return &(src->header_); + static TVM_FFI_INLINE TVMFFIObject* GetHeader(const Object* src) { + return const_cast<TVMFFIObject*>(&(src->header_)); } // Create ObjectPtr from unknowned ptr @@ -610,6 +611,18 @@ struct ObjectInternal { return GetHeader(obj_ptr); } + template <typename T> + static TVM_FFI_INLINE TVMFFIObject* GetTVMFFIObjectPtrFromObjectPtr(const ObjectPtr<T>& src) { + return GetHeader(src.data_); + } + + template <typename T> + static TVM_FFI_INLINE TVMFFIObject* MoveTVMFFIObjectPtrFromObjectPtr(ObjectPtr<T>* src) { + Object* obj_ptr = src->data_; + src->data_ = nullptr; + return GetHeader(obj_ptr); + } + // Create objectptr by moving from an existing address of object and setting its // address to nullptr template <typename T> diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 18c7ab2a35..e1d453fdba 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -59,6 +59,29 @@ struct TypeTraits { template <typename T> using TypeTraitsNoCR = TypeTraits<std::remove_const_t<std::remove_reference_t<T>>>; +// None +template <> +struct TypeTraits<std::nullptr_t> { + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(const std::nullptr_t&, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFINone; + } + + static TVM_FFI_INLINE void MoveToAny(std::nullptr_t, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFINone; + } + + static TVM_FFI_INLINE std::optional<std::nullptr_t> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return nullptr; + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "None"; } +}; + // Integer POD values template <typename Int> struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> { @@ -69,7 +92,7 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> { result->v_int64 = static_cast<int64_t>(src); } - static TVM_FFI_INLINE void MoveToManagedAny(Int src, TVMFFIAny* result) { + static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { ConvertToAnyView(src, result); } @@ -93,7 +116,7 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> { result->v_float64 = static_cast<double>(src); } - static TVM_FFI_INLINE void MoveToManagedAny(Float src, TVMFFIAny* result) { + static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) { ConvertToAnyView(src, result); } @@ -109,6 +132,30 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> { static TVM_FFI_INLINE std::string TypeStr() { return "float"; } }; +// void* +template <> +struct TypeTraits<void*> { + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(void* src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIOpaquePtr; + result->v_ptr = src; + } + + static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) { + ConvertToAnyView(src, result); + } + + static TVM_FFI_INLINE std::optional<void*> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { + return std::make_optional<void*>(src->v_ptr); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "void*"; } +}; + // Traits for object template <typename TObjRef> struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef>>> { @@ -117,13 +164,13 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef static constexpr bool enabled = true; static TVM_FFI_INLINE void ConvertToAnyView(const TObjRef& src, TVMFFIAny* result) { - TVMFFIObject* obj_ptr = details::ObjectInternal::GetTVMFFIObjectPtrFromObjectRef(src); + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; } - static TVM_FFI_INLINE void MoveToManagedAny(TObjRef src, TVMFFIAny* result) { - TVMFFIObject* obj_ptr = details::ObjectInternal::MoveTVMFFIObjectPtrFromObjectRef(&src); + static TVM_FFI_INLINE void MoveToAny(TObjRef src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; } @@ -132,7 +179,7 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { #if TVM_FFI_ALLOW_DYN_TYPE if (details::IsObjectInstance<ContainerType>(src->type_index)) { - return TObjRef(details::ObjectInternal::ObjectPtrFromUnowned<Object>(src->v_obj)); + return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); } #else TVM_FFI_THROW(RuntimeError) @@ -148,6 +195,75 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef static TVM_FFI_INLINE std::string TypeStr() { return ContainerType::_type_key; } }; +// Traits for object +template <typename T> +struct TypeTraits<ObjectPtr<T>> { + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(const ObjectPtr<T>& src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectPtr(src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE void MoveToAny(ObjectPtr<T> src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectPtr(&src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE std::optional<ObjectPtr<T>> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { +#if TVM_FFI_ALLOW_DYN_TYPE + if (details::IsObjectInstance<T>(src->type_index)) { + return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj); + } +#else + TVM_FFI_THROW(RuntimeError) + << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on". +#endif + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return T::_type_key; } +}; + +// Traits for object +template <typename TObject> +struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TObject>>> { + static constexpr bool enabled = true; + + static TVM_FFI_INLINE void ConvertToAnyView(const TObject* src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE void MoveToAny(const TObject* src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + details::ObjectUnsafe::IncRefObjectInAny(result); + } + + static TVM_FFI_INLINE std::optional<const TObject*> TryConvertFromAnyView(const TVMFFIAny* src) { + if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { +#if TVM_FFI_ALLOW_DYN_TYPE + if (details::IsObjectInstance<TObject>(src->type_index)) { + return reinterpret_cast<const TObject*>(src->v_obj); + } +#else + TVM_FFI_THROW(RuntimeError) + << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on". +#endif + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return TObject::_type_key; } +}; + /*! * \brief Get type key from type index * \param type_index The input type index diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc index 791ecfe68c..02d0ad6a23 100644 --- a/ffi/tests/example/test_any.cc +++ b/ffi/tests/example/test_any.cc @@ -114,6 +114,28 @@ TEST(Any, Object) { AnyView view2 = any1; EXPECT_EQ(v1.use_count(), 2); + // convert to weak raw object ptr + const TIntObj* v1_ptr = view2; + EXPECT_EQ(v1.use_count(), 2); + EXPECT_EQ(v1_ptr->value, 11); + Any any2 = v1_ptr; + EXPECT_EQ(v1.use_count(), 3); + EXPECT_TRUE(any2.TryAs<TInt>().has_value()); + + // convert to raw opaque ptr + void* raw_v1_ptr = const_cast<TIntObj*>(v1_ptr); + any2 = raw_v1_ptr; + EXPECT_TRUE(any2.TryAs<void*>().value() == v1_ptr); + + // convert to ObjectPtr + ObjectPtr<TNumberObj> v1_obj_ptr = view2; + EXPECT_EQ(v1.use_count(), 3); + any2 = v1_obj_ptr; + EXPECT_EQ(v1.use_count(), 4); + EXPECT_TRUE(any2.TryAs<TInt>().has_value()); + any2.reset(); + v1_obj_ptr.reset(); + // convert that triggers error EXPECT_THROW( {
