This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 67986d3369e58ec3b4ccd0c595a02b429d251fd0
Author: tqchen <[email protected]>
AuthorDate: Fri Sep 6 18:57:38 2024 -0400

    [FFI] Improve Any coverage to include ObjectPtr
---
 ffi/include/tvm/ffi/any.h         |  10 +--
 ffi/include/tvm/ffi/memory.h      |   4 +-
 ffi/include/tvm/ffi/object.h      |  29 ++++++---
 ffi/include/tvm/ffi/type_traits.h | 128 ++++++++++++++++++++++++++++++++++++--
 ffi/tests/example/test_any.cc     |  22 +++++++
 5 files changed, 173 insertions(+), 20 deletions(-)

diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h
index 6f2935f8a6..2a2d223c57 100644
--- a/ffi/include/tvm/ffi/any.h
+++ b/ffi/include/tvm/ffi/any.h
@@ -96,6 +96,7 @@ class AnyView {
     }
     TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << 
TypeIndex2TypeKey(data_.type_index)
                              << "` to `" << TypeTraits<T>::TypeStr() << "`";
+    TVM_FFI_UNREACHABLE();
   }
   // The following functions are only used for testing purposes
   /*!
@@ -126,7 +127,7 @@ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* 
data,
                                                [[maybe_unused]] size_t 
extra_any_bytes = 0) {
   // TODO: string conversion.
   if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) {
-    details::ObjectInternal::IncRefObjectInAny(data);
+    details::ObjectUnsafe::IncRefObjectInAny(data);
   }
 }
 }  // namespace details
@@ -145,7 +146,7 @@ class Any {
    */
   void reset() {
     if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) {
-      details::ObjectInternal::DecRefObjectInAny(&data_);
+      details::ObjectUnsafe::DecRefObjectInAny(&data_);
     }
     data_.type_index = TVMFFITypeIndex::kTVMFFINone;
   }
@@ -166,7 +167,7 @@ class Any {
   // constructors from Any
   Any(const Any& other) : data_(other.data_) {
     if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
-      details::ObjectInternal::IncRefObjectInAny(&data_);
+      details::ObjectUnsafe::IncRefObjectInAny(&data_);
     }
   }
   Any(Any&& other) : data_(other.data_) { other.data_.type_index = 
TypeIndex::kTVMFFINone; }
@@ -192,7 +193,7 @@ class Any {
   // constructor from general types
   template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>>
   Any(T other) {  // NOLINT(*)
-    TypeTraits<T>::MoveToManagedAny(std::move(other), &data_);
+    TypeTraits<T>::MoveToAny(std::move(other), &data_);
   }
   template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>>
   Any& operator=(T other) {  // NOLINT(*)
@@ -213,6 +214,7 @@ class Any {
     }
     TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << 
TypeIndex2TypeKey(data_.type_index)
                              << "` to `" << TypeTraits<T>::TypeStr() << "`";
+    TVM_FFI_UNREACHABLE();
   }
 
   // FFI related operations
diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h
index ca3021ff78..03cd542d91 100644
--- a/ffi/include/tvm/ffi/memory.h
+++ b/ffi/include/tvm/ffi/memory.h
@@ -74,11 +74,11 @@ class ObjAllocatorBase {
     using Handler = typename Derived::template Handler<T>;
     static_assert(std::is_base_of<Object, T>::value, "make can only be used to 
create Object");
     T* ptr = Handler::New(static_cast<Derived*>(this), 
std::forward<Args>(args)...);
-    TVMFFIObject* ffi_ptr = details::ObjectInternal::GetHeader(ptr);
+    TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
     // NOTE: ref_counter is initialized in object constructor
     ffi_ptr->type_index = T::RuntimeTypeIndex();
     ffi_ptr->deleter = Handler::Deleter();
-    return details::ObjectInternal::ObjectPtrFromUnowned<T>(ptr);
+    return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(ptr);
   }
 
   /*!
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index a1547b1900..b7a0d1a7c5 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -36,8 +36,9 @@ using TypeIndex = TVMFFITypeIndex;
 using TypeInfo = TVMFFITypeInfo;
 
 namespace details {
-// forward declare object internal
-struct ObjectInternal;
+// Helper to perform
+// unsafe operations related to object
+struct ObjectUnsafe;
 
 // Code section that depends on dynamic components
 #if TVM_FFI_ALLOW_DYN_TYPE
@@ -191,7 +192,7 @@ class Object {
   // friend classes
   template <typename>
   friend class ObjectPtr;
-  friend class tvm::ffi::details::ObjectInternal;
+  friend class tvm::ffi::details::ObjectUnsafe;
 };
 
 /*!
@@ -328,7 +329,7 @@ class ObjectPtr {
   friend struct ObjectPtrHash;
   template <typename>
   friend class ObjectPtr;
-  friend class tvm::ffi::details::ObjectInternal;
+  friend class tvm::ffi::details::ObjectUnsafe;
   template <typename RelayRefType, typename ObjType>
   friend RelayRefType GetRef(const ObjType* ptr);
   template <typename BaseType, typename ObjType>
@@ -417,7 +418,7 @@ class ObjectRef {
   Object* get_mutable() const { return data_.get(); }
   // friend classes.
   friend struct ObjectPtrHash;
-  friend class tvm::ffi::details::ObjectInternal;
+  friend class tvm::ffi::details::ObjectUnsafe;
   template <typename SubRef, typename BaseRef>
   friend SubRef Downcast(BaseRef ref);
 };
@@ -574,10 +575,10 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t 
object_type_index) {
  * \note These functions are only supposed to be used by internal
  * implementations and not external users of the tvm::ffi
  */
-struct ObjectInternal {
+struct ObjectUnsafe {
   // NOTE: get ffi header from an object
-  static TVM_FFI_INLINE TVMFFIObject* GetHeader(Object* src) {
-    return &(src->header_);
+  static TVM_FFI_INLINE TVMFFIObject* GetHeader(const Object* src) {
+    return const_cast<TVMFFIObject*>(&(src->header_));
   }
 
   // Create ObjectPtr from unknowned ptr
@@ -610,6 +611,18 @@ struct ObjectInternal {
     return GetHeader(obj_ptr);
   }
 
+  template <typename T>
+  static TVM_FFI_INLINE TVMFFIObject* GetTVMFFIObjectPtrFromObjectPtr(const 
ObjectPtr<T>& src) {
+    return GetHeader(src.data_);
+  }
+
+  template <typename T>
+  static TVM_FFI_INLINE TVMFFIObject* 
MoveTVMFFIObjectPtrFromObjectPtr(ObjectPtr<T>* src) {
+    Object* obj_ptr = src->data_;
+    src->data_ = nullptr;
+    return GetHeader(obj_ptr);
+  }
+
   // Create objectptr by moving from an existing address of object and setting 
its
   // address to nullptr
   template <typename T>
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index 18c7ab2a35..e1d453fdba 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -59,6 +59,29 @@ struct TypeTraits {
 template <typename T>
 using TypeTraitsNoCR = 
TypeTraits<std::remove_const_t<std::remove_reference_t<T>>>;
 
+// None
+template <>
+struct TypeTraits<std::nullptr_t> {
+  static constexpr bool enabled = true;
+
+  static TVM_FFI_INLINE void ConvertToAnyView(const std::nullptr_t&, 
TVMFFIAny* result) {
+    result->type_index = TypeIndex::kTVMFFINone;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(std::nullptr_t, TVMFFIAny* result) {
+    result->type_index = TypeIndex::kTVMFFINone;
+  }
+
+  static TVM_FFI_INLINE std::optional<std::nullptr_t> 
TryConvertFromAnyView(const TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFINone) {
+      return nullptr;
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "None"; }
+};
+
 // Integer POD values
 template <typename Int>
 struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> {
@@ -69,7 +92,7 @@ struct TypeTraits<Int, 
std::enable_if_t<std::is_integral_v<Int>>> {
     result->v_int64 = static_cast<int64_t>(src);
   }
 
-  static TVM_FFI_INLINE void MoveToManagedAny(Int src, TVMFFIAny* result) {
+  static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) {
     ConvertToAnyView(src, result);
   }
 
@@ -93,7 +116,7 @@ struct TypeTraits<Float, 
std::enable_if_t<std::is_floating_point_v<Float>>> {
     result->v_float64 = static_cast<double>(src);
   }
 
-  static TVM_FFI_INLINE void MoveToManagedAny(Float src, TVMFFIAny* result) {
+  static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) {
     ConvertToAnyView(src, result);
   }
 
@@ -109,6 +132,30 @@ struct TypeTraits<Float, 
std::enable_if_t<std::is_floating_point_v<Float>>> {
   static TVM_FFI_INLINE std::string TypeStr() { return "float"; }
 };
 
+// void*
+template <>
+struct TypeTraits<void*> {
+  static constexpr bool enabled = true;
+
+  static TVM_FFI_INLINE void ConvertToAnyView(void* src, TVMFFIAny* result) {
+    result->type_index = TypeIndex::kTVMFFIOpaquePtr;
+    result->v_ptr = src;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) {
+    ConvertToAnyView(src, result);
+  }
+
+  static TVM_FFI_INLINE std::optional<void*> TryConvertFromAnyView(const 
TVMFFIAny* src) {
+    if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) {
+      return std::make_optional<void*>(src->v_ptr);
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return "void*"; }
+};
+
 // Traits for object
 template <typename TObjRef>
 struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, 
TObjRef>>> {
@@ -117,13 +164,13 @@ struct TypeTraits<TObjRef, 
std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef
   static constexpr bool enabled = true;
 
   static TVM_FFI_INLINE void ConvertToAnyView(const TObjRef& src, TVMFFIAny* 
result) {
-    TVMFFIObject* obj_ptr = 
details::ObjectInternal::GetTVMFFIObjectPtrFromObjectRef(src);
+    TVMFFIObject* obj_ptr = 
details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(src);
     result->type_index = obj_ptr->type_index;
     result->v_obj = obj_ptr;
   }
 
-  static TVM_FFI_INLINE void MoveToManagedAny(TObjRef src, TVMFFIAny* result) {
-    TVMFFIObject* obj_ptr = 
details::ObjectInternal::MoveTVMFFIObjectPtrFromObjectRef(&src);
+  static TVM_FFI_INLINE void MoveToAny(TObjRef src, TVMFFIAny* result) {
+    TVMFFIObject* obj_ptr = 
details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&src);
     result->type_index = obj_ptr->type_index;
     result->v_obj = obj_ptr;
   }
@@ -132,7 +179,7 @@ struct TypeTraits<TObjRef, 
std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef
     if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
 #if TVM_FFI_ALLOW_DYN_TYPE
       if (details::IsObjectInstance<ContainerType>(src->type_index)) {
-        return 
TObjRef(details::ObjectInternal::ObjectPtrFromUnowned<Object>(src->v_obj));
+        return 
TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj));
       }
 #else
       TVM_FFI_THROW(RuntimeError)
@@ -148,6 +195,75 @@ struct TypeTraits<TObjRef, 
std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef
   static TVM_FFI_INLINE std::string TypeStr() { return 
ContainerType::_type_key; }
 };
 
+// Traits for object
+template <typename T>
+struct TypeTraits<ObjectPtr<T>> {
+  static constexpr bool enabled = true;
+
+  static TVM_FFI_INLINE void ConvertToAnyView(const ObjectPtr<T>& src, 
TVMFFIAny* result) {
+    TVMFFIObject* obj_ptr = 
details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectPtr(src);
+    result->type_index = obj_ptr->type_index;
+    result->v_obj = obj_ptr;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(ObjectPtr<T> src, TVMFFIAny* result) {
+    TVMFFIObject* obj_ptr = 
details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectPtr(&src);
+    result->type_index = obj_ptr->type_index;
+    result->v_obj = obj_ptr;
+  }
+
+  static TVM_FFI_INLINE std::optional<ObjectPtr<T>> 
TryConvertFromAnyView(const TVMFFIAny* src) {
+    if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+#if TVM_FFI_ALLOW_DYN_TYPE
+      if (details::IsObjectInstance<T>(src->type_index)) {
+        return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj);
+      }
+#else
+      TVM_FFI_THROW(RuntimeError)
+          << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on".
+#endif
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return T::_type_key; }
+};
+
+// Traits for object
+template <typename TObject>
+struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, 
TObject>>> {
+  static constexpr bool enabled = true;
+
+  static TVM_FFI_INLINE void ConvertToAnyView(const TObject* src, TVMFFIAny* 
result) {
+    TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
+    result->type_index = obj_ptr->type_index;
+    result->v_obj = obj_ptr;
+  }
+
+  static TVM_FFI_INLINE void MoveToAny(const TObject* src, TVMFFIAny* result) {
+    TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
+    result->type_index = obj_ptr->type_index;
+    result->v_obj = obj_ptr;
+    details::ObjectUnsafe::IncRefObjectInAny(result);
+  }
+
+  static TVM_FFI_INLINE std::optional<const TObject*> 
TryConvertFromAnyView(const TVMFFIAny* src) {
+    if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+#if TVM_FFI_ALLOW_DYN_TYPE
+      if (details::IsObjectInstance<TObject>(src->type_index)) {
+        return reinterpret_cast<const TObject*>(src->v_obj);
+      }
+#else
+      TVM_FFI_THROW(RuntimeError)
+          << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on".
+#endif
+    }
+    return std::nullopt;
+  }
+
+  static TVM_FFI_INLINE std::string TypeStr() { return TObject::_type_key; }
+};
+
 /*!
  * \brief Get type key from type index
  * \param type_index The input type index
diff --git a/ffi/tests/example/test_any.cc b/ffi/tests/example/test_any.cc
index 791ecfe68c..02d0ad6a23 100644
--- a/ffi/tests/example/test_any.cc
+++ b/ffi/tests/example/test_any.cc
@@ -114,6 +114,28 @@ TEST(Any, Object) {
   AnyView view2 = any1;
   EXPECT_EQ(v1.use_count(), 2);
 
+  // convert to weak raw object ptr
+  const TIntObj* v1_ptr = view2;
+  EXPECT_EQ(v1.use_count(), 2);
+  EXPECT_EQ(v1_ptr->value, 11);
+  Any any2 = v1_ptr;
+  EXPECT_EQ(v1.use_count(), 3);
+  EXPECT_TRUE(any2.TryAs<TInt>().has_value());
+
+  // convert to raw opaque ptr
+  void* raw_v1_ptr = const_cast<TIntObj*>(v1_ptr);
+  any2 = raw_v1_ptr;
+  EXPECT_TRUE(any2.TryAs<void*>().value() == v1_ptr);
+
+  // convert to ObjectPtr
+  ObjectPtr<TNumberObj> v1_obj_ptr = view2;
+  EXPECT_EQ(v1.use_count(), 3);
+  any2 = v1_obj_ptr;
+  EXPECT_EQ(v1.use_count(), 4);
+  EXPECT_TRUE(any2.TryAs<TInt>().has_value());
+  any2.reset();
+  v1_obj_ptr.reset();
+
   // convert that triggers error
   EXPECT_THROW(
       {

Reply via email to