This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aeda760e5e [TIR] Disallow unused rhs vars in GetAutoTensorizeMapping 
(#12225)
aeda760e5e is described below

commit aeda760e5e29eddd0a7ddb22c7031f9607440770
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri Jul 29 01:00:21 2022 -0700

    [TIR] Disallow unused rhs vars in GetAutoTensorizeMapping (#12225)
---
 src/tir/schedule/analysis/analysis.cc              |  5 +++++
 .../python/unittest/test_tir_schedule_analysis.py  | 24 ++++++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 569259d061..72b8c12fea 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -2460,6 +2460,7 @@ class AutoTensorizeMappingProposer {
     }
 
     // Step 3: Fuse LHS iters mapped to the same RHS iter
+    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_rhs_vars;
     for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) {
       const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var;
       const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var];
@@ -2472,12 +2473,16 @@ class AutoTensorizeMappingProposer {
         PrimExpr updated_fused_lhs =
             fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i];
         fused_lhs_iters.Set(rhs_var, updated_fused_lhs);
+        used_rhs_vars.insert(rhs_var);
       } else {
         // non-unique mapping is not supported
         return {};
       }
     }
     for (const auto& iter : extractor_->rhs_iters_) {
+      if (!used_rhs_vars.count(iter->var)) {
+        return {};
+      }
       index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var]));
     }
     // At most one mapping is supported.
diff --git a/tests/python/unittest/test_tir_schedule_analysis.py 
b/tests/python/unittest/test_tir_schedule_analysis.py
index 625343f740..d3e6033e88 100644
--- a/tests/python/unittest/test_tir_schedule_analysis.py
+++ b/tests/python/unittest/test_tir_schedule_analysis.py
@@ -265,6 +265,9 @@ def check_index_map(workload, block_name, intrin_name, 
expected_index_map):
     block = s.get_block(block_name)
     desc_func = TensorIntrin.get(intrin_name).desc
     info = get_auto_tensorize_mapping_info(s, block, desc_func)
+    if expected_index_map is None:
+        assert info is None
+        return
     assert len(info.mappings) == 1
     assert 
IndexMap.from_func(expected_index_map).is_equivalent_to(info.mappings[0])
 
@@ -304,5 +307,26 @@ def test_get_auto_tensorize_mapping_info_batch_matmul(b, 
m, n, k):
     )
 
 
+@pytest.mark.parametrize(
+    "n,m,k,expected",
+    [
+        (
+            512,
+            512,
+            512,
+            lambda n, m, k: (
+                n,
+                m,
+                k,
+            ),
+        ),
+        (1, 32, 32, None),
+    ],
+)
+def test_get_auto_tensorize_mapping_info_matmul(n, m, k, expected):
+    matmul = create_prim_func(te_workload.matmul(n, m, k, in_dtype="float16", 
out_dtype="float32"))
+    check_index_map(matmul, "C", WMMA_SYNC_16x16x16_f16f16f32_INTRIN, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to