This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0df4103675 [Bugfix] Restrict CopyOnWrite to _type_final (#17132) 0df4103675 is described below commit 0df4103675a52cc5b9e6356cb003bb17c66bc1a4 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Tue Jul 2 10:18:08 2024 -0500 [Bugfix] Restrict CopyOnWrite to _type_final (#17132) Prior to this commit, the `TVM_DEFINE_OBJECT_REF_COW_METHOD` could be used in any `ObjectRef` subclass to provide a `CopyOnWrite` method. However, the implementation of this method method was invalid if the object's `ContainerType` could itself be subclassed. In that case, using `obj.CopyOnWrite()` when the object contains a subclass, and when a copy is required, would silently convert `obj` to instead contain a base class. This commit adds a `static_assert`, to the `TVM_DEFINE_OBJECT_REF_COW_METHOD` macro, preventing the macro from being used in classes that would have incorrect usage. Compilation with this change found two classes, `relax::Var` and `relax::BindingBlock` that were susceptible to this error, and the macro has been removed from these classes. For backwards-compatibility, the `CopyOnWrite` function for these two classes is provided explicitly. --- include/tvm/relax/expr.h | 7 ++++--- include/tvm/runtime/object.h | 20 ++++++++++++-------- src/relax/ir/expr.cc | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 11 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 401aaa9248..60032c3462 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -427,7 +427,8 @@ class Var : public LeafExpr { TVM_DLL explicit Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, LeafExpr, VarNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); + + VarNode* CopyOnWrite(); }; /*! \brief A sub-type of the variable node used to mark dataflow variables from @@ -784,10 +785,10 @@ class BindingBlock : public ObjectRef { public: TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(BindingBlockNode); + + BindingBlockNode* CopyOnWrite(); }; -class DataflowBlock; class DataflowBlockNode : public BindingBlockNode { public: bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 172316daae..4483867f3c 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -823,14 +823,18 @@ struct ObjectPtrEqual { * * \endcode */ -#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ - ObjectName* CopyOnWrite() { \ - ICHECK(data_ != nullptr); \ - if (!data_.unique()) { \ - auto n = make_object<ObjectName>(*(operator->())); \ - ObjectPtr<Object>(std::move(n)).swap(data_); \ - } \ - return static_cast<ObjectName*>(data_.get()); \ +#define TVM_DEFINE_OBJECT_REF_COW_METHOD(ObjectName) \ + static_assert(ObjectName::_type_final, \ + "TVM's CopyOnWrite may only be used for " \ + "Object types that are declared as final, " \ + "using the TVM_DECLARE_FINAL_OBJECT_INFO macro."); \ + ObjectName* CopyOnWrite() { \ + ICHECK(data_ != nullptr); \ + if (!data_.unique()) { \ + auto n = make_object<ObjectName>(*(operator->())); \ + ObjectPtr<Object>(std::move(n)).swap(data_); \ + } \ + return static_cast<ObjectName*>(data_.get()); \ } // Implementations details below diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 59b6a0aeb7..a14ba1d9aa 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -265,6 +265,25 @@ Var::Var(Id vid, Optional<StructInfo> struct_info_annotation, Span span) { data_ = std::move(n); } +VarNode* Var::CopyOnWrite() { + // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for + // Var, because it is the base class for `DataflowBlock`. + // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the + // automatic implementation would erroneously convert from a + // `DataflowBlock` to a `Var`. + ICHECK(data_ != nullptr); + if (!data_.unique()) { + ObjectPtr<VarNode> node; + if (auto dataflow_var = as<DataflowVarNode>()) { + node = make_object<DataflowVarNode>(*dataflow_var); + } else { + node = make_object<VarNode>(*(operator->())); + } + ObjectPtr<Object>(std::move(node)).swap(data_); + } + return static_cast<VarNode*>(data_.get()); +} + TVM_REGISTER_GLOBAL("relax.Var") .set_body_typed([](String name_hint, Optional<StructInfo> struct_info_annotation, Span span) { return Var(name_hint, struct_info_annotation, span); @@ -473,6 +492,25 @@ BindingBlock::BindingBlock(Array<Binding> bindings, Span span) { data_ = std::move(n); } +BindingBlockNode* BindingBlock::CopyOnWrite() { + // The `TVM_DEFINE_OBJECT_REF_COW_METHOD` cannot be used for + // BindingBlock, because it is the base class for `DataflowBlock`. + // If the `TVM_DEFINE_OBJECT_REF_COW_METHOD` were used, the + // automatic implementation would erroneously convert from a + // `DataflowBlock` to a `BindingBlock`. + ICHECK(data_ != nullptr); + if (!data_.unique()) { + ObjectPtr<BindingBlockNode> node; + if (auto dataflow_block = as<DataflowBlockNode>()) { + node = make_object<DataflowBlockNode>(*dataflow_block); + } else { + node = make_object<BindingBlockNode>(*(operator->())); + } + ObjectPtr<Object>(std::move(node)).swap(data_); + } + return static_cast<BindingBlockNode*>(data_.get()); +} + TVM_REGISTER_GLOBAL("relax.BindingBlock").set_body_typed([](Array<Binding> bindings, Span span) { return BindingBlock(bindings, span); });