This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new ce1fa8908f [TE] Record primitives of Schedule for visualization (#14168) ce1fa8908f is described below commit ce1fa8908f626e58f245966dd0a2e2540b75dace Author: Chun-I Tsai <quic_chu...@quicinc.com> AuthorDate: Wed Mar 15 04:52:11 2023 +0800 [TE] Record primitives of Schedule for visualization (#14168) * [ScheduleVisualization] - Make Stage link to its Schedule via attach_sch. - Add two array attributes, primitive_record and schedule_record to Schedule. - Create a new class, ScheduleContext, to record primitives. - Register a pass config variable, keep_schedule_record to enable/disable the recording. - Add test cases for TEDD, build_module and schedule ops. * [ScheduleVisualization] * Fix grammar issues * Rewrite unclear comments * [ScheduleVisualization] * Remove the wrong term, rebased in comments and variables. --------- Co-authored-by: Joey Tsai <chu...@qti.qualcomm.com> --- include/tvm/te/schedule.h | 43 +++++++++++++++++++- python/tvm/contrib/tedd.py | 27 ++++++++++++- src/relay/backend/te_compiler.cc | 13 +++++- src/te/schedule/schedule_dataflow_rewrite.cc | 14 ++++--- src/te/schedule/schedule_lang.cc | 51 ++++++++++++++++++++++- tests/python/contrib/test_tedd.py | 58 ++++++++++++++++++++++++++- tests/python/relay/test_build_module.py | 37 +++++++++++++++++ tests/python/unittest/test_te_schedule_ops.py | 53 ++++++++++++++++++++++++ 8 files changed, 283 insertions(+), 13 deletions(-) diff --git a/include/tvm/te/schedule.h b/include/tvm/te/schedule.h index 5d88793206..1b711a8370 100644 --- a/include/tvm/te/schedule.h +++ b/include/tvm/te/schedule.h @@ -62,8 +62,9 @@ class Stage : public ObjectRef { /*! * \brief create a new schedule for op. * \param op The operator in the schedule + * \param sch The schedule which current stage belongs to */ - explicit Stage(Operation op); + explicit Stage(Operation op, const ScheduleNode* sch); /*! * \brief access the internal node container * \return the pointer to the internal node container @@ -445,6 +446,26 @@ class Schedule : public ObjectRef { using ContainerType = ScheduleNode; }; +/*! + * \brief Context helper to collect debug information of Schedule. + * + * Attach With<ScheduleContext>(schedule_instance, primitive_name) + * inside function body of schedule primitives to collect the + * snapshot of schedule status and corresponding primitive name + */ +class ScheduleContext { + private: + friend class With<ScheduleContext>; + ScheduleContext(const ScheduleNode* sch_node, String current_primitive_name); + void EnterWithScope(); + void ExitWithScope(); + + /*! \brief Schedule instance to store information for debug */ + Schedule sch_; + /*! \brief String representing which primitive has been applied to sch_ */ + String current_primitive_name_; +}; + /*! * \brief The schedule relation between IterVars * can be Split, Fuse. @@ -546,6 +567,8 @@ class StageNode : public Object { IterVar attach_ivar; /*! \brief The stage this node attaches to */ Stage attach_stage; + /*! \brief The schedule current stage is attached to */ + const ScheduleNode* attach_sch; /*! \brief The thread storage scope level of the stage */ std::string scope; /*! \brief Whether this is an output stage */ @@ -615,12 +638,30 @@ class ScheduleNode : public Object { * This is created on demand and can be invalidated. */ std::unordered_map<const Object*, Stage> op2stage_cache_; + /*! + * \brief list of all transformed schedules + * User can display the optimization strategy via TEDD step by step to check + * the order and effect of primitives. Set "te.keep_schedule_record" in + * PassContext config as true to enable recording. + */ + Array<Schedule> schedule_record; + /*! + * \brief list of all applied primitive names. + */ + Array<String> primitive_record; + /*! + * \brief Flag to keep schedule record or not. + */ + Optional<Bool> keep_schedule_record; void VisitAttrs(AttrVisitor* v) { v->Visit("outputs", &outputs); v->Visit("stages", &stages); v->Visit("groups", &groups); v->Visit("stage_map", &stage_map); + v->Visit("schedule_record", &schedule_record); + v->Visit("primitive_record", &primitive_record); + v->Visit("keep_schedule_record", &keep_schedule_record); } /*! \brief Initialize temp cache. */ diff --git a/python/tvm/contrib/tedd.py b/python/tvm/contrib/tedd.py index a65f5e474a..aa423d8964 100644 --- a/python/tvm/contrib/tedd.py +++ b/python/tvm/contrib/tedd.py @@ -78,6 +78,27 @@ def insert_dot_id(sch): return sch +def itervar_equal(iv_a, iv_b): + """A helper method that compares the equality of two iterative variables""" + # Adopt the following method to assure the equality between two itervars. + # The plain comparison might fail (i.e. iv_a == iv_b) after the change of + # domain bounds from InferBound. + def _var_equal(v_a, v_b): + condtions = [ + v_a.name == v_b.name, + v_a.dtype == v_b.dtype, + v_a.type_annotation == v_b.type_annotation, + ] + return all(c for c in condtions) + + condtions = [ + _var_equal(iv_a.var, iv_b.var), + iv_a.iter_type == iv_b.iter_type, + iv_a.thread_tag == iv_b.thread_tag, + ] + return all(c for c in condtions) + + class ObjectManager: """A helper class tracking schedule objects, e.g. stage, IterVar, relationship, and tensor, to their DOM path.""" @@ -88,6 +109,10 @@ class ObjectManager: self.dict[stage] = [stage_idx] for itervar_idx, itervar in enumerate(stage.all_iter_vars): self.dict[itervar] = [stage_idx, itervar_idx] + # the itervars of leaf should also be mapped to the original one + for leaf_iv in stage.leaf_iter_vars: + if itervar_equal(leaf_iv, itervar): + self.dict[leaf_iv] = [stage_idx, itervar_idx] for rel_idx, rel in enumerate(stage.relations): self.dict[rel] = [stage_idx, rel_idx] for tensor_idx in range(stage.op.num_outputs): @@ -289,7 +314,7 @@ def dump_json(sch, need_range): def get_leaf_itervar_index(itervar, leaf_iv): for leaf_index, ivar in enumerate(leaf_iv): - if ivar == itervar: + if itervar_equal(ivar, itervar): return leaf_index return -1 diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e20e0c9429..ce47be361e 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -700,7 +700,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { */ Expr MakeLoweredCall(const BaseFunc& original_function, const GlobalVar& prim_fn_var, Array<Expr> args, Span span, const Target& target, - const Map<GlobalVar, BaseFunc>& lowered_functions) { + const Map<GlobalVar, BaseFunc>& lowered_functions, + const te::Schedule& sch = {}) { auto opt_compiler = original_function->GetAttr<String>(attr::kCompiler); // Add some metadata on top of the *original function* and invoke the callback so it can @@ -730,6 +731,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); + // Store generated Schedules of operator + if (sch.defined() && sch->keep_schedule_record) { + func_with_metadata = WithAttr(func_with_metadata, "schedule", sch); + } this->process_fn_(func_with_metadata); } else { const auto* function_node = original_function.as<FunctionNode>(); @@ -738,6 +743,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { func_with_metadata = WithAttr(func_with_metadata, "prim_fn_var", prim_fn_var); func_with_metadata = WithAttr(func_with_metadata, "prim_funcs", prim_fns); func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, target); + // Store generated Schedules of operator + if (sch.defined() && sch->keep_schedule_record) { + func_with_metadata = WithAttr(func_with_metadata, "schedule", sch); + } this->process_fn_(func_with_metadata); } @@ -926,7 +935,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { CachedFunc cfunc = compiler_->Lower(key); ICHECK(cfunc.defined()); return MakeLoweredCall(primitive_func, cfunc->prim_fn_var, std::move(new_args), - call_node->span, target, cfunc->funcs->functions); + call_node->span, target, cfunc->funcs->functions, cfunc->schedule); } } diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c1741e9e4e..c38c5a5c80 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -174,10 +174,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, Array<Stage>& stages = (*this)->stages; Stage op_stage = operator[](tensor->op); size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage); - Stage cache_stage = Stage(cache->op); - cache_stage.set_scope(scope); + Stage cache_stage = Stage(cache->op, this->operator->()); ICHECK_LT(pos, stages.size()); stages.insert(stages.begin() + pos + 1, cache_stage); + // in order to obtain correct copy on schedule_record, + // make sure "set_scope" primitive is applied after stage being added + cache_stage.set_scope(scope); (*this)->stage_map.Set(cache->op, cache_stage); // Update group cache_stage->group = op_stage->group; @@ -266,10 +268,12 @@ Array<Tensor> ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::strin // create schedule for new cached stage. Array<Stage>& stages = sch->stages; size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage); - Stage cache_stage = Stage(cache_op); - cache_stage.set_scope(scope); + Stage cache_stage = Stage(cache_op, sch.operator->()); ICHECK_LT(pos, stages.size()); stages.insert(stages.begin() + pos, cache_stage); + // in order to obtain correct copy on schedule_record, + // make sure "set_scope" primitive is applied after stage being added + cache_stage.set_scope(scope); sch->stage_map.Set(cache_op, cache_stage); // Update group cache_stage->group = orig_stage->group; @@ -892,7 +896,7 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f Operation factor_op(n); Array<Stage>& stages = (*this)->stages; size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage); - Stage factor_stage = Stage(factor_op); + Stage factor_stage = Stage(factor_op, this->operator->()); factor_stage->relations = rels; ICHECK_LT(stage_pos, stages.size()); stages.insert(stages.begin() + stage_pos, factor_stage); diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index e8f4f65eb6..56fe0cfc65 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -21,6 +21,7 @@ * \file schedule_lang.cc */ #include <dmlc/thread_local.h> +#include <tvm/ir/transform.h> #include <tvm/runtime/registry.h> #include <tvm/te/operation.h> #include <tvm/te/schedule.h> @@ -91,7 +92,7 @@ void SplitHelper(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr npar leaf_vars.insert(leaf_vars.begin() + pos, outer); } -Stage::Stage(Operation op) { +Stage::Stage(Operation op, const ScheduleNode* sch) { auto n = make_object<StageNode>(); n->op = op; n->origin_op = op; @@ -106,6 +107,7 @@ Stage::Stage(Operation op) { } else { n->leaf_iter_vars = clean; } + n->attach_sch = sch; data_ = std::move(n); } @@ -124,11 +126,13 @@ Stage Stage::GetAttachSpec() const { } Stage& Stage::set_scope(std::string scope) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); (*this)->scope = scope; return *this; } Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; // Group constraint checking. Stage group = (*this)->group; @@ -156,18 +160,21 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) } Stage& Stage::compute_inline() { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kInline; return *this; } Stage& Stage::compute_root() { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); ICHECK_NE((*this)->attach_type, kScanUpdate) << "Cannot specify compute_at for scan updates"; (*this)->attach_type = kGroupRoot; return *this; } Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); ICHECK(ivar->iter_type == kDataPar || ivar->iter_type == kCommReduce) << "Cannot bind " << IterVarType2String(ivar->iter_type) << " to thread"; @@ -194,6 +201,7 @@ Stage& Stage::bind(IterVar ivar, IterVar thread_ivar) { // NOLINT(*) } Stage& Stage::env_threads(Array<IterVar> threads) { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); ICHECK(self->op.defined() && self->op.as<ScanOpNode>()) << "env_threads is only valid for composite ops such as ScanOp"; @@ -211,6 +219,7 @@ Stage& Stage::env_threads(Array<IterVar> threads) { } Stage& Stage::set_store_predicate(PrimExpr predicate) { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); self->store_predicate = predicate; return *this; @@ -218,17 +227,20 @@ Stage& Stage::set_store_predicate(PrimExpr predicate) { Stage& Stage::split(IterVar parent, PrimExpr factor, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); SplitHelper(operator->(), parent, factor, PrimExpr(), p_outer, p_inner); return *this; } Stage& Stage::split_by_nparts(IterVar parent, PrimExpr nparts, IterVar* p_outer, IterVar* p_inner) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); SplitHelper(operator->(), parent, PrimExpr(), nparts, p_outer, p_inner); return *this; } Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); ICHECK(outer->iter_type == kDataPar || outer->iter_type == kCommReduce || outer->iter_type == kOrdered) @@ -264,6 +276,7 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT } Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); if (axes.size() != 0) { IterVar fused = axes[0]; for (size_t i = 1; i < axes.size(); ++i) { @@ -287,6 +300,7 @@ Stage& Stage::fuse(const Array<IterVar>& axes, IterVar* p_target) { // NOLINT(* } Stage& Stage::reorder(const Array<IterVar>& order) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); std::unordered_set<IterVar> seen_var; StageNode* self = operator->(); for (IterVar iv : order) { @@ -347,6 +361,7 @@ inline void SetAttrIterType(StageNode* self, IterVar var, IterVarType iter_type) } Stage& Stage::vectorize(IterVar var) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); ICHECK(var->iter_type == kDataPar || var->iter_type == kOpaque || var->iter_type == kUnrolled || var->iter_type == kVectorized || var->iter_type == kTensorized || var->iter_type == kParallelized) @@ -356,6 +371,7 @@ Stage& Stage::vectorize(IterVar var) { // NOLINT(*) } Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); UpdateIterVarAttr(operator->(), var, [f](IterVarAttrNode* n) { n->iter_type = kTensorized; n->tensor_intrin = f; @@ -364,11 +380,13 @@ Stage& Stage::tensorize(IterVar var, TensorIntrin f) { // NOLINT(*) } Stage& Stage::unroll(IterVar var) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); SetAttrIterType(operator->(), var, kUnrolled); return *this; } Stage& Stage::parallel(IterVar var) { // NOLINT(*) + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); SetAttrIterType(operator->(), var, kParallelized); return *this; } @@ -380,6 +398,7 @@ Stage& Stage::pragma(IterVar var, const std::string& pragma_type, } else if (pragma_type == "vectorize") { this->vectorize(var); } else { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); UpdateIterVarAttr(operator->(), var, [pragma_type, pragma_value](IterVarAttrNode* n) { n->pragma_keys.push_back(tir::StringImm(pragma_type)); n->pragma_values.push_back(pragma_value); @@ -389,6 +408,7 @@ Stage& Stage::pragma(IterVar var, const std::string& pragma_type, } Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -407,6 +427,7 @@ Stage& Stage::prefetch(const Tensor& tensor, IterVar var, PrimExpr offset) { } Stage& Stage::storage_align(IterVar axis, int factor, int offset) { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); UpdateIterVarAttr( self, axis, @@ -419,6 +440,7 @@ Stage& Stage::storage_align(IterVar axis, int factor, int offset) { } Stage& Stage::double_buffer() { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); ICHECK(!self->is_output) << "Cannot apply double buffer on output"; self->double_buffer = true; @@ -426,6 +448,7 @@ Stage& Stage::double_buffer() { } Stage& Stage::rolling_buffer() { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); ICHECK(!self->is_output) << "Cannot apply rolling buffer on output"; self->rolling_buffer = true; @@ -434,6 +457,7 @@ Stage& Stage::rolling_buffer() { Stage& Stage::transform_layout(const Array<Var>& initial_indices, const Array<PrimExpr>& final_indices, Array<IterVar>* out_iter_vars) { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); IndexMap map(initial_indices, final_indices); self->layout_transforms.push_back(map); @@ -501,6 +525,7 @@ Stage& Stage::transform_layout(const Array<Var>& initial_indices, } Stage& Stage::set_axis_separators(const Array<IntImm>& axis_separators) { + With<ScheduleContext> ctx(operator->()->attach_sch, __func__); StageNode* self = operator->(); self->axis_separators = axis_separators; return *this; @@ -630,6 +655,7 @@ Stage Schedule::create_group(const Array<Tensor>& outputs, const Array<Tensor>& } // Create the new group stage. Stage gstage(make_object<StageNode>()); + gstage->attach_sch = this->operator->(); gstage->group = parent_group; if (parent_group.defined()) { ++parent_group->num_child_stages; @@ -718,6 +744,8 @@ bool ScheduleNode::Contain(const Operation& op) const { return stage_map.find(op) != stage_map.end(); } +TVM_REGISTER_PASS_CONFIG_OPTION("te.keep_schedule_record", Bool); + Schedule::Schedule(Array<Operation> ops) { auto n = make_object<ScheduleNode>(); data_ = n; @@ -730,7 +758,7 @@ Schedule::Schedule(Array<Operation> ops) { output_set.insert(x); } for (Operation op : post_order) { - Stage stage(op); + Stage stage(op, this->operator->()); stage->is_output = output_set.count(op) != 0; n->stages.push_back(stage); n->stage_map.Set(op, stage); @@ -754,6 +782,25 @@ Schedule::Schedule(Array<Operation> ops) { } } } + transform::PassContext pass_ctx = transform::PassContext::Current(); + n->keep_schedule_record = pass_ctx->GetConfig<Bool>("te.keep_schedule_record", Bool(false)); + if (n->keep_schedule_record.value()) { + // push plain schedule as the very first one + n->schedule_record.push_back(copy()); + n->primitive_record.push_back("vanilla"); + } +} + +ScheduleContext::ScheduleContext(const ScheduleNode* sch_node, String current_primitive_name) + : sch_(GetRef<Schedule>(sch_node)), current_primitive_name_(current_primitive_name) {} + +void ScheduleContext::EnterWithScope() {} + +void ScheduleContext::ExitWithScope() { + if (sch_.defined() && sch_->keep_schedule_record.value()) { + sch_->schedule_record.push_back(sch_.copy()); + sch_->primitive_record.push_back(current_primitive_name_); + } } Split::Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts) { diff --git a/tests/python/contrib/test_tedd.py b/tests/python/contrib/test_tedd.py index 373fb14d36..c1af9f6825 100644 --- a/tests/python/contrib/test_tedd.py +++ b/tests/python/contrib/test_tedd.py @@ -16,8 +16,12 @@ # under the License. import re +import tvm from tvm import te from tvm import topi +from tvm import relay +from tvm.relay import testing +from tvm.relay.backend import Runtime, Executor def findany(pattern, str): @@ -79,8 +83,8 @@ def test_itervar_relationship_graph(): findany(r"subgraph cluster_Stage_0", str) findany(r"subgraph cluster_Stage_1", str) # Check itervars and their types - findany(r"\(kDataPar\)\<br/\>range\(min=0, ext=n\)", str) - findany(r"\(kCommReduce\)\<br/\>range\(min=0, ext=m\)", str) + findany(r"\(kDataPar\)\<br/\>T.Range\(0, n\)", str) + findany(r"\(kCommReduce\)\<br/\>T.Range\(0, m\)", str) # Check the split node findany(r"Split_Relation_1_0 +.+\>Split", str) # Check all edges to/from the split node @@ -144,7 +148,57 @@ def test_schedule_tree(): verify() +@tvm.testing.requires_llvm +def test_tedd_with_schedule_record(): + """Test to build a nn model and check if all schedules could be generated""" + + def check_schedule(executor): + from tvm.contrib import tedd + + error = {} + for func_name, func_meta in executor.function_metadata.items(): + # check converted op only + if "main" not in func_name: + primfunc = list(func_meta.relay_primfuncs.values())[0] + schs = primfunc.attrs["schedule"].schedule_record + for index in range(len(schs)): + try: + sch = schs[index].normalize() + tedd.viz_dataflow_graph(sch, False, "", True) + tedd.viz_itervar_relationship_graph(sch, False, "", True) + tedd.viz_schedule_tree(sch, False, "", True) + except: + if func_name not in error: + error[func_name] = [] + error[func_name].append(index) + + assert error == {}, str(error) + + if checkdependency(): + relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") + target_llvm = tvm.target.Target("llvm") + config = {"te.keep_schedule_record": True} + + with tvm.transform.PassContext(opt_level=3, config=config): + aot_executor_factory = relay.build( + relay_mod, + target_llvm, + runtime=Runtime("cpp"), + executor=Executor("aot"), + params=params, + ) + graph_executor_factory = relay.build( + relay_mod, + target_llvm, + params=params, + ) + + check_schedule(aot_executor_factory) + check_schedule(graph_executor_factory) + + if __name__ == "__main__": test_dfg() test_itervar_relationship_graph() test_schedule_tree() + test_tedd_with_schedule_record() diff --git a/tests/python/relay/test_build_module.py b/tests/python/relay/test_build_module.py index 5cfc27330a..b1146743ee 100644 --- a/tests/python/relay/test_build_module.py +++ b/tests/python/relay/test_build_module.py @@ -21,6 +21,7 @@ import tvm import tvm.testing from tvm import relay from tvm.target.target import Target +from tvm.relay import testing from tvm.relay.backend import Runtime, Executor, graph_executor_codegen @@ -62,5 +63,41 @@ def test_build_relay_graph_(): build_graph(add((1, 8), "float32"), tvm.target.Target("llvm")) +@tvm.testing.requires_llvm +def test_schedule_record(): + """Test to build a nn model and get schedule_record from build_module""" + + def check_schedule(executor): + for func_name, func_meta in executor.function_metadata.items(): + # check converted op only + if "main" not in func_name: + primfunc = list(func_meta.relay_primfuncs.values())[0] + # make sure schedule is well-stored in function metadata + assert "schedule" in primfunc.attrs + sch = primfunc.attrs["schedule"] + assert len(sch.schedule_record) == len(sch.primitive_record) + + relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") + target_llvm = tvm.target.Target("llvm") + config = {"te.keep_schedule_record": True} + + with tvm.transform.PassContext(opt_level=3, config=config): + aot_executor_factory = relay.build( + relay_mod, + target_llvm, + runtime=Runtime("cpp"), + executor=Executor("aot"), + params=params, + ) + graph_executor_factory = relay.build( + relay_mod, + target_llvm, + params=params, + ) + + check_schedule(aot_executor_factory) + check_schedule(graph_executor_factory) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index f85cdc6196..1ff0297539 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -614,6 +614,57 @@ def test_local_stage_predicate2(): assert any(collect_visit(lowered_body, visit_stmt)) +def test_schedule_record_gemm(): + with tvm.transform.PassContext(config={"te.keep_schedule_record": True}): + M, K, N = 1024, 1024, 1024 + k = te.reduce_axis((0, K), "k") + A = te.placeholder((M, K), name="A") + B = te.placeholder((K, N), name="B") + C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C") + s = te.create_schedule(C.op) + # currently there are no other applied primitives + # size of schedule record is expected to be 1 (vanilla schedule) + assert len(s.schedule_record) == 1 + # apply sequential optimizatoin primitives + block_size, factor = 32, 8 + # tile -> split + split + reorder + mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], block_size, block_size) + ko, ki = s[C].split(k, factor=factor) + s[C].reorder(mo, ko, no, mi, ki, ni) + s[C].vectorize(ni) + s[C].parallel(mo) + assert len(s.schedule_record) == 8 + # compare primitive names + expected_names = [ + "vanilla", + "split", + "split", + "reorder", + "split", + "reorder", + "vectorize", + "parallel", + ] + for i in range(len(s.schedule_record)): + assert s.primitive_record[i] == expected_names[i] + + +def test_schedule_record_misc(): + s = te.create_schedule([]) + # size of schedule record is expected to be 0 (no storing behavior) + assert len(s.schedule_record) == 0 + + with tvm.transform.PassContext(config={"te.keep_schedule_record": True}): + s = te.create_schedule([]) + # size of schedule record is expected to be 1 (vanilla schedule) + assert len(s.schedule_record) == 1 + + stg = te.compute((), lambda *args: 0, name="empty_op") + s = te.create_schedule(stg.op) + # size of schedule record is expected to be 1 (vanilla schedule) + assert len(s.schedule_record) == 1 + + if __name__ == "__main__": test_loop_dep_reduce() test_loop_dep_reduce_cache_write() @@ -640,3 +691,5 @@ if __name__ == "__main__": test_schedule_compute_inline() test_local_stage_predicate() test_local_stage_predicate2() + test_schedule_record_gemm() + test_schedule_record_misc()