This is an automated email from the ASF dual-hosted git repository. kparzysz pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 694d4bf5ea [tir] Add copy on write to all nodes (#13512) 694d4bf5ea is described below commit 694d4bf5eaf65df4eaad93188830112c6b139956 Author: driazati <9407960+driaz...@users.noreply.github.com> AuthorDate: Tue Nov 29 13:58:21 2022 -0800 [tir] Add copy on write to all nodes (#13512) This enables copy on write methods for all nodes since some were missing it before (see #13012 for more context) Co-authored-by: driazati <driaz...@users.noreply.github.com> --- include/tvm/ir/expr.h | 2 ++ include/tvm/tir/expr.h | 30 ++++++++++++++++++++++++++++++ include/tvm/tir/stmt.h | 12 ++++++++++++ 3 files changed, 44 insertions(+) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 94927b4892..bb4c468f45 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -526,6 +526,7 @@ class IntImm : public PrimExpr { TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; /*! @@ -572,6 +573,7 @@ class FloatImm : public PrimExpr { TVM_DLL FloatImm(DataType dtype, double value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode); }; /*! diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 674ff0b7f4..689b1c0a17 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -79,6 +79,7 @@ class StringImm : public PrimExpr { public: TVM_DLL StringImm(String value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); }; /*! @@ -117,6 +118,7 @@ class Cast : public PrimExpr { public: TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode); }; /*! @@ -165,6 +167,7 @@ class Add : public PrimExpr { public: TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode); }; /*! \brief a - b */ @@ -181,6 +184,7 @@ class Sub : public PrimExpr { public: TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode); }; /*! \brief a * b */ @@ -197,6 +201,7 @@ class Mul : public PrimExpr { public: TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode); }; /*! @@ -216,6 +221,7 @@ class Div : public PrimExpr { public: TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode); }; /*! @@ -235,6 +241,7 @@ class Mod : public PrimExpr { public: TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode); }; /*! \brief Floor division, floor(a/b) */ @@ -251,6 +258,7 @@ class FloorDiv : public PrimExpr { public: TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode); }; /*! \brief The remainder of the floordiv */ @@ -267,6 +275,7 @@ class FloorMod : public PrimExpr { public: TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode); }; /*! \brief min(a, b) */ @@ -283,6 +292,7 @@ class Min : public PrimExpr { public: TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode); }; /*! \brief max(a, b) */ @@ -299,6 +309,7 @@ class Max : public PrimExpr { public: TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode); }; /*! @@ -347,6 +358,7 @@ class EQ : public PrimExpr { public: TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode); }; /*! \brief a != b */ @@ -363,6 +375,7 @@ class NE : public PrimExpr { public: TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode); }; /*! \brief a < b */ @@ -379,6 +392,7 @@ class LT : public PrimExpr { public: TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode); }; /*! \brief a <= b */ @@ -395,6 +409,7 @@ class LE : public PrimExpr { public: TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode); }; /*! \brief a > b */ @@ -411,6 +426,7 @@ class GT : public PrimExpr { public: TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode); }; /*! \brief a >= b */ @@ -427,6 +443,7 @@ class GE : public PrimExpr { public: TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode); }; /*! \brief a && b */ @@ -466,6 +483,7 @@ class And : public PrimExpr { public: TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode); }; /*! \brief a || b */ @@ -505,6 +523,7 @@ class Or : public PrimExpr { public: TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode); }; /*! \brief !a */ @@ -540,6 +559,7 @@ class Not : public PrimExpr { public: TVM_DLL Not(PrimExpr a, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode); }; /*! @@ -591,6 +611,7 @@ class Select : public PrimExpr { TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode); }; /*! @@ -706,6 +727,7 @@ class ProducerLoad : public PrimExpr { TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> indices, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode); }; /*! @@ -765,6 +787,7 @@ class Load : public PrimExpr { TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode); }; /*! @@ -817,6 +840,7 @@ class Ramp : public PrimExpr { public: TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode); }; /*! \brief Create a vector where all the elements are value. */ @@ -856,6 +880,7 @@ class Broadcast : public PrimExpr { public: TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode); }; /*! @@ -902,6 +927,7 @@ class Let : public PrimExpr { public: TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode); }; /*! @@ -948,6 +974,7 @@ class Call : public PrimExpr { public: TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; /*! @@ -995,6 +1022,7 @@ class Shuffle : public PrimExpr { TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode); }; // Reduce operator @@ -1124,6 +1152,7 @@ class Reduce : public PrimExpr { int value_index, Array<PrimExpr> init, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode); }; /*! \brief Any shape. */ @@ -1159,6 +1188,7 @@ class Any : public PrimExpr { TVM_DLL Any(Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode); }; /* diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 6865326b88..5beea44cdb 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -102,6 +102,7 @@ class LetStmt : public Stmt { TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode); }; /*! @@ -158,6 +159,7 @@ class AttrStmt : public Stmt { TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode); }; /*! @@ -206,6 +208,7 @@ class AssertStmt : public Stmt { TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode); }; /*! @@ -271,6 +274,7 @@ class Store : public Stmt { Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode); }; /*! @@ -442,6 +446,7 @@ class ProducerStore : public Stmt { Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode); }; /*! @@ -505,6 +510,7 @@ class ProducerRealize : public Stmt { String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode); }; /*! @@ -679,6 +685,7 @@ class AllocateConst : public Stmt { Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode); }; /*! \brief Declare a buffer that can be used in the body */ @@ -812,6 +819,7 @@ class SeqStmt : public Stmt { }; TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode); }; /*! @@ -898,6 +906,7 @@ class Evaluate : public Stmt { explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), span) {} TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode); }; /*! @@ -1055,6 +1064,7 @@ class While : public Stmt { TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode); }; /*! @@ -1099,6 +1109,7 @@ class Prefetch : public Stmt { TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = Span()); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode); }; /*! @@ -1203,6 +1214,7 @@ class MatchBufferRegion : public ObjectRef { TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source); TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode); }; /*!