This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 89ccf3f9677d0e8443f379c26fe220a4fbd91d78
Author: tqchen <[email protected]>
AuthorDate: Wed Aug 14 14:53:14 2024 -0400

    [FFI] Object type hierachy cast and check support
    
    Co-authored-by: Junru Shao <[email protected]>
---
 ffi/include/tvm/ffi/c_api.h      |  26 +++-
 ffi/include/tvm/ffi/error.h      |   2 +-
 ffi/include/tvm/ffi/object.h     | 201 +++++++++++++++++++++--------
 ffi/src/ffi/object.cc            | 265 +++++++++++++++++++++------------------
 ffi/tests/example/test_error.cc  |   1 -
 ffi/tests/example/test_object.cc |  93 +++++++++++++-
 6 files changed, 410 insertions(+), 178 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 2616ae1648..e0597d0dda 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -85,8 +85,7 @@ typedef enum {
   // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo)
   // kTVMFFIDynObject is used to indicate that the type index
   // is dynamic and needs to be looked up at runtime
-  kTVMFFIDynObject = 128,
-  kTVMFFIDynObjectBegin = 129
+  kTVMFFIDynObjectBegin = 128
 #ifdef __cplusplus
 };
 #else
@@ -142,6 +141,29 @@ typedef struct {
   const char* bytes;
 } TVMFFIByteArray;
 
+/*!
+ * \brief Runtime type information for object type checking.
+ */
+typedef struct {
+  /*!
+   *\brief The runtime type index,
+   * It can be allocated during runtime if the type is dynamic.
+   */
+  int32_t type_index;
+  /*! \brief number of parent types in the type hierachy. */
+  int32_t type_depth;
+  /*! \brief the unique type key to identify the type. */
+  const char* type_key;
+  /*! \brief Cached hash value of the type key, used for consistent structural 
hashing. */
+  uint64_t type_key_hash;
+  /*!
+   * \brief type_acenstors[depth] stores the type_index of the acenstors at 
depth level
+   * \note To keep things simple, we do not allow multiple inheritance so the
+   *       hieracy stays as a tree
+   */
+  const int32_t* type_acenstors;
+} TVMFFITypeInfo;
+
 #ifdef __cplusplus
 }  // TVM_FFI_EXTERN_C
 #endif
diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h
index 250206ac8b..f894166ef3 100644
--- a/ffi/include/tvm/ffi/error.h
+++ b/ffi/include/tvm/ffi/error.h
@@ -95,7 +95,7 @@ class Error :
     return get()->what_str.c_str();
   }
 
-  TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj)
+  TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj);
 };
 
 namespace details {
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 03ba7d7dbf..5a6b552e1c 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -33,10 +33,44 @@ namespace tvm {
 namespace ffi {
 
 using TypeIndex = TVMFFITypeIndex;
+using TypeInfo = TVMFFITypeInfo;
 
 namespace details {
 // forward declare object internal
 struct ObjectInternal;
+
+// Code section that depends on dynamic components
+#if TVM_FFI_ALLOW_DYN_TYPE
+/*!
+ * \brief Initialize the type info during runtime.
+ *
+ *  When the function is first time called for a type,
+ *  it will register the type to the type table in the runtime.
+ *
+ *  If the static_tindex is non-negative, the function will
+ *  allocate a runtime type index.
+ *  Otherwise, we will populate the type table and return the static index.
+ *
+ * \param type_key The type key.
+ * \param static_type_index Static type index if any, can be -1, which means 
this is a dynamic index
+ * \param num_child_slots Number of slots reserved for its children.
+ * \param child_slots_can_overflow Whether to allow child to overflow the 
slots.
+ * \param parent_type_index Parent type index, pass in -1 if it is root.
+ *
+ * \return The allocated type index
+ */
+TVM_FFI_DLL int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t 
static_type_index,
+                                              int32_t type_depth, int32_t 
num_child_slots,
+                                              bool child_slots_can_overflow,
+                                              int32_t parent_type_index);
+
+/*!
+ * \brief Get Type information from type index.
+ * \param type_index The type index
+ * \return The type information
+ */
+TVM_FFI_DLL const TypeInfo* ObjectGetTypeInfo(int32_t type_index);
+#endif  // TVM_FFI_ALLOW_DYN_TYPE
 }  // namespace details
 
 /*!
@@ -89,6 +123,47 @@ class Object {
     header_.deleter = nullptr;
   }
 
+  /*!
+   * Check if the object is an instance of TargetType.
+   * \tparam TargetType The target type to be checked.
+   * \return Whether the target type is true.
+   */
+  template <typename TargetType>
+  bool IsInstance() const {
+    // Everything is a subclass of object.
+    if constexpr (std::is_same<TargetType, Object>::value) return true;
+
+    if constexpr (TargetType::_type_final) {
+      // if the target type is a final type
+      // then we only need to check the equivalence.
+      return header_.type_index == TargetType::RuntimeTypeIndex();
+    }
+
+    // if target type is a non-leaf type
+    // Check if type index falls into the range of reserved slots.
+    int32_t target_type_index = TargetType::RuntimeTypeIndex();
+    int32_t begin = target_type_index;
+    // The condition will be optimized by constant-folding.
+    if constexpr (TargetType::_type_child_slots != 0) {
+      int32_t end = begin + TargetType::_type_child_slots;
+      if (header_.type_index >= begin && header_.type_index < end) return true;
+    } else {
+      if (header_.type_index == begin) return true;
+    }
+    if (!TargetType::_type_child_slots_can_overflow) return false;
+    // Invariance: parent index is always smaller than the child.
+    if (header_.type_index < target_type_index) return false;
+      // Do a runtime lookup of type information
+#if TVM_FFI_ALLOW_DYN_TYPE
+    // the function checks that the info exists
+    const TypeInfo* type_info = details::ObjectGetTypeInfo(header_.type_index);
+    return (type_info->type_depth > TargetType::_type_depth &&
+            type_info->type_acenstors[TargetType::_type_depth] == 
target_type_index);
+#else
+    return false;
+#endif
+  }
+
   // Information about the object
   static constexpr const char* _type_key = "runtime.Object";
 
@@ -96,11 +171,10 @@ class Object {
   static constexpr bool _type_final = false;
   static constexpr uint32_t _type_child_slots = 0;
   static constexpr bool _type_child_slots_can_overflow = true;
-  // NOTE: the following field is not type index of Object
-  // but was intended to be used by sub-classes as default value.
-  // The type index of Object is TypeIndex::kRoot
+  // NOTE: static type index field of the class
   static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject;
-
+  // the static type depth of the class
+  static constexpr int32_t _type_depth = 0;
   // The following functions are provided by macro
   // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO
   /*!
@@ -110,7 +184,6 @@ class Object {
   static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; }
   /*!
    * \brief Internal function to get or allocate a runtime index.
-   * \note
    */
   static int32_t _GetOrAllocRuntimeTypeIndex() { return 
TypeIndex::kTVMFFIObject; }
 
@@ -342,7 +415,13 @@ class ObjectRef {
    * \tparam ObjectType the target type, must be a subtype of Object
    */
   template <typename ObjectType, typename = 
std::enable_if_t<std::is_base_of_v<Object, ObjectType>>>
-  inline const ObjectType* as() const;
+  const ObjectType* as() const {
+    if (data_ != nullptr && data_->IsInstance<ObjectType>()) {
+      return static_cast<ObjectType*>(data_.get());
+    } else {
+      return nullptr;
+    }
+  }
 
   /*! \brief type indicate the container type. */
   using ContainerType = Object;
@@ -377,14 +456,30 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
  * \param TypeName The name of the current type.
  * \param ParentType The name of the ParentType
  */
-#define TVM_FFI_OBJECT_STATIC_CHECKS(TypeName, ParentType)                     
           \
+#define TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType)                       
           \
+  static constexpr int32_t _type_depth = ParentType::_type_depth + 1;          
           \
   static_assert(!ParentType::_type_final, "ParentType marked as final");       
           \
   static_assert(TypeName::_type_child_slots == 0 || 
ParentType::_type_child_slots == 0 || \
                     TypeName::_type_child_slots < 
ParentType::_type_child_slots,          \
                 "Need to set _type_child_slots when parent specifies it.");    
           \
   static_assert(TypeName::_type_child_slots == 0 || 
ParentType::_type_child_slots == 0 || \
                     TypeName::_type_child_slots < 
ParentType::_type_child_slots,          \
-                "Need to set _type_child_slots when parent specifies it.");
+                "Need to set _type_child_slots when parent specifies it.")
+
+// If dynamic type is enabled, we still need to register the runtime type of 
parent
+#if TVM_FFI_ALLOW_DYN_TYPE
+#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType)                
\
+  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
\
+    static int32_t tindex = ::tvm::ffi::details::ObjectGetOrAllocTypeIndex(    
\
+        TypeName::_type_key, TypeName::_type_index, TypeName::_type_depth,     
\
+        TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow, 
\
+        ParentType::_GetOrAllocRuntimeTypeIndex());                            
\
+    return tindex;                                                             
\
+  }                                                                            
\
+  static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex()
+#else
+#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType)
+#endif
 
 /*!
  * \brief Helper macro to declare a object that comes with static type index.
@@ -392,25 +487,38 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
  * \param ParentType The name of the ParentType
  */
 #define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \
-  TVM_FFI_OBJECT_STATIC_CHECKS(TypeName, ParentType)             \
-  static int32_t RuntimeTypeIndex() { return TypeName::_type_index; }
+  TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType);        \
+  static int32_t RuntimeTypeIndex() { return TypeName::_type_index; }\
+  TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType)
 
 /*!
  * \brief helper macro to declare a base object type that can be inherited.
  * \param TypeName The name of the current type.
  * \param ParentType The name of the ParentType
  */
-#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                     
 \
-  static_assert(TVM_FFI_ALLOW_DYN_TYPE,                                        
 \
-                "Dynamic object depend on TVM_FFI_ALLOW_DYN_TYPE cd set to 
1"); \
-  TVM_FFI_OBJECT_STATIC_CHECKS(TypaName, ParentType)                           
 \
-  static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex();           
 \
-  static int32_t RuntimeTypeIndex() { return TypeName::_type_index; }          
 \
-  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
 \
-    return ::tvm::ffi::details::ObjectGetOrAllocTypeIndex(                     
 \
-        TypeName::_type_key, -1, ParentType::_GetOrAllocRuntimeTypeIndex(),    
 \
-        TypeName::_type_child_slots, 
TypeName::_type_child_slots_can_overflow); \
-  }
+#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                 
               \
+  static_assert(TVM_FFI_ALLOW_DYN_TYPE,                                        
               \
+                "Dynamic object depend on TVM_FFI_ALLOW_DYN_TYPE cd set to 
1");               \
+  TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType);                            
               \
+  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
               \
+    static int32_t tindex = ::tvm::ffi::details::ObjectGetOrAllocTypeIndex(    
               \
+        TypeName::_type_key, -1, TypeName::_type_depth, 
TypeName::_type_child_slots,          \
+        TypeName::_type_child_slots_can_overflow, 
ParentType::_GetOrAllocRuntimeTypeIndex()); \
+    return tindex;                                                             
               \
+  }                                                                            
               \
+  static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); }  
               \
+  static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex()
+
+/*!
+ * \brief helper macro to declare type information in a final class.
+ * \param TypeName The name of the current type.
+ * \param ParentType The name of the ParentType
+ */
+#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
+  static const constexpr int _type_child_slots = 0;             \
+  static const constexpr bool _type_final = true;                \
+  TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
+
 
 /*
  * \brief Define object reference methods.
@@ -424,7 +532,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
   TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName)                        
                \
   const ObjectName* operator->() const { return static_cast<const 
ObjectName*>(data_.get()); } \
   const ObjectName* get() const { return operator->(); }                       
                \
-  using ContainerType = ObjectName;
+  using ContainerType = ObjectName
 
 /*
  * \brief Define object reference methods that is not nullable.
@@ -439,9 +547,25 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr);
   const ObjectName* operator->() const { return static_cast<const 
ObjectName*>(data_.get()); } \
   const ObjectName* get() const { return operator->(); }                       
                \
   static constexpr bool _type_is_nullable = false;                             
                \
-  using ContainerType = ObjectName;
+  using ContainerType = ObjectName
 
 namespace details {
+
+// auxiliary class to enable static type info table at depth
+template <int depth>
+struct TypeInfoAtDepth : public TypeInfo {
+  /*! \brief extra type acenstors fields */
+  int32_t _type_acenstors[depth];
+
+  TypeInfoAtDepth(const char* type_key, int32_t static_type_index) {
+    this->type_key = type_key;
+    this->type_key_hash = 0;
+    this->type_index = static_type_index;
+    this->type_depth = depth;
+    this->type_acenstors = _type_acenstors;
+  }
+};
+
 /*!
  * \brief Namespace to internally manipulate object class.
  * \note These functions are only supposed to be used by internal
@@ -469,37 +593,6 @@ struct ObjectInternal {
   }
 };
 
-// Code section that depends on dynamic components
-#if TVM_FFI_ALLOW_DYN_TYPE
-/*!
- * \brief Get the type index using type key.
- *
- *  When the function is first time called for a type,
- *  it will register the type to the type table in the runtime.
- *  If the static_tindex is TypeIndex::kDynamic, the function will
- *  allocate a runtime type index.
- *  Otherwise, we will populate the type table and return the static index.
- *
- * \param type_key the type key.
- * \param static_tindex Static type index if any, can be -1, which means this 
is a dynamic index
- * \param parent_tindex The index of the parent.
- * \param type_child_slots Number of slots reserved for its children.
- * \param type_child_slots_can_overflow Whether to allow child to overflow the 
slots.
- *
- * \return The allocated type index
- */
-TVM_FFI_DLL int ObjectGetOrAllocTypeIndex(const char* type_key, int32_t 
static_tindex,
-                                          int32_t parent_tindex, int32_t 
type_child_slots,
-                                          bool type_child_slots_can_overflow);
-
-/*!
- * \brief Check whether child type is derived from parent type.
- * \param child_type_index The candidate child type index.
- * \param parent_type_index The candidate parent type index.
- * \return the Check result.
- */
-TVM_FFI_DLL bool ObjectDerivedFrom(int32_t child_type_index, int32_t 
parent_type_index);
-#endif  // TVM_FFI_ALLOW_DYN_TYPE
 }  // namespace details
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index ac8e04ebb9..f4e2ff3e30 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -23,6 +23,7 @@
 #include <tvm/ffi/c_api.h>
 #include <tvm/ffi/error.h>
 
+#include <memory>
 #include <string>
 #include <unordered_map>
 #include <utility>
@@ -31,26 +32,6 @@
 namespace tvm {
 namespace ffi {
 
-/*! \brief Type information */
-struct TypeInfo {
-  /*! \brief The current index. */
-  int32_t index{0};
-  /*! \brief Index of the parent in the type hierarchy */
-  int32_t parent_index{0};
-  // NOTE: the indices in [index, index + num_reserved_slots) are
-  // reserved for the child-class of this type.
-  /*! \brief Total number of slots reserved for the type and its children. */
-  int32_t num_slots{0};
-  /*! \brief number of allocated child slots. */
-  int32_t allocated_slots{0};
-  /*! \brief Whether child can overflow. */
-  bool child_slots_can_overflow{true};
-  /*! \brief name of the type. */
-  std::string name;
-  /*! \brief hash of the name */
-  size_t name_hash{0};
-};
-
 /*!
  * \brief Type context that manages the type hierarchy information.
  *
@@ -61,150 +42,196 @@ struct TypeInfo {
  *
  * Then the followup code will leverage the information
  */
-class TypeContext {
+class TypeTable {
  public:
-  // NOTE: this is a relatively slow path for child checking
-  // Most types are already checked by the fast-path via reserved slot 
checking.
-  bool DerivedFrom(int32_t child_tindex, int32_t parent_tindex) {
-    // invariance: child's type index is always bigger than its parent.
-    if (child_tindex < parent_tindex) return false;
-    if (child_tindex == parent_tindex) return true;
-    TVM_FFI_ICHECK_LT(child_tindex, type_table_.size());
-    while (child_tindex > parent_tindex) {
-      child_tindex = type_table_[child_tindex].parent_index;
+  /*! \brief Type information */
+  struct Entry : public TypeInfo {
+    /*! \brief stored type key */
+    std::string type_key_data;
+    /*! \brief acenstor information */
+    std::vector<int32_t> type_acenstors_data;
+    // NOTE: the indices in [index, index + num_reserved_slots) are
+    // reserved for the child-class of this type.
+    /*! \brief Total number of slots reserved for the type and its children. */
+    int32_t num_slots;
+    /*! \brief number of allocated child slots. */
+    int32_t allocated_slots;
+    /*! \brief Whether child can overflow. */
+    bool child_slots_can_overflow{true};
+
+    Entry(int32_t type_index, int32_t type_depth, std::string type_key, 
int32_t num_slots,
+          bool child_slots_can_overflow, const Entry* parent) {
+      // setup fields in the class
+      this->type_key_data = std::move(type_key);
+      this->num_slots = num_slots;
+      this->allocated_slots = 1;
+      this->child_slots_can_overflow = child_slots_can_overflow;
+      // set up type acenstors information
+      if (type_depth != 0) {
+        TVM_FFI_ICHECK_NOTNULL(parent);
+        TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1);
+        type_acenstors_data.resize(type_depth);
+        // copy over parent's type information
+        for (int32_t i = 0; i < parent->type_depth; ++i) {
+          type_acenstors_data[i] = parent->type_acenstors[i];
+        }
+        // set last type information to be parent
+        type_acenstors_data[parent->type_depth] = parent->type_index;
+      }
+      // initialize type info: no change to type_key and type_acenstors fields
+      // after this line
+      this->type_index = type_index;
+      this->type_depth = type_depth;
+      this->type_key = this->type_key_data.c_str();
+      this->type_key_hash = std::hash<std::string>()(this->type_key_data);
+      this->type_acenstors = type_acenstors_data.data();
     }
-    return child_tindex == parent_tindex;
-  }
+  };
 
-  int32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, int32_t 
static_tindex,
-                                     int32_t parent_tindex, int32_t 
num_child_slots,
-                                     bool child_slots_can_overflow) {
-    auto it = type_key2index_.find(skey);
+  int32_t GetOrAllocTypeIndex(std::string type_key, int32_t static_type_index, 
int32_t type_depth,
+                              int32_t num_child_slots, bool 
child_slots_can_overflow,
+                              int32_t parent_type_index) {
+    auto it = type_key2index_.find(type_key);
     if (it != type_key2index_.end()) {
-      return it->second;
+      return type_table_[it->second]->type_index;
     }
-    // try to allocate from parent's type table.
-    TVM_FFI_ICHECK_LT(parent_tindex, type_table_.size())
-        << " skey=" << skey << ", static_index=" << static_tindex;
 
-    TypeInfo& pinfo = type_table_[parent_tindex];
-    TVM_FFI_ICHECK_EQ(pinfo.index, parent_tindex);
+    // get parent's entry
+    Entry* parent = [&]() -> Entry* {
+      if (parent_type_index < 0) return nullptr;
+      // try to allocate from parent's type table.
+      TVM_FFI_ICHECK_LT(parent_type_index, type_table_.size())
+          << " type_key=" << type_key << ", static_index=" << 
static_type_index;
+      return type_table_[parent_type_index].get();
+    }();
+
+    // get allocated index
+    int32_t allocated_tindex = [&]() {
+      // Step 0: static allocation
+      if (static_type_index >= 0) {
+        TVM_FFI_ICHECK_LT(static_type_index, type_table_.size());
+        TVM_FFI_ICHECK(type_table_[static_type_index] == nullptr)
+            << "Conflicting static index " << static_type_index << " between "
+            << type_table_[static_type_index]->type_key << " and " << type_key;
+        return static_type_index;
+      }
+      TVM_FFI_ICHECK_NOTNULL(parent);
+      int num_slots = num_child_slots + 1;
+      if (parent->allocated_slots + num_slots <= parent->num_slots) {
+        // allocate the slot from parent's reserved pool
+        int32_t allocated_tindex = parent->type_index + 
parent->allocated_slots;
+        // update parent's state
+        parent->allocated_slots += num_slots;
+        return allocated_tindex;
+      }
+      // Step 2: allocate from overflow
+      TVM_FFI_ICHECK(parent->child_slots_can_overflow)
+          << "Reach maximum number of sub-classes for " << parent->type_key;
+      // allocate new entries.
+      int32_t allocated_tindex = type_counter_;
+      type_counter_ += num_slots;
+      TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_);
+      type_table_.reserve(type_counter_);
+      // resize type table
+      while (static_cast<int32_t>(type_table_.size()) < type_counter_) {
+        type_table_.emplace_back(nullptr);
+      }
+      return allocated_tindex;
+    }();
 
     // if parent cannot overflow, then this class cannot.
-    if (!pinfo.child_slots_can_overflow) {
+    if (parent != nullptr && !(parent->child_slots_can_overflow)) {
       child_slots_can_overflow = false;
     }
-
     // total number of slots include the type itself.
-    int32_t num_slots = num_child_slots + 1;
-    int32_t allocated_tindex;
-
-    if (static_tindex > 0) {
-      // statically assigned type
-      allocated_tindex = static_tindex;
-      TVM_FFI_ICHECK_LT(static_tindex, type_table_.size());
-      TVM_FFI_ICHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
-          << "Conflicting static index " << static_tindex << " between "
-          << type_table_[allocated_tindex].name << " and " << skey;
-    } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
-      // allocate the slot from parent's reserved pool
-      allocated_tindex = parent_tindex + pinfo.allocated_slots;
-      // update parent's state
-      pinfo.allocated_slots += num_slots;
-    } else {
-      TVM_FFI_ICHECK(pinfo.child_slots_can_overflow)
-          << "Reach maximum number of sub-classes for " << pinfo.name;
-      // allocate new entries.
-      allocated_tindex = type_counter_;
-      type_counter_ += num_slots;
-      TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_);
-      type_table_.resize(type_counter_, TypeInfo());
+
+    if (parent != nullptr) {
+      TVM_FFI_ICHECK_GT(allocated_tindex, parent->type_index);
     }
-    TVM_FFI_ICHECK_GT(allocated_tindex, parent_tindex);
-    // initialize the slot.
-    type_table_[allocated_tindex].index = allocated_tindex;
-    type_table_[allocated_tindex].parent_index = parent_tindex;
-    type_table_[allocated_tindex].num_slots = num_slots;
-    type_table_[allocated_tindex].allocated_slots = 1;
-    type_table_[allocated_tindex].child_slots_can_overflow = 
child_slots_can_overflow;
-    type_table_[allocated_tindex].name = skey;
-    type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
+
+    type_table_[allocated_tindex] =
+        std::make_unique<Entry>(allocated_tindex, type_depth, type_key, 
num_child_slots + 1,
+                                child_slots_can_overflow, parent);
     // update the key2index mapping.
-    type_key2index_[skey] = allocated_tindex;
+    type_key2index_[type_key] = allocated_tindex;
     return allocated_tindex;
   }
 
-  const std::string& TypeIndex2Key(int32_t tindex) {
-    if (tindex != 0) {
-      // always return the right type key for root
-      // for non-root type nodes, allocated slots should not equal 0
-      TVM_FFI_ICHECK(tindex < static_cast<int32_t>(type_table_.size()) &&
-                     type_table_[tindex].allocated_slots != 0)
-          << "Unknown type index " << tindex;
-    }
-    return type_table_[tindex].name;
-  }
-
-  size_t TypeIndex2KeyHash(int32_t tindex) {
-    TVM_FFI_ICHECK(tindex < static_cast<int32_t>(type_table_.size()) &&
-                   type_table_[tindex].allocated_slots != 0)
-        << "Unknown type index " << tindex;
-    return type_table_[tindex].name_hash;
+  int32_t TypeKey2Index(const std::string& type_key) {
+    auto it = type_key2index_.find(type_key);
+    TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type " << 
type_key;
+    return it->second;
   }
 
-  int32_t TypeKey2Index(const std::string& skey) {
-    auto it = type_key2index_.find(skey);
-    TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type " << skey;
-    return it->second;
+  const TypeInfo* GetTypeInfo(int32_t type_index) {
+    const TypeInfo* info = nullptr;
+    if (type_index >= 0 && static_cast<size_t>(type_index) < 
type_table_.size()) {
+      info = type_table_[type_index].get();
+    }
+    TVM_FFI_ICHECK(info != nullptr) << "Cannot find type info for type_index=" 
<< type_index;
+    return info;
   }
 
   void Dump(int min_children_count) {
     std::vector<int> num_children(type_table_.size(), 0);
     // reverse accumulation so we can get total counts in a bottom-up manner.
     for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
-      if (it->index != 0) {
-        num_children[it->parent_index] += num_children[it->index] + 1;
+      const Entry* ptr = it->get();
+      if (ptr != nullptr && ptr->type_depth != 0) {
+        int parent_index = ptr->type_acenstors[ptr->type_depth - 1];
+        num_children[parent_index] += num_children[ptr->type_index] + 1;
       }
     }
 
-    for (const auto& info : type_table_) {
-      if (info.index != 0 && num_children[info.index] >= min_children_count) {
-        std::cerr << '[' << info.index << "] " << info.name
-                  << "\tparent=" << type_table_[info.parent_index].name
-                  << "\tnum_child_slots=" << info.num_slots - 1
-                  << "\tnum_children=" << num_children[info.index] << 
std::endl;
+    for (const auto& ptr : type_table_) {
+      if (ptr != nullptr && num_children[ptr->type_index] >= 
min_children_count) {
+        std::cerr << '[' << ptr->type_index << "]\t" << ptr->type_key;
+        if (ptr->type_depth != 0) {
+          int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1];
+          std::cerr << "\tparent=" << type_table_[parent_index]->type_key;
+        } else {
+          std::cerr << "\tparent=root";
+        }
+        std::cerr << "\tnum_child_slots=" << ptr->num_slots - 1
+                  << "\tnum_children=" << num_children[ptr->type_index] << 
std::endl;
       }
     }
   }
 
-  static TypeContext* Global() {
-    static TypeContext inst;
+  static TypeTable* Global() {
+    static TypeTable inst;
     return &inst;
   }
 
  private:
-  TypeContext() {
-    type_table_.resize(TypeIndex::kTVMFFIDynObjectBegin, TypeInfo());
-    type_table_[0].name = "runtime.Object";
+  TypeTable() {
+    type_table_.reserve(TypeIndex::kTVMFFIDynObjectBegin);
+    for (int32_t i = 0; i < TypeIndex::kTVMFFIDynObjectBegin; ++i) {
+      type_table_.emplace_back(nullptr);
+    }
+    // initialize the entry for object
+    this->GetOrAllocTypeIndex(Object::_type_key, Object::_type_index, 
Object::_type_depth,
+                              Object::_type_child_slots, 
Object::_type_child_slots_can_overflow,
+                              -1);
   }
 
   int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin};
-  std::vector<TypeInfo> type_table_;
+  std::vector<std::unique_ptr<Entry>> type_table_;
   std::unordered_map<std::string, int32_t> type_key2index_;
 };
 
 namespace details {
 
-int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t static_tindex,
-                                  int32_t parent_tindex, int32_t 
type_child_slots,
-                                  bool type_child_slots_can_overflow) {
-  return tvm::ffi::TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
-      type_key, static_tindex, parent_tindex, type_child_slots, 
type_child_slots_can_overflow != 0);
+int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t 
static_type_index,
+                                  int32_t type_depth, int32_t num_child_slots,
+                                  bool child_slots_can_overflow, int32_t 
parent_index) {
+  return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex(type_key, 
static_type_index, type_depth,
+                                                            num_child_slots,
+                                                            
child_slots_can_overflow, parent_index);
 }
 
-bool ObjectDerivedFrom(int32_t child_type_index, int32_t parent_type_index) {
-  return static_cast<int>(
-      tvm::ffi::TypeContext::Global()->DerivedFrom(child_type_index, 
parent_type_index));
+const TypeInfo* ObjectGetTypeInfo(int32_t type_index) {
+  return tvm::ffi::TypeTable::Global()->GetTypeInfo(type_index);
 }
 }  // namespace details
 }  // namespace ffi
diff --git a/ffi/tests/example/test_error.cc b/ffi/tests/example/test_error.cc
index 4f208b6999..015551c953 100644
--- a/ffi/tests/example/test_error.cc
+++ b/ffi/tests/example/test_error.cc
@@ -33,7 +33,6 @@ TEST(CheckError, Traceback) {
           TVM_FFI_ICHECK_GT(2, 3);
         } catch (const Error& error) {
           EXPECT_EQ(error->kind, "InternalError");
-          std::cout << error.what();
           std::string what = error.what();
           EXPECT_NE(what.find("line"), std::string::npos);
           EXPECT_NE(what.find("2 > 3"), std::string::npos);
diff --git a/ffi/tests/example/test_object.cc b/ffi/tests/example/test_object.cc
index 959246fda0..4395fccf52 100644
--- a/ffi/tests/example/test_object.cc
+++ b/ffi/tests/example/test_object.cc
@@ -6,11 +6,55 @@ namespace {
 
 using namespace tvm::ffi;
 
-class IntObj : public Object {
+class NumberObj : public Object {
+ public:
+  // declare as one slot, with float as overflow
+  static constexpr uint32_t _type_child_slots = 1;
+  static constexpr const char* _type_key = "test.Number";
+  TVM_FFI_DECLARE_BASE_OBJECT_INFO(NumberObj, Object);
+};
+
+class Number : public ObjectRef {
+ public:
+  TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Number, ObjectRef, NumberObj);
+};
+
+class IntObj : public NumberObj {
  public:
   int64_t value;
 
   IntObj(int64_t value) : value(value) {}
+
+  static constexpr const char* _type_key = "test.Int";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(IntObj, NumberObj);
+};
+
+class Int : public Number {
+ public:
+  explicit Int(int64_t value) {
+    data_ = make_object<IntObj>(value);
+  }
+
+  TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Int, Number, IntObj);
+};
+
+class FloatObj : public NumberObj {
+ public:
+  double value;
+
+  FloatObj(double value) : value(value) {}
+
+  static constexpr const char* _type_key = "test.Float";
+  TVM_FFI_DECLARE_FINAL_OBJECT_INFO(FloatObj, NumberObj);
+};
+
+class Float : public Number {
+ public:
+  explicit Float(double value) {
+    data_ = make_object<FloatObj>(value);
+  }
+
+  TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Float, Number, FloatObj);
 };
 
 TEST(Object, RefCounter) {
@@ -32,4 +76,51 @@ TEST(Object, RefCounter) {
   EXPECT_EQ(c->value, 11);
 }
 
+TEST(Object, TypeInfo) {
+  const TypeInfo* info = 
tvm::ffi::details::ObjectGetTypeInfo(IntObj::RuntimeTypeIndex());
+  EXPECT_TRUE(info != nullptr);
+  EXPECT_EQ(info->type_index, IntObj::RuntimeTypeIndex());
+  EXPECT_EQ(info->type_depth, 2);
+  EXPECT_EQ(info->type_acenstors[0], Object::_type_index);
+  EXPECT_EQ(info->type_acenstors[1], NumberObj::_type_index);
+  EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
+}
+
+TEST(Object, InstanceCheck) {
+  ObjectPtr<Object> a = make_object<IntObj>(11);
+  ObjectPtr<Object> b = make_object<FloatObj>(11);
+
+  EXPECT_TRUE(a->IsInstance<Object>());
+  EXPECT_TRUE(a->IsInstance<NumberObj>());
+  EXPECT_TRUE(a->IsInstance<IntObj>());
+  EXPECT_TRUE(!a->IsInstance<FloatObj>());
+
+  EXPECT_TRUE(a->IsInstance<Object>());
+  EXPECT_TRUE(b->IsInstance<NumberObj>());
+  EXPECT_TRUE(!b->IsInstance<IntObj>());
+  EXPECT_TRUE(b->IsInstance<FloatObj>());
+}
+
+TEST(ObjectRef, as) {
+  ObjectRef a = Int(10);
+  ObjectRef b = Float(20);
+  // nullable object
+  ObjectRef c(nullptr);
+
+  EXPECT_TRUE(a.as<IntObj>() != nullptr);
+  EXPECT_TRUE(a.as<FloatObj>() == nullptr);
+  EXPECT_TRUE(a.as<NumberObj>() != nullptr);
+
+  EXPECT_TRUE(b.as<IntObj>() == nullptr);
+  EXPECT_TRUE(b.as<FloatObj>() != nullptr);
+  EXPECT_TRUE(b.as<NumberObj>() != nullptr);
+
+  EXPECT_TRUE(c.as<IntObj>() == nullptr);
+  EXPECT_TRUE(c.as<FloatObj>() == nullptr);
+  EXPECT_TRUE(c.as<NumberObj>() == nullptr);
+
+  EXPECT_EQ(a.as<IntObj>()->value, 10);
+  EXPECT_EQ(b.as<FloatObj>()->value, 20);
+}
+
 }  // namespace

Reply via email to