This is an automated email from the ASF dual-hosted git repository. lunderberg pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 299ca267e7 [TIR] Update region min/extent in ReplaceBufferMutator (#12725) 299ca267e7 is described below commit 299ca267e7641b5fa6e78dd131d0574e310f9a13 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Thu Sep 8 09:35:58 2022 -0700 [TIR] Update region min/extent in ReplaceBufferMutator (#12725) Prior to this commit, `ReplaceBufferMutator` only checks `BufferRegionNode::buffer` to determine if a `BufferRegion` needs to be replaced, and doesn't check the `BufferRegionNode::region`. As a result, updating `T.reads(A[B[i]])` would fail to replace `B`. This commit checks `BufferRegionNode::region` for buffer usage to resolve this issue. --- src/tir/schedule/transform.cc | 27 +++++++++++++++++++--- .../test_tir_schedule_set_axis_separator.py | 24 +++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 1ebaf202d4..c11fa656d6 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -138,9 +138,30 @@ Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { return this->VisitMatchBufferRegion(match_buffer); }; auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { - auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); - return it == buffer_var_map_.end() ? buffer_region - : BufferRegion(it->second, buffer_region->region); + auto region = MutateArray(buffer_region->region, [this](const Range& range) { + PrimExpr min = VisitExpr(range->min); + PrimExpr extent = VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(min, extent); + } + }); + + Buffer buf = [&]() { + auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); + if (it == buffer_var_map_.end()) { + return buffer_region->buffer; + } else { + return it->second; + } + }(); + + if (buf.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(buf, region); + } }; auto f_mutate_alloc_buffers = [this](const Buffer& buffer) { auto it = buffer_var_map_.find(buffer->data.get()); diff --git a/tests/python/unittest/test_tir_schedule_set_axis_separator.py b/tests/python/unittest/test_tir_schedule_set_axis_separator.py index 9502da1829..b432fbb610 100644 --- a/tests/python/unittest/test_tir_schedule_set_axis_separator.py +++ b/tests/python/unittest/test_tir_schedule_set_axis_separator.py @@ -154,6 +154,30 @@ def test_set_axis_separator_subregion(use_sugared_transform): tvm.ir.assert_structural_equal(element_wise_subregion_match_set_axis_separator, s.mod["main"]) verify_trace_roundtrip(sch=s, mod=func) +class TestIndexedLookup(tvm.testing.CompareBeforeAfter): + def transform(self): + def func(mod): + sch = tir.Schedule(mod) + sch.set_axis_separator('block', 'B', [1]) + return sch.mod + return func + + @T.prim_func + def before(): + A = T.alloc_buffer([4,4], dtype="int32") + B = T.alloc_buffer([1,1], dtype="int32") + for j in T.serial(4): + with T.block('block'): + A[B[0,0],j] = 0 + + @T.prim_func + def expected(): + A = T.alloc_buffer([4,4], dtype="int32") + B = T.alloc_buffer([1,1], dtype="int32", axis_separators=[1]) + for j in T.serial(4): + with T.block('block'): + A[B[0,0],j] = 0 + if __name__ == "__main__": tvm.testing.main()