This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ea0e29f425 [MetaSchedule][Test] Add unittests for GRP (#12246)
ea0e29f425 is described below

commit ea0e29f425e10add9d12c362c738452fbe890ba6
Author: Junru Shao <junrushao1...@gmail.com>
AuthorDate: Sun Jul 31 03:02:47 2022 -0700

    [MetaSchedule][Test] Add unittests for GRP (#12246)
---
 .../unittest/test_meta_schedule_space_cpu.py       | 175 +++++++++++++++++++++
 .../unittest/test_meta_schedule_space_cuda.py      |  90 +++++++++++
 2 files changed, 265 insertions(+)

diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py 
b/tests/python/unittest/test_meta_schedule_space_cpu.py
index 7d601a7b0b..cb8be2999f 100644
--- a/tests/python/unittest/test_meta_schedule_space_cpu.py
+++ b/tests/python/unittest/test_meta_schedule_space_cpu.py
@@ -1201,6 +1201,180 @@ def test_cpu_gmm():
     )
 
 
+def test_cpu_grp():
+    # fmt: off
+    @T.prim_func
+    def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: 
T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), 
"float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+            PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+            conv2d_nhwc_global = T.alloc_buffer([1, 28, 28, 128], 
dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 1, 2):
+                for ax0, ax1, ax2, ax3 in T.grid(1, 9, 57, 32):
+                    with T.block("PadInput"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(58, i1_0 * 8 + ax1)
+                        i2 = T.axis.spatial(58, ax2)
+                        i3 = T.axis.spatial(64, i3_0 * 32 + ax3)
+                        T.reads(inputs[i0, i1 - 1, i2 - 1, i3])
+                        T.writes(PadInput[i0, i1, i2, i3])
+                        PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and 
i1 < 57 and 1 <= i2 and i2 < 57, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), 
dtype="float32")
+                for i0_1, i1_1, i2_1, i3_1 in T.grid(1, 4, 1, 1):
+                    for i4_0, i5_0, i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, 
i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 
16):
+                        with T.block("conv2d_nhwc"):
+                            n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                            h = T.axis.spatial(28, i1_0 * 4 + i1_1 + i1_2 + 
i1_3)
+                            w = T.axis.spatial(28, i2_0 * 28 + i2_1 * 28 + 
i2_2 * 7 + i2_3)
+                            co = T.axis.spatial(128, i3_0 * 64 + i3_1 * 64 + 
i3_2 * 16 + i3_3)
+                            rh = T.axis.reduce(3, i4_0 * 3 + i4_1)
+                            rw = T.axis.reduce(3, i5_0 + i5_1)
+                            rc = T.axis.reduce(16, i6_0 * 2 + i6_1)
+                            T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 
32 * 16 + rc], weight[rh, rw, rc, co])
+                            T.writes(conv2d_nhwc_global[n, h, w, co])
+                            
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                            with T.init():
+                                conv2d_nhwc_global[n, h, w, co] = T.float32(0)
+                            conv2d_nhwc_global[n, h, w, co] = 
conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 
* 16 + rc] * weight[rh, rw, rc, co]
+                    for ax0, ax1, ax2, ax3 in T.grid(1, 1, 28, 64):
+                        with T.block("conv2d_nhwc_global"):
+                            v0 = T.axis.spatial(1, ax0)
+                            v1 = T.axis.spatial(28, i1_0 * 4 + i1_1 + ax1)
+                            v2 = T.axis.spatial(28, ax2)
+                            v3 = T.axis.spatial(128, i3_0 * 64 + ax3)
+                            T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+                            T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                            conv2d_nhwc[v0, v1, v2, v3] = 
conv2d_nhwc_global[v0, v1, v2, v3]
+    @T.prim_func
+    def grp_1(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: 
T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), 
"float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64})
+            PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+            conv2d_nhwc_global = T.alloc_buffer([1, 28, 28, 128], 
dtype="float32")
+            for i0, i1, i2, i3 in T.grid(1, 58, 58, 64):
+                with T.block("PadInput"):
+                    i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1])
+                    T.writes(PadInput[i0_1, i1_1, i2_1, i3_1])
+                    PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= 
i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, inputs[i0_1, i1_1 - 1, i2_1 - 
1, i3_1], T.float32(0), dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 7, 1, 2):
+                for i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_0, i5_0, i6_0, i0_2, 
i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 4, 1, 
1, 1, 3, 8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16):
+                    with T.block("conv2d_nhwc"):
+                        n = T.axis.spatial(1, i0_3 + i0_0 + i0_1_1 + i0_2)
+                        h = T.axis.spatial(28, i1_0 * 4 + i1_1_1 + i1_2 + i1_3)
+                        w = T.axis.spatial(28, i2_0 * 28 + i2_1_1 * 28 + i2_2 
* 7 + i2_3)
+                        co = T.axis.spatial(128, i3_0 * 64 + i3_1_1 * 64 + 
i3_2 * 16 + i3_3)
+                        rh = T.axis.reduce(3, i4_0 * 3 + i4_1)
+                        rw = T.axis.reduce(3, i5_0 + i5_1)
+                        rc = T.axis.reduce(16, i6_0 * 2 + i6_1)
+                        T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 
16 + rc], weight[rh, rw, rc, co])
+                        T.writes(conv2d_nhwc_global[n, h, w, co])
+                        
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                        with T.init():
+                            conv2d_nhwc_global[n, h, w, co] = T.float32(0)
+                        conv2d_nhwc_global[n, h, w, co] = 
conv2d_nhwc_global[n, h, w, co] + PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 
* 16 + rc] * weight[rh, rw, rc, co]
+                for ax0, ax1, ax2, ax3 in T.grid(1, 4, 28, 64):
+                    with T.block("conv2d_nhwc_global"):
+                        v0 = T.axis.spatial(1, ax0)
+                        v1 = T.axis.spatial(28, i1_0 * 4 + ax1)
+                        v2 = T.axis.spatial(28, ax2)
+                        v3 = T.axis.spatial(128, i3_0 * 64 + ax3)
+                        T.reads(conv2d_nhwc_global[v0, v1, v2, v3])
+                        T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                        conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, 
v1, v2, v3]
+    @T.prim_func
+    def grp_2(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: 
T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), 
"float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.parallel":288, 
"meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
+            PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float32")
+            for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1, i4_0, i5_0 in 
T.grid(1, 7, 1, 2, 1, 4, 1, 1, 1, 3):
+                for ax0, ax1, ax2, ax3 in T.grid(1, 3, 55, 32):
+                    with T.block("PadInput"):
+                        i0 = T.axis.spatial(1, ax0)
+                        i1 = T.axis.spatial(58, i1_0 * 8 + i1_1 * 2 + ax1)
+                        i2 = T.axis.spatial(58, i5_0 + ax2)
+                        i3 = T.axis.spatial(64, i3_0 * 32 + ax3)
+                        T.reads(inputs[i0, i1 - 1, i2 - 1, i3])
+                        T.writes(PadInput[i0, i1, i2, i3])
+                        PadInput[i0, i1, i2, i3] = T.if_then_else(1 <= i1 and 
i1 < 57 and 1 <= i2 and i2 < 57, inputs[i0, i1 - 1, i2 - 1, i3], T.float32(0), 
dtype="float32")
+                for i6_0, i0_2, i1_2, i2_2, i3_2, i4_1, i5_1, i6_1, i0_3, 
i1_3, i2_3, i3_3 in T.grid(8, 1, 1, 4, 4, 3, 1, 2, 1, 1, 7, 16):
+                    with T.block("conv2d_nhwc"):
+                        n = T.axis.spatial(1, i0_3 + i0_0 + i0_1 + i0_2)
+                        h = T.axis.spatial(28, i1_0 * 4 + i1_1 + i1_2 + i1_3)
+                        w = T.axis.spatial(28, i2_0 * 28 + i2_1 * 28 + i2_2 * 
7 + i2_3)
+                        co = T.axis.spatial(128, i3_0 * 64 + i3_1 * 64 + i3_2 
* 16 + i3_3)
+                        rh = T.axis.reduce(3, i4_0 * 3 + i4_1)
+                        rw = T.axis.reduce(3, i5_0 + i5_1)
+                        rc = T.axis.reduce(16, i6_0 * 2 + i6_1)
+                        T.reads(PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 
16 + rc], weight[rh, rw, rc, co])
+                        T.writes(conv2d_nhwc[n, h, w, co])
+                        
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
+                        with T.init():
+                            conv2d_nhwc[n, h, w, co] = T.float32(0)
+                        conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + 
PadInput[n, h * 2 + rh, w * 2 + rw, co // 32 * 16 + rc] * weight[rh, rw, rc, co]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [7, 4, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 4, 7]),
+        ("SamplePerfectTile", [2, 1, 4, 16]),
+        ("SamplePerfectTile", [1, 3]),
+        ("SamplePerfectTile", [3, 1]),
+        ("SamplePerfectTile", [8, 2]),
+        ("SampleCategorical", 1),
+        ("SampleComputeLocation", 3),
+    ]
+    decision_1 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [7, 4, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 4, 7]),
+        ("SamplePerfectTile", [2, 1, 4, 16]),
+        ("SamplePerfectTile", [1, 3]),
+        ("SamplePerfectTile", [3, 1]),
+        ("SamplePerfectTile", [8, 2]),
+        ("SampleCategorical", 3),
+        ("SampleComputeLocation", -1),
+    ]
+    decision_2 = [
+        ("SamplePerfectTile", [1, 1, 1, 1]),
+        ("SamplePerfectTile", [7, 4, 1, 1]),
+        ("SamplePerfectTile", [1, 1, 4, 7]),
+        ("SamplePerfectTile", [2, 1, 4, 16]),
+        ("SamplePerfectTile", [1, 3]),
+        ("SamplePerfectTile", [3, 1]),
+        ("SamplePerfectTile", [8, 2]),
+        ("SampleCategorical", 1),
+        ("SampleComputeLocation", 9),
+    ]
+    mod = create_te_workload("GRP", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[grp_0, grp_1, grp_2],
+        expected_decisions=[decision_0, decision_1, decision_2],
+    )
+
+
 if __name__ == "__main__":
     test_cpu_c1d()
     test_cpu_c2d()
@@ -1209,3 +1383,4 @@ if __name__ == "__main__":
     test_cpu_dep()
     test_cpu_dil()
     test_cpu_gmm()
+    test_cpu_grp()
diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py 
b/tests/python/unittest/test_meta_schedule_space_cuda.py
index 3bf2666cdc..81281d5d38 100644
--- a/tests/python/unittest/test_meta_schedule_space_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_space_cuda.py
@@ -653,6 +653,95 @@ def test_cuda_gmm():
     )
 
 
+def test_cuda_grp():
+    # fmt: off
+    @T.prim_func
+    def grp_0(inputs: T.Buffer[(1, 56, 56, 64), "float32"], weight: 
T.Buffer[(3, 3, 16, 128), "float32"], conv2d_nhwc: T.Buffer[(1, 28, 28, 128), 
"float32"]) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "main", "tir.noalias": True})
+        # body
+        with T.block("root"):
+            T.reads()
+            T.writes()
+            T.block_attr({"meta_schedule.unroll_explicit":16})
+            conv2d_nhwc_local = T.alloc_buffer([1, 28, 28, 128], 
dtype="float32", scope="local")
+            PadInput_shared = T.alloc_buffer([1, 58, 58, 64], dtype="float32", 
scope="shared")
+            weight_shared = T.alloc_buffer([3, 3, 16, 128], dtype="float32", 
scope="shared")
+            for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(2, 
thread="blockIdx.x"):
+                for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(1, 
thread="vthread.x"):
+                    for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(112, 
thread="threadIdx.x"):
+                        for i4_0, i5_0, i6_0 in T.grid(3, 3, 1):
+                            for ax0_ax1_ax2_ax3_fused in T.serial(95040):
+                                with T.block("PadInput_shared"):
+                                    v0 = T.axis.spatial(1, 0)
+                                    v1 = T.axis.spatial(58, 
i0_0_i1_0_i2_0_i3_0_fused * 28 + i4_0 + ax0_ax1_ax2_ax3_fused % 95040 // 3520)
+                                    v2 = T.axis.spatial(58, i5_0 + 
ax0_ax1_ax2_ax3_fused % 3520 // 64)
+                                    v3 = T.axis.spatial(64, 
ax0_ax1_ax2_ax3_fused % 64)
+                                    T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
+                                    T.writes(PadInput_shared[v0, v1, v2, v3])
+                                    
T.block_attr({"meta_schedule.cooperative_fetch":2})
+                                    PadInput_shared[v0, v1, v2, v3] = 
T.if_then_else(1 <= v1 and v1 < 57 and 1 <= v2 and v2 < 57, inputs[v0, v1 - 1, 
v2 - 1, v3], T.float32(0), dtype="float32")
+                            for ax0_ax1_ax2_ax3_fused in T.serial(2048):
+                                with T.block("weight_shared"):
+                                    v0, v1 = T.axis.remap("SS", [i4_0, i5_0])
+                                    v2 = T.axis.spatial(16, 
ax0_ax1_ax2_ax3_fused // 128)
+                                    v3 = T.axis.spatial(128, 
ax0_ax1_ax2_ax3_fused % 128)
+                                    T.reads(weight[v0, v1, v2, v3])
+                                    T.writes(weight_shared[v0, v1, v2, v3])
+                                    
T.block_attr({"meta_schedule.cooperative_fetch":1})
+                                    weight_shared[v0, v1, v2, v3] = weight[v0, 
v1, v2, v3]
+                            for i4_1, i5_1, i6_1, i0_3, i1_3, i2_3, i3_3, 
i4_2, i5_2, i6_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 2, 1, 2, 1, 2, 1, 1, 
8, 1, 7, 4, 4):
+                                with T.block("conv2d_nhwc"):
+                                    n = T.axis.spatial(1, i0_3 + i0_4)
+                                    h = T.axis.spatial(28, 
i0_0_i1_0_i2_0_i3_0_fused * 14 + i1_3 * 7 + i1_4)
+                                    w = T.axis.spatial(28, 
i0_2_i1_2_i2_2_i3_2_fused // 16 * 4 + i2_3 * 4 + i2_4)
+                                    co = T.axis.spatial(128, 
i0_2_i1_2_i2_2_i3_2_fused % 16 * 8 + i3_3 * 4 + i3_4)
+                                    rh = T.axis.reduce(3, i4_0 + i4_1 + i4_2)
+                                    rw = T.axis.reduce(3, i5_2 + i5_0 + i5_1)
+                                    rc = T.axis.reduce(16, i6_0 * 16 + i6_1 * 
8 + i6_2)
+                                    T.reads(PadInput_shared[n, h * 2 + rh, w * 
2 + rw, co // 32 * 16 + rc], weight_shared[rh, rw, rc, co])
+                                    T.writes(conv2d_nhwc_local[n, h, w, co])
+                                    
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, 
"meta_schedule.thread_extent_low_inclusive":32, 
"meta_schedule.tiling_structure":"SSSRRSRS"})
+                                    with T.init():
+                                        conv2d_nhwc_local[n, h, w, co] = 
T.float32(0)
+                                    conv2d_nhwc_local[n, h, w, co] = 
conv2d_nhwc_local[n, h, w, co] + PadInput_shared[n, h * 2 + rh, w * 2 + rw, co 
// 32 * 16 + rc] * weight_shared[rh, rw, rc, co]
+                        for ax0, ax1, ax2, ax3 in T.grid(1, 14, 4, 8):
+                            with T.block("conv2d_nhwc_local"):
+                                v0 = T.axis.spatial(1, ax0)
+                                v1 = T.axis.spatial(28, 
i0_0_i1_0_i2_0_i3_0_fused * 14 + ax1)
+                                v2 = T.axis.spatial(28, 
i0_2_i1_2_i2_2_i3_2_fused // 16 * 4 + ax2)
+                                v3 = T.axis.spatial(128, 
i0_2_i1_2_i2_2_i3_2_fused % 16 * 8 + ax3)
+                                T.reads(conv2d_nhwc_local[v0, v1, v2, v3])
+                                T.writes(conv2d_nhwc[v0, v1, v2, v3])
+                                conv2d_nhwc[v0, v1, v2, v3] = 
conv2d_nhwc_local[v0, v1, v2, v3]
+    # fmt: on
+    decision_0 = [
+        ("SamplePerfectTile", [1, 1, 1, 1, 1]),
+        ("SamplePerfectTile", [2, 1, 1, 2, 7]),
+        ("SamplePerfectTile", [1, 1, 7, 1, 4]),
+        ("SamplePerfectTile", [1, 1, 16, 2, 4]),
+        ("SamplePerfectTile", [3, 1, 1]),
+        ("SamplePerfectTile", [3, 1, 1]),
+        ("SamplePerfectTile", [1, 2, 8]),
+        ("SampleCategorical", 1),
+        ("SampleCategorical", 0),
+        ("SampleCategorical", 1),
+    ]
+    mod = create_te_workload("GRP", 0)
+    actual = ms.TuneContext(
+        mod=mod,
+        target=_target(),
+        space_generator=ms.space_generator.PostOrderApply(),
+        sch_rules="default",
+    ).generate_design_space()
+    check_sketches(
+        mod,
+        sketches=actual,
+        expected_mods=[grp_0],
+        expected_decisions=[decision_0],
+    )
+
+
 if __name__ == "__main__":
     test_cuda_c1d()
     test_cuda_c2d()
@@ -661,3 +750,4 @@ if __name__ == "__main__":
     test_cuda_dep()
     test_cuda_dil()
     test_cuda_gmm()
+    test_cuda_grp()

Reply via email to