This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new aeda760e5e [TIR] Disallow unused rhs vars in GetAutoTensorizeMapping (#12225) aeda760e5e is described below commit aeda760e5e29eddd0a7ddb22c7031f9607440770 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Fri Jul 29 01:00:21 2022 -0700 [TIR] Disallow unused rhs vars in GetAutoTensorizeMapping (#12225) --- src/tir/schedule/analysis/analysis.cc | 5 +++++ .../python/unittest/test_tir_schedule_analysis.py | 24 ++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 569259d061..72b8c12fea 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2460,6 +2460,7 @@ class AutoTensorizeMappingProposer { } // Step 3: Fuse LHS iters mapped to the same RHS iter + std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_rhs_vars; for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) { const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var; const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var]; @@ -2472,12 +2473,16 @@ class AutoTensorizeMappingProposer { PrimExpr updated_fused_lhs = fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i]; fused_lhs_iters.Set(rhs_var, updated_fused_lhs); + used_rhs_vars.insert(rhs_var); } else { // non-unique mapping is not supported return {}; } } for (const auto& iter : extractor_->rhs_iters_) { + if (!used_rhs_vars.count(iter->var)) { + return {}; + } index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var])); } // At most one mapping is supported. diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 625343f740..d3e6033e88 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -265,6 +265,9 @@ def check_index_map(workload, block_name, intrin_name, expected_index_map): block = s.get_block(block_name) desc_func = TensorIntrin.get(intrin_name).desc info = get_auto_tensorize_mapping_info(s, block, desc_func) + if expected_index_map is None: + assert info is None + return assert len(info.mappings) == 1 assert IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0]) @@ -304,5 +307,26 @@ def test_get_auto_tensorize_mapping_info_batch_matmul(b, m, n, k): ) +@pytest.mark.parametrize( + "n,m,k,expected", + [ + ( + 512, + 512, + 512, + lambda n, m, k: ( + n, + m, + k, + ), + ), + (1, 32, 32, None), + ], +) +def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected): + matmul = create_prim_func(te_workload.matmul(n, m, k, in_dtype="float16", out_dtype="float32")) + check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected) + + if __name__ == "__main__": tvm.testing.main()