This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 694d4bf5ea [tir] Add copy on write to all nodes (#13512)
694d4bf5ea is described below

commit 694d4bf5eaf65df4eaad93188830112c6b139956
Author: driazati <9407960+driaz...@users.noreply.github.com>
AuthorDate: Tue Nov 29 13:58:21 2022 -0800

    [tir] Add copy on write to all nodes (#13512)
    
    This enables copy on write methods for all nodes since some were missing
    it before (see #13012 for more context)
    
    Co-authored-by: driazati <driaz...@users.noreply.github.com>
---
 include/tvm/ir/expr.h  |  2 ++
 include/tvm/tir/expr.h | 30 ++++++++++++++++++++++++++++++
 include/tvm/tir/stmt.h | 12 ++++++++++++
 3 files changed, 44 insertions(+)

diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 94927b4892..bb4c468f45 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -526,6 +526,7 @@ class IntImm : public PrimExpr {
   TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(IntImm, PrimExpr, IntImmNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode);
 };
 
 /*!
@@ -572,6 +573,7 @@ class FloatImm : public PrimExpr {
   TVM_DLL FloatImm(DataType dtype, double value, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(FloatImmNode);
 };
 
 /*!
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 674ff0b7f4..689b1c0a17 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -79,6 +79,7 @@ class StringImm : public PrimExpr {
  public:
   TVM_DLL StringImm(String value, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
 };
 
 /*!
@@ -117,6 +118,7 @@ class Cast : public PrimExpr {
  public:
   TVM_DLL Cast(DataType dtype, PrimExpr value, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Cast, PrimExpr, CastNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(CastNode);
 };
 
 /*!
@@ -165,6 +167,7 @@ class Add : public PrimExpr {
  public:
   TVM_DLL Add(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Add, PrimExpr, AddNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AddNode);
 };
 
 /*! \brief a - b */
@@ -181,6 +184,7 @@ class Sub : public PrimExpr {
  public:
   TVM_DLL Sub(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Sub, PrimExpr, SubNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(SubNode);
 };
 
 /*! \brief a * b */
@@ -197,6 +201,7 @@ class Mul : public PrimExpr {
  public:
   TVM_DLL Mul(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Mul, PrimExpr, MulNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(MulNode);
 };
 
 /*!
@@ -216,6 +221,7 @@ class Div : public PrimExpr {
  public:
   TVM_DLL Div(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Div, PrimExpr, DivNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(DivNode);
 };
 
 /*!
@@ -235,6 +241,7 @@ class Mod : public PrimExpr {
  public:
   TVM_DLL Mod(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Mod, PrimExpr, ModNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ModNode);
 };
 
 /*! \brief Floor division, floor(a/b) */
@@ -251,6 +258,7 @@ class FloorDiv : public PrimExpr {
  public:
   TVM_DLL FloorDiv(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(FloorDiv, PrimExpr, FloorDivNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorDivNode);
 };
 
 /*! \brief The remainder of the floordiv */
@@ -267,6 +275,7 @@ class FloorMod : public PrimExpr {
  public:
   TVM_DLL FloorMod(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(FloorMod, PrimExpr, FloorModNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(FloorModNode);
 };
 
 /*! \brief min(a, b) */
@@ -283,6 +292,7 @@ class Min : public PrimExpr {
  public:
   TVM_DLL Min(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Min, PrimExpr, MinNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(MinNode);
 };
 
 /*! \brief max(a, b) */
@@ -299,6 +309,7 @@ class Max : public PrimExpr {
  public:
   TVM_DLL Max(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Max, PrimExpr, MaxNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(MaxNode);
 };
 
 /*!
@@ -347,6 +358,7 @@ class EQ : public PrimExpr {
  public:
   TVM_DLL EQ(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(EQ, PrimExpr, EQNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(EQNode);
 };
 
 /*! \brief a != b */
@@ -363,6 +375,7 @@ class NE : public PrimExpr {
  public:
   TVM_DLL NE(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(NE, PrimExpr, NENode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(NENode);
 };
 
 /*! \brief a < b */
@@ -379,6 +392,7 @@ class LT : public PrimExpr {
  public:
   TVM_DLL LT(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(LT, PrimExpr, LTNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LTNode);
 };
 
 /*! \brief a <= b */
@@ -395,6 +409,7 @@ class LE : public PrimExpr {
  public:
   TVM_DLL LE(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(LE, PrimExpr, LENode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LENode);
 };
 
 /*! \brief a > b */
@@ -411,6 +426,7 @@ class GT : public PrimExpr {
  public:
   TVM_DLL GT(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(GT, PrimExpr, GTNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(GTNode);
 };
 
 /*! \brief a >= b */
@@ -427,6 +443,7 @@ class GE : public PrimExpr {
  public:
   TVM_DLL GE(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(GE, PrimExpr, GENode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(GENode);
 };
 
 /*! \brief a && b */
@@ -466,6 +483,7 @@ class And : public PrimExpr {
  public:
   TVM_DLL And(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(And, PrimExpr, AndNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AndNode);
 };
 
 /*! \brief a || b */
@@ -505,6 +523,7 @@ class Or : public PrimExpr {
  public:
   TVM_DLL Or(PrimExpr a, PrimExpr b, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Or, PrimExpr, OrNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(OrNode);
 };
 
 /*! \brief !a */
@@ -540,6 +559,7 @@ class Not : public PrimExpr {
  public:
   TVM_DLL Not(PrimExpr a, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Not, PrimExpr, NotNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(NotNode);
 };
 
 /*!
@@ -591,6 +611,7 @@ class Select : public PrimExpr {
   TVM_DLL Select(PrimExpr condition, PrimExpr true_value, PrimExpr 
false_value, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Select, PrimExpr, SelectNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(SelectNode);
 };
 
 /*!
@@ -706,6 +727,7 @@ class ProducerLoad : public PrimExpr {
   TVM_DLL explicit ProducerLoad(DataProducer producer, Array<PrimExpr> 
indices, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(ProducerLoad, PrimExpr, ProducerLoadNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerLoadNode);
 };
 
 /*!
@@ -765,6 +787,7 @@ class Load : public PrimExpr {
   TVM_DLL Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr 
predicate,
                Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Load, PrimExpr, LoadNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LoadNode);
 };
 
 /*!
@@ -817,6 +840,7 @@ class Ramp : public PrimExpr {
  public:
   TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
 };
 
 /*! \brief Create a vector where all the elements are value. */
@@ -856,6 +880,7 @@ class Broadcast : public PrimExpr {
  public:
   TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
 };
 
 /*!
@@ -902,6 +927,7 @@ class Let : public PrimExpr {
  public:
   TVM_DLL Let(Var var, PrimExpr value, PrimExpr body, Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Let, PrimExpr, LetNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode);
 };
 
 /*!
@@ -948,6 +974,7 @@ class Call : public PrimExpr {
  public:
   TVM_DLL Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span = 
Span());
   TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
 };
 
 /*!
@@ -995,6 +1022,7 @@ class Shuffle : public PrimExpr {
   TVM_DLL static PrimExpr ExtractElement(PrimExpr vector, int index, Span span 
= Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Shuffle, PrimExpr, ShuffleNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ShuffleNode);
 };
 
 // Reduce operator
@@ -1124,6 +1152,7 @@ class Reduce : public PrimExpr {
                  int value_index, Array<PrimExpr> init, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Reduce, PrimExpr, ReduceNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ReduceNode);
 };
 
 /*! \brief Any shape. */
@@ -1159,6 +1188,7 @@ class Any : public PrimExpr {
   TVM_DLL Any(Span span = Span());
 
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Any, PrimExpr, AnyNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AnyNode);
 };
 
 /*
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 6865326b88..5beea44cdb 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -102,6 +102,7 @@ class LetStmt : public Stmt {
   TVM_DLL LetStmt(Var var, PrimExpr value, Stmt body, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(LetStmt, Stmt, LetStmtNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(LetStmtNode);
 };
 
 /*!
@@ -158,6 +159,7 @@ class AttrStmt : public Stmt {
   TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, 
Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AttrStmtNode);
 };
 
 /*!
@@ -206,6 +208,7 @@ class AssertStmt : public Stmt {
   TVM_DLL AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span 
span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(AssertStmt, Stmt, AssertStmtNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AssertStmtNode);
 };
 
 /*!
@@ -271,6 +274,7 @@ class Store : public Stmt {
                 Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Store, Stmt, StoreNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(StoreNode);
 };
 
 /*!
@@ -442,6 +446,7 @@ class ProducerStore : public Stmt {
                         Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode);
 };
 
 /*!
@@ -505,6 +510,7 @@ class ProducerRealize : public Stmt {
                           String storage_scope = "", Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode);
 };
 
 /*!
@@ -679,6 +685,7 @@ class AllocateConst : public Stmt {
                         Map<String, ObjectRef> annotations = Map<String, 
ObjectRef>(),
                         Span span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(AllocateConstNode);
 };
 
 /*! \brief Declare a buffer that can be used in the body */
@@ -812,6 +819,7 @@ class SeqStmt : public Stmt {
   };
 
   TVM_DEFINE_OBJECT_REF_METHODS(SeqStmt, Stmt, SeqStmtNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(SeqStmtNode);
 };
 
 /*!
@@ -898,6 +906,7 @@ class Evaluate : public Stmt {
   explicit Evaluate(int value, Span span = Span()) : Evaluate(PrimExpr(value), 
span) {}
 
   TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(EvaluateNode);
 };
 
 /*!
@@ -1055,6 +1064,7 @@ class While : public Stmt {
   TVM_DLL While(PrimExpr condition, Stmt body, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(While, Stmt, WhileNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode);
 };
 
 /*!
@@ -1099,6 +1109,7 @@ class Prefetch : public Stmt {
   TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds, Span span = 
Span());
 
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode);
 };
 
 /*!
@@ -1203,6 +1214,7 @@ class MatchBufferRegion : public ObjectRef {
   TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);
 
   TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, 
MatchBufferRegionNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchBufferRegionNode);
 };
 
 /*!

Reply via email to