This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 5d3d586e94bb5ab700d4f4f7cdd91151ffd496ba
Author: tqchen <[email protected]>
AuthorDate: Sun Mar 9 12:51:09 2025 -0400

    pass basic disco session compile
---
 ffi/include/tvm/ffi/object.h                | 49 +++++++++++++++++++++++++++++
 include/tvm/runtime/container/string.h      |  1 +
 include/tvm/runtime/disco/session.h         |  3 +-
 include/tvm/runtime/memory/memory_manager.h |  1 -
 include/tvm/runtime/object.h                |  6 ++++
 include/tvm/runtime/packed_func.h           |  5 +--
 src/runtime/debug_compile.cc                |  1 +
 src/runtime/disco/protocol.h                |  4 +--
 src/runtime/minrpc/rpc_reference.h          |  2 +-
 9 files changed, 65 insertions(+), 7 deletions(-)

diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 8fbd6705fd..01dca88878 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -114,6 +114,9 @@ class Object {
     return details::IsObjectInstance<TargetType>(header_.type_index);
   }
 
+  /*! \return The internal runtime type index of the object. */
+  int32_t type_index() const { return header_.type_index; }
+
   /*!
    * \return the type key of the object.
    * \note this operation is expensive, can be used for error reporting.
@@ -124,6 +127,16 @@ class Object {
     return type_info->type_key;
   }
 
+  /*!
+   * \brief Get the type key of the corresponding index from runtime.
+   * \param tindex The type index.
+   * \return the result.
+   */
+  static std::string TypeIndex2Key(int32_t tindex) {
+    const TypeInfo* type_info = TVMFFIGetTypeInfo(tindex);
+    return type_info->type_key;
+  }
+
   // Information about the object
   static constexpr const char* _type_key = "object.Object";
 
@@ -464,6 +477,7 @@ struct ObjectPtrEqual {
   static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \
   TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType)
 
+
 /*
  * \brief Define object reference methods.
  * \param TypeName The object type name
@@ -493,6 +507,37 @@ struct ObjectPtrEqual {
   static constexpr bool _type_is_nullable = false;                             
                \
   using ContainerType = ObjectName
 
+/*
+ * \brief Define object reference methods of whose content is mutable.
+ * \param TypeName The object type name
+ * \param ParentType The parent type of the objectref
+ * \param ObjectName The type name of the object.
+ * \note We recommend making objects immutable when possible.
+ *       This macro is only reserved for objects that stores runtime states.
+ */
+#define TVM_DEFINE_MUTABLE_NULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, 
ObjectName)    \
+  TypeName() = default;                                                        
             \
+  TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName);                       
             \
+  explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : 
ParentType(n) {} \
+  ObjectName* operator->() const { return 
static_cast<ObjectName*>(data_.get()); }          \
+  using ContainerType = ObjectName;
+
+/*
+ * \brief Define object reference methods that is both not nullable and 
mutable.
+ *
+ * \param TypeName The object type name
+ * \param ParentType The parent type of the objectref
+ * \param ObjectName The type name of the object.
+ */
+#define TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, 
ParentType, ObjectName) \
+  explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : 
ParentType(n) {}         \
+  TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName);                           
             \
+  ObjectName* operator->() const { return 
static_cast<ObjectName*>(data_.get()); }          \
+  ObjectName* get() const { return operator->(); }                             
             \
+  static constexpr bool _type_is_nullable = false;                             
             \
+  using ContainerType = ObjectName;
+
+
 namespace details {
 
 template <typename TargetType>
@@ -587,6 +632,10 @@ class ObjectUnsafe {
     reinterpret_cast<Object*>(src->v_obj)->IncRef();
   }
 
+  static TVM_FFI_INLINE Object* GetRawObjectPtrFromObjectRef(const ObjectRef& 
src) {
+    return src.data_.data_;
+  }
+
   static TVM_FFI_INLINE TVMFFIObject* GetTVMFFIObjectPtrFromObjectRef(const 
ObjectRef& src) {
     return GetHeader(src.data_.data_);
   }
diff --git a/include/tvm/runtime/container/string.h 
b/include/tvm/runtime/container/string.h
index 661d95bba2..4ee48d9c95 100644
--- a/include/tvm/runtime/container/string.h
+++ b/include/tvm/runtime/container/string.h
@@ -59,6 +59,7 @@ class StringObj : public Object {
 
   static constexpr const uint32_t _type_index = TypeIndex::kRuntimeString;
   static constexpr const char* _type_key = "runtime.String";
+  static const constexpr bool _type_final = true;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object);
 
  private:
diff --git a/include/tvm/runtime/disco/session.h 
b/include/tvm/runtime/disco/session.h
index 9c34f8a2af..1c3af22d83 100644
--- a/include/tvm/runtime/disco/session.h
+++ b/include/tvm/runtime/disco/session.h
@@ -148,7 +148,8 @@ class DRefObj : public Object {
 
   static constexpr const char* _type_key = "runtime.disco.DRef";
   static constexpr const uint32_t _type_index = TypeIndex::kRuntimeDiscoDRef;
-  TVM_DECLARE_FINAL_OBJECT_INFO(DRefObj, Object);
+  static const constexpr bool _type_final = true;
+  TVM_FFI_DECLARE_STATIC_OBJECT_INFO(DRefObj, Object);
 
   /*! \brief The id of the register */
   int64_t reg_id;
diff --git a/include/tvm/runtime/memory/memory_manager.h 
b/include/tvm/runtime/memory/memory_manager.h
index ab1e6b5c9f..754d7b4fe3 100644
--- a/include/tvm/runtime/memory/memory_manager.h
+++ b/include/tvm/runtime/memory/memory_manager.h
@@ -182,7 +182,6 @@ class StorageObj : public Object {
     }
   }
 
-  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
   static constexpr const char* _type_key = "vm.Storage";
   TVM_DECLARE_FINAL_OBJECT_INFO(StorageObj, Object);
 };
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 046f0d3948..2f89d11965 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -143,14 +143,20 @@ class ObjectRef : public tvm::ffi::ObjectRef {
   TypeName& operator=(const TypeName& other) = default;   \
   TypeName& operator=(TypeName&& other) = default;
 
+
+#define TVM_DECLARE_BASE_OBJECT_INFO TVM_FFI_DECLARE_BASE_OBJECT_INFO
 #define TVM_DECLARE_FINAL_OBJECT_INFO TVM_FFI_DECLARE_FINAL_OBJECT_INFO
 #define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS 
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS
 
+#define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS 
TVM_DEFINE_MUTABLE_NULLABLE_OBJECT_REF_METHODS
 #define TVM_DEFINE_OBJECT_REF_METHODS 
TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS
+#define TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS 
TVM_FFI_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS
 
 #define TVM_STR_CONCAT_(__x, __y) __x##__y
 #define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
 
+// Object register type is now a nop
+#define TVM_REGISTER_OBJECT_TYPE(x)
 
 }  // namespace runtime
 }  // namespace tvm
diff --git a/include/tvm/runtime/packed_func.h 
b/include/tvm/runtime/packed_func.h
index 804fd3b98e..0a2461515e 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -2096,7 +2096,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) 
const {
     return;
   }
 
-  Object* ptr = value.data_.data_;
+  Object* ptr = details::ObjectUnsafe::GetRawObjectPtrFromObjectRef(value);
   if constexpr (std::is_base_of_v<NDArray::ContainerType, ContainerType> ||
                 std::is_base_of_v<ContainerType, NDArray::ContainerType>) {
     if (std::is_base_of_v<NDArray::ContainerType, ContainerType> ||
@@ -2186,7 +2186,8 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) 
const {
     values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
     type_codes_[i] = kTVMObjectRValueRefArg;
   } else {
-    values_[i].v_handle = value.data_.data_;
+    // value.data_.data_;
+    values_[i].v_handle = 
details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(value);
     type_codes_[i] = kTVMObjectHandle;
   }
 }
diff --git a/src/runtime/debug_compile.cc b/src/runtime/debug_compile.cc
index 81f1edffc7..4233035abd 100644
--- a/src/runtime/debug_compile.cc
+++ b/src/runtime/debug_compile.cc
@@ -29,6 +29,7 @@
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
+#include <tvm/runtime/disco/disco_worker.h>
 
 namespace tvm {
 namespace debug {
diff --git a/src/runtime/disco/protocol.h b/src/runtime/disco/protocol.h
index 50a6b091af..cc741fdbb1 100644
--- a/src/runtime/disco/protocol.h
+++ b/src/runtime/disco/protocol.h
@@ -155,7 +155,7 @@ inline void 
DiscoProtocol<SubClassType>::WriteObject(Object* obj) {
     self->template Write<uint64_t>(shape->size);
     self->template WriteArray<ShapeTupleObj::index_type>(shape->data, 
shape->size);
   } else if (obj->IsInstance<DiscoDebugObject>()) {
-    self->template Write<uint32_t>(TypeIndex::kRoot);
+    self->template Write<uint32_t>(0);
     std::string str = static_cast<DiscoDebugObject*>(obj)->SaveToStr();
     self->template Write<uint64_t>(str.size());
     self->template WriteArray<char>(str.data(), str.size());
@@ -188,7 +188,7 @@ inline void DiscoProtocol<SubClassType>::ReadObject(int* 
tcode, TVMValue* value)
     std::vector<ShapeTupleObj::index_type> data(ndim);
     self->template ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
     result = ShapeTuple(std::move(data));
-  } else if (type_index == TypeIndex::kRoot) {
+  } else if (type_index == 0) {
     uint64_t size = 0;
     self->template Read<uint64_t>(&size);
     std::string data(size, '\0');
diff --git a/src/runtime/minrpc/rpc_reference.h 
b/src/runtime/minrpc/rpc_reference.h
index 13c1fa4b38..bb9c5d3c86 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -28,7 +28,7 @@ namespace tvm {
 namespace runtime {
 
 // Forward declare TVM Object to use `Object*` in RPC protocol.
-class Object;
+// class Object;
 
 /*! \brief The current RPC procotol version. */
 constexpr const char* kRPCProtocolVer = "0.8.0";

Reply via email to