This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 95a2fff959 [MetaSchedule] Fix mma default rule and disable tuning abort (#15437) 95a2fff959 is described below commit 95a2fff959a7f9003d4a0d092089ffb6ec5daf76 Author: Tian Xia <2630737...@qq.com> AuthorDate: Mon Jul 31 13:34:51 2023 -0700 [MetaSchedule] Fix mma default rule and disable tuning abort (#15437) fix --- .../multi_level_tiling_tensor_core.cc | 5 ++++- src/meta_schedule/schedule_rule/schedule_rule.cc | 23 +++++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 18bd58510d..d519187d30 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -336,6 +336,10 @@ std::vector<State> MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta const BlockRV& block_rv = state->block_rv; // Step 1. Assuming trivial binding, pair the loops and their iter-var-types Array<LoopRV> loops = sch->GetLoops(block_rv); + if (!(loops.size() == 3 || !state->is_mma)) { + LOG(DEBUG) << "The MMA tensor core only supports SSR loops now"; + return {}; + } std::vector<IterVarType> iter_types = GetBlockVarTypes(sch->GetSRef(state->block_rv)); ICHECK_EQ(loops.size(), iter_types.size()); // Step 2. For each loop axis, tile it @@ -344,7 +348,6 @@ std::vector<State> MultiLevelTilingTensorCoreNode::MMATileLoopNest(TensorCoreSta state->tile_factors.resize(tiles.size()); std::vector<Array<tir::ExprRV>> tile_factors; tile_factors.resize(tiles.size()); - ICHECK(loops.size() == 3 || !state->is_mma) << "The MMA tensor core only supports SSR loops now"; for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector<int>* idx = nullptr; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 2a5efcd760..3be2643324 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -171,7 +171,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDA() { } Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { - Array<Map<String, String>> intrin_groups = { + Array<Map<String, String>> wmma_intrin_groups = { // Tensor Cores f32 += f16 * f16 { {"init", "wmma_fill_16x16x16_f32"}, @@ -217,6 +217,8 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { {"compute", "wmma_sync_16x16x16_s8s8s32_trans"}, {"store", "wmma_store_16x16x16_s32_shared_dyn"}, }, + }; + Array<Map<String, String>> mma_intrin_groups = { // Tensor Core MMA { {"init", "mma_init_m16n8k8_f16"}, @@ -236,7 +238,7 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { Array<ScheduleRule> results{ ScheduleRule::ApplyCustomRule(), ScheduleRule::MultiLevelTilingTensorCore( - /*intrin_groups=*/intrin_groups, + /*intrin_groups=*/wmma_intrin_groups, /*structure=*/"SSSRRSRS", /*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, /*max_innermost_factor=*/Integer(4), @@ -249,7 +251,22 @@ Array<ScheduleRule> ScheduleRule::DefaultCUDATensorCore() { Map<String, ObjectRef>{{"req", String("must")}, {"levels", Array<Integer>{2}}, // {"scope", String("shared.dyn")}}, - /*use_software_pipeline=*/false) // + /*use_software_pipeline=*/false), // + ScheduleRule::MultiLevelTilingTensorCore( + /*intrin_groups=*/mma_intrin_groups, + /*structure=*/"SSSRRSRS", + /*tile_binds=*/Array<String>{"blockIdx.y", "blockIdx.x", "threadIdx.y"}, + /*max_innermost_factor=*/Integer(4), + /*vector_load_lens=*/Array<Integer>{1, 2, 3, 4, 8, 16}, + /*reuse_read=*/ + Map<String, ObjectRef>{{"req", String("must")}, + {"levels", Array<Integer>{4}}, // + {"scope", String("shared.dyn")}}, + /*reuse_write=*/ + Map<String, ObjectRef>{{"req", String("no")}, + {"levels", Array<Integer>{2}}, // + {"scope", String("shared.dyn")}}, + /*use_software_pipeline=*/true) // }; Array<ScheduleRule> append = ScheduleRule::DefaultCUDA(); results.insert(results.end(), append.begin() + 1, append.end());