This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ce1fa8908f [TE] Record primitives of Schedule for visualization 
(#14168)
ce1fa8908f is described below

commit ce1fa8908f626e58f245966dd0a2e2540b75dace
Author: Chun-I Tsai <quic_chu...@quicinc.com>
AuthorDate: Wed Mar 15 04:52:11 2023 +0800

    [TE] Record primitives of Schedule for visualization (#14168)
    
    * [ScheduleVisualization]
    
    - Make Stage link to its Schedule via attach_sch.
    - Add two array attributes, primitive_record and schedule_record to 
Schedule.
    - Create a new class, ScheduleContext, to record primitives.
    - Register a pass config variable, keep_schedule_record to enable/disable 
the recording.
    - Add test cases for TEDD, build_module and schedule ops.
    
    * [ScheduleVisualization]
    
    * Fix grammar issues
    * Rewrite unclear comments
    
    * [ScheduleVisualization]
    
    * Remove the wrong term, rebased in comments and variables.
    
    ---------
    
    Co-authored-by: Joey Tsai <chu...@qti.qualcomm.com>
---
 include/tvm/te/schedule.h                     | 43 +++++++++++++++++++-
 python/tvm/contrib/tedd.py                    | 27 ++++++++++++-
 src/relay/backend/te_compiler.cc              | 13 +++++-
 src/te/schedule/schedule_dataflow_rewrite.cc  | 14 ++++---
 src/te/schedule/schedule_lang.cc              | 51 ++++++++++++++++++++++-
 tests/python/contrib/test_tedd.py             | 58 ++++++++++++++++++++++++++-
 tests/python/relay/test_build_module.py       | 37 +++++++++++++++++
 tests/python/unittest/test_te_schedule_ops.py | 53 ++++++++++++++++++++++++
 8 files changed, 283 insertions(+), 13 deletions(-)

diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h
index 5d88793206..1b711a8370 100644
--- a/include/tvm/te/schedule.h
+++ b/include/tvm/te/schedule.h
@@ -62,8 +62,9 @@ class Stage : public ObjectRef {
   /*!
    * \brief create a new schedule for op.
    * \param op The operator in the schedule
+   * \param sch The schedule which current stage belongs to
    */
-  explicit Stage(Operation op);
+  explicit Stage(Operation op, const ScheduleNode* sch);
   /*!
    * \brief access the internal node container
    * \return the pointer to the internal node container
@@ -445,6 +446,26 @@ class Schedule : public ObjectRef {
   using ContainerType = ScheduleNode;
 };
 
+/*!
+ * \brief Context helper to collect debug information of Schedule.
+ *
+ *  Attach With<ScheduleContext>(schedule_instance, primitive_name)
+ *  inside function body of schedule primitives to collect the
+ *  snapshot of schedule status and corresponding primitive name
+ */
+class ScheduleContext {
+ private:
+  friend class With<ScheduleContext>;
+  ScheduleContext(const ScheduleNode* sch_node, String current_primitive_name);
+  void EnterWithScope();
+  void ExitWithScope();
+
+  /*! \brief Schedule instance to store information for debug */
+  Schedule sch_;
+  /*! \brief String representing which primitive has been applied to sch_ */
+  String current_primitive_name_;
+};
+
 /*!
  * \brief The schedule relation between IterVars
  *  can be Split, Fuse.
@@ -546,6 +567,8 @@ class StageNode : public Object {
   IterVar attach_ivar;
   /*! \brief The stage this node attaches to */
   Stage attach_stage;
+  /*! \brief The schedule current stage is attached to */
+  const ScheduleNode* attach_sch;
   /*! \brief The thread storage scope level of the stage */
   std::string scope;
   /*! \brief Whether this is an output stage */
@@ -615,12 +638,30 @@ class ScheduleNode : public Object {
    *  This is created on demand and can be invalidated.
    */
   std::unordered_map<const Object*, Stage> op2stage_cache_;
+  /*!
+   * \brief list of all transformed schedules
+   * User can display the optimization strategy via TEDD step by step to check
+   * the order and effect of primitives. Set "te.keep_schedule_record" in
+   * PassContext config as true to enable recording.
+   */
+  Array<Schedule> schedule_record;
+  /*!
+   * \brief list of all applied primitive names.
+   */
+  Array<String> primitive_record;
+  /*!
+   * \brief Flag to keep schedule record or not.
+   */
+  Optional<Bool> keep_schedule_record;
 
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("outputs", &outputs);
     v->Visit("stages", &stages);
     v->Visit("groups", &groups);
     v->Visit("stage_map", &stage_map);
+    v->Visit("schedule_record", &schedule_record);
+    v->Visit("primitive_record", &primitive_record);
+    v->Visit("keep_schedule_record", &keep_schedule_record);
   }
 
   /*! \brief Initialize temp cache. */
diff --git a/python/tvm/contrib/tedd.py b/python/tvm/contrib/tedd.py
index a65f5e474a..aa423d8964 100644
--- a/python/tvm/contrib/tedd.py
+++ b/python/tvm/contrib/tedd.py
@@ -78,6 +78,27 @@ def insert_dot_id(sch):
     return sch
 
 
+def itervar_equal(iv_a, iv_b):
+    """A helper method that compares the equality of two iterative variables"""
+    # Adopt the following method to assure the equality between two itervars.
+    # The plain comparison might fail (i.e. iv_a == iv_b) after the change of
+    # domain bounds from InferBound.
+    def _var_equal(v_a, v_b):
+        condtions = [
+            v_a.name == v_b.name,
+            v_a.dtype == v_b.dtype,
+            v_a.type_annotation == v_b.type_annotation,
+        ]
+        return all(c for c in condtions)
+
+    condtions = [
+        _var_equal(iv_a.var, iv_b.var),
+        iv_a.iter_type == iv_b.iter_type,
+        iv_a.thread_tag == iv_b.thread_tag,
+    ]
+    return all(c for c in condtions)
+
+
 class ObjectManager:
     """A helper class tracking schedule objects, e.g. stage, IterVar,
     relationship, and tensor, to their DOM path."""
@@ -88,6 +109,10 @@ class ObjectManager:
             self.dict[stage] = [stage_idx]
             for itervar_idx, itervar in enumerate(stage.all_iter_vars):
                 self.dict[itervar] = [stage_idx, itervar_idx]
+                # the itervars of leaf should also be mapped to the original 
one
+                for leaf_iv in stage.leaf_iter_vars:
+                    if itervar_equal(leaf_iv, itervar):
+                        self.dict[leaf_iv] = [stage_idx, itervar_idx]
             for rel_idx, rel in enumerate(stage.relations):
                 self.dict[rel] = [stage_idx, rel_idx]
             for tensor_idx in range(stage.op.num_outputs):
@@ -289,7 +314,7 @@ def dump_json(sch, need_range):
 
         def get_leaf_itervar_index(itervar, leaf_iv):
             for leaf_index, ivar in enumerate(leaf_iv):
-                if ivar == itervar:
+                if itervar_equal(ivar, itervar):
                     return leaf_index
             return -1
 
diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc
index e20e0c9429..ce47be361e 100644
--- a/src/relay/backend/te_compiler.cc
+++ b/src/relay/backend/te_compiler.cc
@@ -700,7 +700,8 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
    */
   Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& 
prim_fn_var,
                        Array<Expr> args, Span span, const Target& target,
-                       const Map<GlobalVar, BaseFunc>& lowered_functions) {
+                       const Map<GlobalVar, BaseFunc>& lowered_functions,
+                       const te::Schedule& sch = {}) {
     auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler);
 
     // Add some metadata on top of the *original function* and invoke the 
callback so it can
@@ -730,6 +731,10 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
prim_fn_var);
       func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
       func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
target);
+      // Store generated Schedules of operator
+      if (sch.defined() && sch->keep_schedule_record) {
+        func_with_metadata = WithAttr(func_with_metadata, "schedule", sch);
+      }
       this->process_fn_(func_with_metadata);
     } else {
       const auto* function_node = original_function.as<FunctionNode>();
@@ -738,6 +743,10 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", 
prim_fn_var);
       func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", 
prim_fns);
       func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, 
target);
+      // Store generated Schedules of operator
+      if (sch.defined() && sch->keep_schedule_record) {
+        func_with_metadata = WithAttr(func_with_metadata, "schedule", sch);
+      }
       this->process_fn_(func_with_metadata);
     }
 
@@ -926,7 +935,7 @@ class LowerTensorExprMutator : public 
DeviceAwareExprMutator {
       CachedFunc cfunc = compiler_->Lower(key);
       ICHECK(cfunc.defined());
       return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, 
std::move(new_args),
-                             call_node->span, target, cfunc->funcs->functions);
+                             call_node->span, target, cfunc->funcs->functions, 
cfunc->schedule);
     }
   }
 
diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc 
b/src/te/schedule/schedule_dataflow_rewrite.cc
index c1741e9e4e..c38c5a5c80 100644
--- a/src/te/schedule/schedule_dataflow_rewrite.cc
+++ b/src/te/schedule/schedule_dataflow_rewrite.cc
@@ -174,10 +174,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, const 
std::string& scope,
   Array<Stage>& stages = (*this)->stages;
   Stage op_stage = operator[](tensor->op);
   size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage);
-  Stage cache_stage = Stage(cache->op);
-  cache_stage.set_scope(scope);
+  Stage cache_stage = Stage(cache->op, this->operator->());
   ICHECK_LT(pos, stages.size());
   stages.insert(stages.begin() + pos + 1, cache_stage);
+  // in order to obtain correct copy on schedule_record,
+  // make sure "set_scope" primitive is applied after stage being added
+  cache_stage.set_scope(scope);
   (*this)->stage_map.Set(cache->op, cache_stage);
   // Update group
   cache_stage->group = op_stage->group;
@@ -266,10 +268,12 @@ Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage 
orig_stage, const std::strin
   // create schedule for new cached stage.
   Array<Stage>& stages = sch->stages;
   size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage);
-  Stage cache_stage = Stage(cache_op);
-  cache_stage.set_scope(scope);
+  Stage cache_stage = Stage(cache_op, sch.operator->());
   ICHECK_LT(pos, stages.size());
   stages.insert(stages.begin() + pos, cache_stage);
+  // in order to obtain correct copy on schedule_record,
+  // make sure "set_scope" primitive is applied after stage being added
+  cache_stage.set_scope(scope);
   sch->stage_map.Set(cache_op, cache_stage);
   // Update group
   cache_stage->group = orig_stage->group;
@@ -892,7 +896,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const 
IterVar& axis, int f
   Operation factor_op(n);
   Array<Stage>& stages = (*this)->stages;
   size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
-  Stage factor_stage = Stage(factor_op);
+  Stage factor_stage = Stage(factor_op, this->operator->());
   factor_stage->relations = rels;
   ICHECK_LT(stage_pos, stages.size());
   stages.insert(stages.begin() + stage_pos, factor_stage);
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index e8f4f65eb6..56fe0cfc65 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -21,6 +21,7 @@
  * \file schedule_lang.cc
  */
 #include <dmlc/thread_local.h>
+#include <tvm/ir/transform.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/te/operation.h>
 #include <tvm/te/schedule.h>
@@ -91,7 +92,7 @@ void SplitHelper(StageNode* self, IterVar parent, PrimExpr 
factor, PrimExpr npar
   leaf_vars.insert(leaf_vars.begin() + pos, outer);
 }
 
-Stage::Stage(Operation op) {
+Stage::Stage(Operation op, const ScheduleNode* sch) {
   auto n = make_object<StageNode>();
   n->op = op;
   n->origin_op = op;
@@ -106,6 +107,7 @@ Stage::Stage(Operation op) {
   } else {
     n->leaf_iter_vars = clean;
   }
+  n->attach_sch = sch;
   data_ = std::move(n);
 }
 
@@ -124,11 +126,13 @@ Stage Stage::GetAttachSpec() const {
 }
 
 Stage& Stage::set_scope(std::string scope) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   (*this)->scope = scope;
   return *this;
 }
 
 Stage& Stage::compute_at(Stage parent, IterVar scope) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at 
for scan updates";
   // Group constraint checking.
   Stage group = (*this)->group;
@@ -156,18 +160,21 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) {  
// NOLINT(*)
 }
 
 Stage& Stage::compute_inline() {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at 
for scan updates";
   (*this)->attach_type = kInline;
   return *this;
 }
 
 Stage& Stage::compute_root() {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at 
for scan updates";
   (*this)->attach_type = kGroupRoot;
   return *this;
 }
 
 Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   ICHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce)
       << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread";
@@ -194,6 +201,7 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) {  // 
NOLINT(*)
 }
 
 Stage& Stage::env_threads(Array<IterVar> threads) {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   ICHECK(self->op.defined() && self->op.as<ScanOpNode>())
       << "env_threads is only valid for composite ops such as ScanOp";
@@ -211,6 +219,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) {
 }
 
 Stage& Stage::set_store_predicate(PrimExpr predicate) {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   self->store_predicate = predicate;
   return *this;
@@ -218,17 +227,20 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) {
 
 Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer,
                     IterVar* p_inner) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner);
   return *this;
 }
 
 Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* 
p_outer,
                               IterVar* p_inner) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner);
   return *this;
 }
 
 Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) {  // 
NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   ICHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce ||
          outer->iter_type == kOrdered)
@@ -264,6 +276,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* 
p_target) {  // NOLINT
 }
 
 Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) {  // 
NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   if (axes.size() != 0) {
     IterVar fused = axes[0];
     for (size_t i = 1; i < axes.size(); ++i) {
@@ -287,6 +300,7 @@ Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* 
p_target) {  // NOLINT(*
 }
 
 Stage& Stage::reorder(const Array<IterVar>& order) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   std::unordered_set<IterVar> seen_var;
   StageNode* self = operator->();
   for (IterVar iv : order) {
@@ -347,6 +361,7 @@ inline void SetAttrIterType(StageNode* self, IterVar var, 
IterVarType iter_type)
 }
 
 Stage& Stage::vectorize(IterVar var) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   ICHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || 
var->iter_type == kUnrolled ||
          var->iter_type == kVectorized || var->iter_type == kTensorized ||
          var->iter_type == kParallelized)
@@ -356,6 +371,7 @@ Stage& Stage::vectorize(IterVar var) {  // NOLINT(*)
 }
 
 Stage& Stage::tensorize(IterVar var, TensorIntrin f) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) {
     n->iter_type = kTensorized;
     n->tensor_intrin = f;
@@ -364,11 +380,13 @@ Stage& Stage::tensorize(IterVar var, TensorIntrin f) {  
// NOLINT(*)
 }
 
 Stage& Stage::unroll(IterVar var) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   SetAttrIterType(operator->(), var, kUnrolled);
   return *this;
 }
 
 Stage& Stage::parallel(IterVar var) {  // NOLINT(*)
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   SetAttrIterType(operator->(), var, kParallelized);
   return *this;
 }
@@ -380,6 +398,7 @@ Stage& Stage::pragma(IterVar var, const std::string& 
pragma_type,
   } else if (pragma_type == "vectorize") {
     this->vectorize(var);
   } else {
+    With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
     UpdateIterVarAttr(operator->(), var, [pragma_type, 
pragma_value](IterVarAttrNode* n) {
       n->pragma_keys.push_back(tir::StringImm(pragma_type));
       n->pragma_values.push_back(pragma_value);
@@ -389,6 +408,7 @@ Stage& Stage::pragma(IterVar var, const std::string& 
pragma_type,
 }
 
 Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
   ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
@@ -407,6 +427,7 @@ Stage& Stage::prefetch(const Tensor& tensor, IterVar var, 
PrimExpr offset) {
 }
 
 Stage& Stage::storage_align(IterVar axis, int factor, int offset) {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   UpdateIterVarAttr(
       self, axis,
@@ -419,6 +440,7 @@ Stage& Stage::storage_align(IterVar axis, int factor, int 
offset) {
 }
 
 Stage& Stage::double_buffer() {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   ICHECK(!self->is_output) << "Cannot apply double buffer on output";
   self->double_buffer = true;
@@ -426,6 +448,7 @@ Stage& Stage::double_buffer() {
 }
 
 Stage& Stage::rolling_buffer() {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   ICHECK(!self->is_output) << "Cannot apply rolling buffer on output";
   self->rolling_buffer = true;
@@ -434,6 +457,7 @@ Stage& Stage::rolling_buffer() {
 Stage& Stage::transform_layout(const Array<Var>& initial_indices,
                                const Array<PrimExpr>& final_indices,
                                Array<IterVar>* out_iter_vars) {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   IndexMap map(initial_indices, final_indices);
   self->layout_transforms.push_back(map);
@@ -501,6 +525,7 @@ Stage& Stage::transform_layout(const Array<Var>& 
initial_indices,
 }
 
 Stage& Stage::set_axis_separators(const Array<IntImm>& axis_separators) {
+  With<ScheduleContext> ctx(operator->()->attach_sch, __func__);
   StageNode* self = operator->();
   self->axis_separators = axis_separators;
   return *this;
@@ -630,6 +655,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs, 
const Array<Tensor>&
   }
   // Create the new group stage.
   Stage gstage(make_object<StageNode>());
+  gstage->attach_sch = this->operator->();
   gstage->group = parent_group;
   if (parent_group.defined()) {
     ++parent_group->num_child_stages;
@@ -718,6 +744,8 @@ bool ScheduleNode::Contain(const Operation& op) const {
   return stage_map.find(op) != stage_map.end();
 }
 
+TVM_REGISTER_PASS_CONFIG_OPTION("te.keep_schedule_record", Bool);
+
 Schedule::Schedule(Array<Operation> ops) {
   auto n = make_object<ScheduleNode>();
   data_ = n;
@@ -730,7 +758,7 @@ Schedule::Schedule(Array<Operation> ops) {
     output_set.insert(x);
   }
   for (Operation op : post_order) {
-    Stage stage(op);
+    Stage stage(op, this->operator->());
     stage->is_output = output_set.count(op) != 0;
     n->stages.push_back(stage);
     n->stage_map.Set(op, stage);
@@ -754,6 +782,25 @@ Schedule::Schedule(Array<Operation> ops) {
       }
     }
   }
+  transform::PassContext pass_ctx = transform::PassContext::Current();
+  n->keep_schedule_record = 
pass_ctx->GetConfig<Bool>("te.keep_schedule_record", Bool(false));
+  if (n->keep_schedule_record.value()) {
+    // push plain schedule as the very first one
+    n->schedule_record.push_back(copy());
+    n->primitive_record.push_back("vanilla");
+  }
+}
+
+ScheduleContext::ScheduleContext(const ScheduleNode* sch_node, String 
current_primitive_name)
+    : sch_(GetRef<Schedule>(sch_node)), 
current_primitive_name_(current_primitive_name) {}
+
+void ScheduleContext::EnterWithScope() {}
+
+void ScheduleContext::ExitWithScope() {
+  if (sch_.defined() && sch_->keep_schedule_record.value()) {
+    sch_->schedule_record.push_back(sch_.copy());
+    sch_->primitive_record.push_back(current_primitive_name_);
+  }
 }
 
 Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, 
PrimExpr nparts) {
diff --git a/tests/python/contrib/test_tedd.py 
b/tests/python/contrib/test_tedd.py
index 373fb14d36..c1af9f6825 100644
--- a/tests/python/contrib/test_tedd.py
+++ b/tests/python/contrib/test_tedd.py
@@ -16,8 +16,12 @@
 # under the License.
 import re
 
+import tvm
 from tvm import te
 from tvm import topi
+from tvm import relay
+from tvm.relay import testing
+from tvm.relay.backend import Runtime, Executor
 
 
 def findany(pattern, str):
@@ -79,8 +83,8 @@ def test_itervar_relationship_graph():
         findany(r"subgraph cluster_Stage_0", str)
         findany(r"subgraph cluster_Stage_1", str)
         # Check itervars and their types
-        findany(r"\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str)
-        findany(r"\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str)
+        findany(r"\(kDataPar\)\<br/\>T.Range\(0, n\)", str)
+        findany(r"\(kCommReduce\)\<br/\>T.Range\(0, m\)", str)
         # Check the split node
         findany(r"Split_Relation_1_0 +.+\>Split", str)
         # Check all edges to/from the split node
@@ -144,7 +148,57 @@ def test_schedule_tree():
         verify()
 
 
+@tvm.testing.requires_llvm
+def test_tedd_with_schedule_record():
+    """Test to build a nn model and check if all schedules could be 
generated"""
+
+    def check_schedule(executor):
+        from tvm.contrib import tedd
+
+        error = {}
+        for func_name, func_meta in executor.function_metadata.items():
+            # check converted op only
+            if "main" not in func_name:
+                primfunc = list(func_meta.relay_primfuncs.values())[0]
+                schs = primfunc.attrs["schedule"].schedule_record
+                for index in range(len(schs)):
+                    try:
+                        sch = schs[index].normalize()
+                        tedd.viz_dataflow_graph(sch, False, "", True)
+                        tedd.viz_itervar_relationship_graph(sch, False, "", 
True)
+                        tedd.viz_schedule_tree(sch, False, "", True)
+                    except:
+                        if func_name not in error:
+                            error[func_name] = []
+                        error[func_name].append(index)
+
+        assert error == {}, str(error)
+
+    if checkdependency():
+        relay_mod, params = testing.mobilenet.get_workload(batch_size=1, 
dtype="float32")
+        target_llvm = tvm.target.Target("llvm")
+        config = {"te.keep_schedule_record": True}
+
+        with tvm.transform.PassContext(opt_level=3, config=config):
+            aot_executor_factory = relay.build(
+                relay_mod,
+                target_llvm,
+                runtime=Runtime("cpp"),
+                executor=Executor("aot"),
+                params=params,
+            )
+            graph_executor_factory = relay.build(
+                relay_mod,
+                target_llvm,
+                params=params,
+            )
+
+        check_schedule(aot_executor_factory)
+        check_schedule(graph_executor_factory)
+
+
 if __name__ == "__main__":
     test_dfg()
     test_itervar_relationship_graph()
     test_schedule_tree()
+    test_tedd_with_schedule_record()
diff --git a/tests/python/relay/test_build_module.py 
b/tests/python/relay/test_build_module.py
index 5cfc27330a..b1146743ee 100644
--- a/tests/python/relay/test_build_module.py
+++ b/tests/python/relay/test_build_module.py
@@ -21,6 +21,7 @@ import tvm
 import tvm.testing
 from tvm import relay
 from tvm.target.target import Target
+from tvm.relay import testing
 from tvm.relay.backend import Runtime, Executor, graph_executor_codegen
 
 
@@ -62,5 +63,41 @@ def test_build_relay_graph_():
     build_graph(add((1, 8), "float32"), tvm.target.Target("llvm"))
 
 
+@tvm.testing.requires_llvm
+def test_schedule_record():
+    """Test to build a nn model and get schedule_record from build_module"""
+
+    def check_schedule(executor):
+        for func_name, func_meta in executor.function_metadata.items():
+            # check converted op only
+            if "main" not in func_name:
+                primfunc = list(func_meta.relay_primfuncs.values())[0]
+                # make sure schedule is well-stored in function metadata
+                assert "schedule" in primfunc.attrs
+                sch = primfunc.attrs["schedule"]
+                assert len(sch.schedule_record) == len(sch.primitive_record)
+
+    relay_mod, params = testing.mobilenet.get_workload(batch_size=1, 
dtype="float32")
+    target_llvm = tvm.target.Target("llvm")
+    config = {"te.keep_schedule_record": True}
+
+    with tvm.transform.PassContext(opt_level=3, config=config):
+        aot_executor_factory = relay.build(
+            relay_mod,
+            target_llvm,
+            runtime=Runtime("cpp"),
+            executor=Executor("aot"),
+            params=params,
+        )
+        graph_executor_factory = relay.build(
+            relay_mod,
+            target_llvm,
+            params=params,
+        )
+
+    check_schedule(aot_executor_factory)
+    check_schedule(graph_executor_factory)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/unittest/test_te_schedule_ops.py 
b/tests/python/unittest/test_te_schedule_ops.py
index f85cdc6196..1ff0297539 100644
--- a/tests/python/unittest/test_te_schedule_ops.py
+++ b/tests/python/unittest/test_te_schedule_ops.py
@@ -614,6 +614,57 @@ def test_local_stage_predicate2():
     assert any(collect_visit(lowered_body, visit_stmt))
 
 
+def test_schedule_record_gemm():
+    with tvm.transform.PassContext(config={"te.keep_schedule_record": True}):
+        M, K, N = 1024, 1024, 1024
+        k = te.reduce_axis((0, K), "k")
+        A = te.placeholder((M, K), name="A")
+        B = te.placeholder((K, N), name="B")
+        C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), 
name="C")
+        s = te.create_schedule(C.op)
+        # currently there are no other applied primitives
+        # size of schedule record is expected to be 1 (vanilla schedule)
+        assert len(s.schedule_record) == 1
+        # apply sequential optimizatoin primitives
+        block_size, factor = 32, 8
+        # tile -> split + split + reorder
+        mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], block_size, 
block_size)
+        ko, ki = s[C].split(k, factor=factor)
+        s[C].reorder(mo, ko, no, mi, ki, ni)
+        s[C].vectorize(ni)
+        s[C].parallel(mo)
+        assert len(s.schedule_record) == 8
+        # compare primitive names
+        expected_names = [
+            "vanilla",
+            "split",
+            "split",
+            "reorder",
+            "split",
+            "reorder",
+            "vectorize",
+            "parallel",
+        ]
+        for i in range(len(s.schedule_record)):
+            assert s.primitive_record[i] == expected_names[i]
+
+
+def test_schedule_record_misc():
+    s = te.create_schedule([])
+    # size of schedule record is expected to be 0 (no storing behavior)
+    assert len(s.schedule_record) == 0
+
+    with tvm.transform.PassContext(config={"te.keep_schedule_record": True}):
+        s = te.create_schedule([])
+        # size of schedule record is expected to be 1 (vanilla schedule)
+        assert len(s.schedule_record) == 1
+
+        stg = te.compute((), lambda *args: 0, name="empty_op")
+        s = te.create_schedule(stg.op)
+        # size of schedule record is expected to be 1 (vanilla schedule)
+        assert len(s.schedule_record) == 1
+
+
 if __name__ == "__main__":
     test_loop_dep_reduce()
     test_loop_dep_reduce_cache_write()
@@ -640,3 +691,5 @@ if __name__ == "__main__":
     test_schedule_compute_inline()
     test_local_stage_predicate()
     test_local_stage_predicate2()
+    test_schedule_record_gemm()
+    test_schedule_record_misc()

Reply via email to