This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0df4103675 [Bugfix] Restrict CopyOnWrite to _type_final (#17132)
0df4103675 is described below

commit 0df4103675a52cc5b9e6356cb003bb17c66bc1a4
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Tue Jul 2 10:18:08 2024 -0500

    [Bugfix] Restrict CopyOnWrite to _type_final (#17132)
    
    Prior to this commit, the `TVM_DEFINE_OBJECT_REF_COW_METHOD` could be
    used in any `ObjectRef` subclass to provide a `CopyOnWrite` method.
    However, the implementation of this method method was invalid if the
    object's `ContainerType` could itself be subclassed.  In that case,
    using `obj.CopyOnWrite()` when the object contains a subclass, and
    when a copy is required, would silently convert `obj` to instead
    contain a base class.
    
    This commit adds a `static_assert`, to the
    `TVM_DEFINE_OBJECT_REF_COW_METHOD` macro, preventing the macro from being
    used in classes that would have incorrect usage.
    
    Compilation with this change found two classes, `relax::Var` and
    `relax::BindingBlock` that were susceptible to this error, and the macro
    has been removed from these classes.  For backwards-compatibility, the
    `CopyOnWrite` function for these two classes is provided explicitly.
---
 include/tvm/relax/expr.h     |  7 ++++---
 include/tvm/runtime/object.h | 20 ++++++++++++--------
 src/relax/ir/expr.cc         | 38 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 54 insertions(+), 11 deletions(-)

diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 401aaa9248..60032c3462 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -427,7 +427,8 @@ class Var : public LeafExpr {
 
   TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, 
Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode);
-  TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
+
+  VarNode* CopyOnWrite();
 };
 
 /*! \brief A sub-type of the variable node used to mark dataflow variables from
@@ -784,10 +785,10 @@ class BindingBlock : public ObjectRef {
  public:
   TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
-  TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode);
+
+  BindingBlockNode* CopyOnWrite();
 };
 
-class DataflowBlock;
 class DataflowBlockNode : public BindingBlockNode {
  public:
   bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const 
{
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 172316daae..4483867f3c 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -823,14 +823,18 @@ struct ObjectPtrEqual {
  *
  * \endcode
  */
-#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)     \
-  ObjectName* CopyOnWrite() {                            \
-    ICHECK(data_ != nullptr);                            \
-    if (!data_.unique()) {                               \
-      auto n = make_object<ObjectName>(*(operator->())); \
-      ObjectPtr<Object>(std::move(n)).swap(data_);       \
-    }                                                    \
-    return static_cast<ObjectName*>(data_.get());        \
+#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName)               \
+  static_assert(ObjectName::_type_final,                           \
+                "TVM's CopyOnWrite may only be used for "          \
+                "Object types that are declared as final, "        \
+                "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \
+  ObjectName* CopyOnWrite() {                                      \
+    ICHECK(data_ != nullptr);                                      \
+    if (!data_.unique()) {                                         \
+      auto n = make_object<ObjectName>(*(operator->()));           \
+      ObjectPtr<Object>(std::move(n)).swap(data_);                 \
+    }                                                              \
+    return static_cast<ObjectName*>(data_.get());                  \
   }
 
 // Implementations details below
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 59b6a0aeb7..a14ba1d9aa 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -265,6 +265,25 @@ Var::Var(Id vid, Optional<StructInfo> 
struct_info_annotation, Span span) {
   data_ = std::move(n);
 }
 
+VarNode* Var::CopyOnWrite() {
+  // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for
+  // Var, because it is the base class for `DataflowBlock`.
+  // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the
+  // automatic implementation would erroneously convert from a
+  // `DataflowBlock` to a `Var`.
+  ICHECK(data_ != nullptr);
+  if (!data_.unique()) {
+    ObjectPtr<VarNode> node;
+    if (auto dataflow_var = as<DataflowVarNode>()) {
+      node = make_object<DataflowVarNode>(*dataflow_var);
+    } else {
+      node = make_object<VarNode>(*(operator->()));
+    }
+    ObjectPtr<Object>(std::move(node)).swap(data_);
+  }
+  return static_cast<VarNode*>(data_.get());
+}
+
 TVM_REGISTER_GLOBAL("relax.Var")
     .set_body_typed([](String name_hint, Optional<StructInfo> 
struct_info_annotation, Span span) {
       return Var(name_hint, struct_info_annotation, span);
@@ -473,6 +492,25 @@ BindingBlock::BindingBlock(Array<Binding> bindings, Span 
span) {
   data_ = std::move(n);
 }
 
+BindingBlockNode* BindingBlock::CopyOnWrite() {
+  // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for
+  // BindingBlock, because it is the base class for `DataflowBlock`.
+  // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the
+  // automatic implementation would erroneously convert from a
+  // `DataflowBlock` to a `BindingBlock`.
+  ICHECK(data_ != nullptr);
+  if (!data_.unique()) {
+    ObjectPtr<BindingBlockNode> node;
+    if (auto dataflow_block = as<DataflowBlockNode>()) {
+      node = make_object<DataflowBlockNode>(*dataflow_block);
+    } else {
+      node = make_object<BindingBlockNode>(*(operator->()));
+    }
+    ObjectPtr<Object>(std::move(node)).swap(data_);
+  }
+  return static_cast<BindingBlockNode*>(data_.get());
+}
+
 TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array<Binding> 
bindings, Span span) {
   return BindingBlock(bindings, span);
 });

Reply via email to