This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f3fde8121b [TIR] [Schedule] Fix decompose_padding bug with dtypes 
(#15050)
f3fde8121b is described below

commit f3fde8121bd2fd25066ea9ec1191880b67dee268
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Thu Jun 8 12:36:20 2023 +0530

    [TIR] [Schedule] Fix decompose_padding bug with dtypes (#15050)
    
    Ran into a type mismatch error when the primfunc uses int64 dimensions,
    but the extent introduced in a particular case was just int32.
---
 src/tir/schedule/primitive/decompose_padding.cc    |  2 +-
 .../test_tir_schedule_decompose_padding.py         | 41 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 1 deletion(-)

diff --git a/src/tir/schedule/primitive/decompose_padding.cc 
b/src/tir/schedule/primitive/decompose_padding.cc
index 1743a34088..50b978f012 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -168,7 +168,7 @@ class PaddingInfoAnalyzer {
     }
     for (const arith::IterSumExpr& sum : res->indices) {
       if (sum->args.empty()) {
-        region.push_back(Range::FromMinExtent(sum->base, 1));
+        region.push_back(Range::FromMinExtent(sum->base, 
IntImm(sum->base.dtype(), /* value */ 1)));
       } else {
         ICHECK_EQ(sum->args.size(), 1U);
         if (!analyzer_->CanProveEqual(sum->args[0]->scale, 1)) {
diff --git a/tests/python/unittest/test_tir_schedule_decompose_padding.py 
b/tests/python/unittest/test_tir_schedule_decompose_padding.py
index e33cfdbd34..15ed194328 100644
--- a/tests/python/unittest/test_tir_schedule_decompose_padding.py
+++ b/tests/python/unittest/test_tir_schedule_decompose_padding.py
@@ -41,6 +41,47 @@ def check_decompose_padding(origin, scheduled, expected, 
check_run=False):
         tvm.testing.assert_allclose(y0.numpy(), y1.numpy())
 
 
+def test_int64_indices_batch_decompose_padding():
+    @T.prim_func
+    def before_decompose(
+        x: T.Buffer((T.int64(1), T.int64(128), T.int64(128)), "int32"),
+        y: T.Buffer((T.int64(1), T.int64(140), T.int64(128)), "int32"),
+    ):
+        for b, i, j in T.grid(T.int64(1), T.int64(140), T.int64(128)):
+            with T.block("block"):
+                vb, vi, vj = T.axis.remap("SSS", [b, i, j])
+                y[vb, vi, vj] = T.if_then_else(vi < T.int64(128), x[vb, vi, 
vj], 0)
+
+    @T.prim_func
+    def after_decompose(
+        x: T.Buffer((T.int64(1), T.int64(128), T.int64(128)), "int32"),
+        y: T.Buffer((T.int64(1), T.int64(140), T.int64(128)), "int32"),
+    ):
+        # with T.block("root"):
+        for b, i in T.grid(T.int64(1), T.int64(140)):
+            for j in range(T.int64(128)):
+                with T.block("block_pad_const"):
+                    vb = T.axis.spatial(T.int64(1), T.int64(0))
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.reads()
+                    T.writes(y[vb, vi, vj])
+                    y[vb, vi, vj] = 0
+            for j in range(T.int64(128)):
+                with T.block("block"):
+                    vb = T.axis.spatial(T.int64(1), T.int64(0))
+                    vi = T.axis.spatial(T.int64(128), i)
+                    vj = T.axis.spatial(T.int64(128), j)
+                    T.where(i < T.int64(128))
+                    T.reads(x[vb, vi, vj])
+                    T.writes(y[vb, vi, vj])
+                    y[vb, vi, vj] = x[vb, vi, vj]
+
+    sch = tir.Schedule(before_decompose, debug_mask="all")
+    block = sch.get_block("block")
+    sch.decompose_padding(block, sch.get_loops(block)[2])
+    check_decompose_padding(before_decompose, sch.mod["main"], 
after_decompose, check_run=False)
+
+
 def test_1d_decompose_padding():
     @T.prim_func
     def before_decompose(x: T.Buffer(128, "int32"), y: T.Buffer(140, "int32")):

Reply via email to