This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 66f993be83bb55232dafa23a286511e6c2425ccd
Author: tqchen <[email protected]>
AuthorDate: Sun Mar 9 18:14:26 2025 -0400

    pass expr first stab
---
 ffi/include/tvm/ffi/object.h     | 28 +++++++++-------------
 ffi/include/tvm/ffi/reflection.h |  8 +++++--
 include/tvm/runtime/object.h     | 51 +++++++++++++++++++++++++++++++++++++++-
 src/node/reflection.cc           |  5 ++--
 src/runtime/debug_compile.cc     |  2 ++
 5 files changed, 72 insertions(+), 22 deletions(-)

diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index d63b4efbd7..b2189fb7f5 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -426,24 +426,15 @@ struct ObjectPtrEqual {
   }
 };
 
-/*!
- * \brief Helper macro to declare list of static checks about object meta-data.
- * \param TypeName The name of the current type.
- * \param ParentType The name of the ParentType
- */
-#define TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType)                       
           \
-  static constexpr int32_t _type_depth = ParentType::_type_depth + 1;          
           \
-  static_assert(!ParentType::_type_final, "ParentType marked as final");       
           \
-  static_assert(TypeName::_type_child_slots == 0 || 
ParentType::_type_child_slots == 0 || \
-                    TypeName::_type_child_slots < 
ParentType::_type_child_slots,          \
-                "Need to set _type_child_slots when parent specifies it.");    
           \
-  static_assert(TypeName::_type_child_slots == 0 || 
ParentType::_type_child_slots == 0 || \
-                    TypeName::_type_child_slots < 
ParentType::_type_child_slots,          \
-                "Need to set _type_child_slots when parent specifies it.")
 
 // If dynamic type is enabled, we still need to register the runtime type of 
parent
 #define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType)                
\
+  static constexpr int32_t _type_depth = ParentType::_type_depth + 1;          
  \
   static int32_t _GetOrAllocRuntimeTypeIndex() {                               
\
+    static_assert(!ParentType::_type_final, "ParentType marked as final");     
             \
+    static_assert(TypeName::_type_child_slots == 0 || 
ParentType::_type_child_slots == 0 || \
+                      TypeName::_type_child_slots < 
ParentType::_type_child_slots,          \
+                  "Need to set _type_child_slots when parent specifies it.");  
             \
     static int32_t tindex = TVMFFIGetOrAllocTypeIndex(                         
\
         TypeName::_type_key, TypeName::_type_index, TypeName::_type_depth,     
\
         TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow, 
\
@@ -458,9 +449,8 @@ struct ObjectPtrEqual {
  * \param ParentType The name of the ParentType
  */
 #define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType)      \
-  TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType);            \
   static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \
-  TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType)
+  TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType);
 
 
 /*
@@ -524,7 +514,6 @@ struct ObjectPtrEqual {
 
 
 namespace details {
-
 template <typename TargetType>
 TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) {
   static_assert(std::is_base_of_v<Object, TargetType>);
@@ -575,6 +564,11 @@ class ObjectUnsafe {
             
reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_)));
   }
 
+  template <typename T>
+  static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromObjectRef(const ObjectRef& 
ref) {
+    return tvm::ffi::ObjectPtr<T>(ref.data_.data_);
+  }
+
   template <typename T>
   static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) {
     tvm::ffi::ObjectPtr<T> ptr;
diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h
index f3cc580853..d154522b9f 100644
--- a/ffi/include/tvm/ffi/reflection.h
+++ b/ffi/include/tvm/ffi/reflection.h
@@ -154,8 +154,12 @@ class ReflectionFieldGetter {
  * \param ParentType The name of the ParentType
  */
 #define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                 
               \
-  TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType);                            
               \
-  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
               \
+  static constexpr int32_t _type_depth = ParentType::_type_depth + 1;          
               \
+  static int32_t _GetOrAllocRuntimeTypeIndex() {                               
            \
+    static_assert(!ParentType::_type_final, "ParentType marked as final");     
             \
+    static_assert(TypeName::_type_child_slots == 0 || 
ParentType::_type_child_slots == 0 || \
+                      TypeName::_type_child_slots < 
ParentType::_type_child_slots,          \
+                  "Need to set _type_child_slots when parent specifies it.");  
             \
     static int32_t tindex = TVMFFIGetOrAllocTypeIndex(                         
               \
         TypeName::_type_key, -1, TypeName::_type_depth, 
TypeName::_type_child_slots,          \
         TypeName::_type_child_slots_can_overflow, 
ParentType::_GetOrAllocRuntimeTypeIndex()); \
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index ec39461b0c..d956eebf1f 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -128,11 +128,13 @@ class ObjectRef : public tvm::ffi::ObjectRef {
    * \brief Internal helper function get data_ as ObjectPtr of ObjectType.
    * \note only used for internal dev purpose.
    * \tparam ObjectType The corresponding object type.
+   * \param ref The object reference
    * \return the corresponding type.
    */
   template <typename ObjectType>
   static ObjectPtr<ObjectType> GetDataPtr(const ObjectRef& ref) {
-    return ObjectPtr<ObjectType>(ref.data_.data_);
+    // return ObjectPtr<ObjectType>(ref.data_.data_);
+    return ffi::details::ObjectUnsafe::ObjectPtrFromObjectRef<ObjectType>(ref);
   }
 
   // friend classes.
@@ -152,6 +154,53 @@ class ObjectRef : public tvm::ffi::ObjectRef {
   TypeName& operator=(const TypeName& other) = default;   \
   TypeName& operator=(TypeName&& other) = default;
 
+/*!
+ * \brief Define CopyOnWrite function in an ObjectRef.
+ * \param ObjectName The Type of the Node.
+ *
+ *  CopyOnWrite will generate a unique copy of the internal node.
+ *  The node will be copied if it is referenced by multiple places.
+ *  The function returns the raw pointer to the node to allow modification
+ *  of the content.
+ *
+ * \code
+ *
+ *  MyCOWObjectRef ref, ref2;
+ *  ref2 = ref;
+ *  ref.CopyOnWrite()->value = new_value;
+ *  assert(ref2->value == old_value);
+ *  assert(ref->value == new_value);
+ *
+ * \endcode
+ */
+#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)              \
+static_assert(ObjectName::_type_final,                           \
+              "TVM's CopyOnWrite may only be used for "          \
+              "Object types that are declared as final, "        \
+              "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \
+ObjectName* CopyOnWrite() {                                      \
+  ICHECK(data_ != nullptr);                                      \
+  if (!data_.unique()) {                                         \
+    auto n = make_object<ObjectName>(*(operator->()));           \
+    ObjectPtr<Object>(std::move(n)).swap(data_);                 \
+  }                                                              \
+  return static_cast<ObjectName*>(data_.get());                  \
+}
+
+/*
+ * \brief Define object reference methods.
+ * \param TypeName The object type name
+ * \param ParentType The parent type of the objectref
+ * \param ObjectName The type name of the object.
+ */
+#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, 
ParentType,        \
+  ObjectName)                  \
+explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : 
ParentType(n) {}    \
+TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName);                             
              \
+const ObjectName* operator->() const { return static_cast<const 
ObjectName*>(data_.get()); } \
+const ObjectName* get() const { return operator->(); }                         
              \
+using ContainerType = ObjectName;
+
 
 #define TVM_DECLARE_BASE_OBJECT_INFO TVM_FFI_DECLARE_BASE_OBJECT_INFO
 #define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index aa572e9965..02e7b2423d 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -153,8 +153,9 @@ ReflectionVTable* ReflectionVTable::Global() {
 
 ObjectPtr<Object> ReflectionVTable::CreateInitObject(const std::string& 
type_key,
                                                      const std::string& 
repr_bytes) const {
-  uint32_t tindex = Object::TypeKey2Index(type_key);
-  if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
+  int32_t tindex;
+  TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKey2Index(type_key.c_str(), &tindex));
+  if (static_cast<size_t>(tindex) >= fcreate_.size() || fcreate_[tindex] == 
nullptr) {
     LOG(FATAL) << "TypeError: " << type_key << " is not registered via 
TVM_REGISTER_NODE_TYPE";
   }
   return fcreate_[tindex](repr_bytes);
diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc
index 4233035abd..17f86d5f7f 100644
--- a/src/runtime/debug_compile.cc
+++ b/src/runtime/debug_compile.cc
@@ -30,6 +30,8 @@
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/disco/disco_worker.h>
+#include <tvm/ir/expr.h>
+#include <tvm/tir/expr.h>
 
 namespace tvm {
 namespace debug {

Reply via email to