This is an automated email from the ASF dual-hosted git repository. syfeng pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new ef784d68e0 [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764) ef784d68e0 is described below commit ef784d68e04ab4b858ce4c953b2d83b5d5811eda Author: Junru Shao <junrushao1...@gmail.com> AuthorDate: Tue Sep 13 02:20:30 2022 -0700 [MetaSchedule][Test] Migrate `check_trace` to `check_sketch` (#12764) * Migrate AutoBind * Migrate RandomComputeLocation * Migrate CrossThreadReduction * Migrate ParallelVectorizeUnroll --- python/tvm/meta_schedule/testing/schedule_rule.py | 48 +- .../test_meta_schedule_schedule_rule_auto_bind.py | 175 +++--- ...chedule_schedule_rule_cross_thread_reduction.py | 665 ++++++++++++++++----- ...dule_schedule_rule_parallel_vectorize_unroll.py | 111 ++-- ...hedule_schedule_rule_random_compute_location.py | 72 ++- 5 files changed, 718 insertions(+), 353 deletions(-) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index b08db0811d..12ca4200d7 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -18,28 +18,15 @@ from typing import List, Union from tvm.meta_schedule.schedule_rule import ( - AutoBind, AutoInline, - CrossThreadReduction, MultiLevelTiling, - ParallelizeVectorizeUnroll, - RandomComputeLocation, + MultiLevelTilingTensorCore, ReuseType, ScheduleRule, ) -from tvm.meta_schedule.schedule_rule.multi_level_tiling import ( - MultiLevelTilingTensorCore, -) from tvm.target import Target -def auto_bind(target: Target) -> ScheduleRule: - """Default schedule rules for auto bind""" - if target.kind.name == "cuda": - return AutoBind(max_threadblocks=256, thread_extents=[32, 64, 128, 256, 512, 1024]) - raise NotImplementedError(f"{target.kind.name} is not supported") - - def auto_inline(target: Target) -> ScheduleRule: """Default schedule rules for auto inline""" if target.kind.name == "llvm": @@ -65,13 +52,6 @@ def auto_inline(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") -def cross_thread_reduction(target: Target) -> ScheduleRule: - """Default schedule rules for with cross-thread reduction""" - if target.kind.name == "cuda": - return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) - raise NotImplementedError(f"{target.kind.name} is not supported") - - def multi_level_tiling(target: Target) -> ScheduleRule: """Default schedule rules for with multi-level tiling and reuse""" if target.kind.name == "llvm": @@ -154,29 +134,3 @@ def multi_level_tiling_tensor_core( use_software_pipeline=use_software_pipeline, ) raise NotImplementedError(f"{target.kind.name} is not supported") - - -def random_compute_location(target: Target) -> ScheduleRule: - """Default schedule rules for with random-compute-location""" - if target.kind.name == "llvm": - return RandomComputeLocation() - raise NotImplementedError(f"{target.kind.name} is not supported") - - -def parallel_vectorize_unroll(target: Target) -> ScheduleRule: - """Default schedule rules for with parallel-vectorize-unroll""" - if target.kind.name == "llvm": - return ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=32, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ) - if target.kind.name == "cuda": - return ParallelizeVectorizeUnroll( - max_jobs_per_core=-1, - max_vectorize_extent=-1, - unroll_max_steps=[0, 16, 64, 512, 1024], - unroll_explicit=True, - ) - raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index a89cca72e1..21ad04da47 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -15,10 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply -from tvm.meta_schedule.testing.schedule_rule import auto_bind -from tvm.meta_schedule.testing.space_generation import check_trace -from tvm.meta_schedule.tune_context import TuneContext +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.space_generation import check_sketches from tvm.script import tir as T from tvm.target import Target @@ -60,83 +58,120 @@ def zero_dim_add( C[()] = A[()] + B[()] -def _create_context(mod, target, rule) -> TuneContext: - ctx = TuneContext( - mod=mod, - target=target, - space_generator=PostOrderApply(), - sch_rules=[rule], - task_name="test", - ) - return ctx - - def test_cuda_element_wise(): - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "l1, l2 = sch.get_loops(block=b0)", - "l3 = sch.fuse(l1, l2, preserve_unit_iters=True)", - "v4 = sch.sample_categorical(candidates=[32, 64, 128, 256, 512, 1024], probs=[0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666, 0.16666666666666666])", - "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l5, thread_axis="blockIdx.x")', - 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - ] + @T.prim_func + def elementwise_0( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + ) -> None: + # body + # with T.block("root") + for i_j_fused_0 in T.thread_binding(256, thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): + with T.block("C"): + vi = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) // 512) + vj = T.axis.spatial(512, (i_j_fused_0 * 1024 + i_j_fused_1) % 512) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] + T.float32(1) + + decision_0 = [ + ("SampleCategorical", 5), ] - target = Target("nvidia/geforce-rtx-3080", host="llvm") - ctx = _create_context( - element_wise, - target=target, - rule=auto_bind(target=target), + mod = element_wise + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3080", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.AutoBind( + max_threadblocks=256, + thread_extents=[32, 64, 128, 256, 512, 1024], + ) + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[elementwise_0], + expected_decisions=[decision_0], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) def test_cuda_reduction_loop_only(): - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "l1, = sch.get_loops(block=b0)", - "l2 = sch.add_unit_loop(block_or_loop=l1)", - "l3 = sch.fuse(l2, preserve_unit_iters=True)", - "l4, l5 = sch.split(loop=l3, factors=[None, 1], preserve_unit_iters=True)", - 'sch.bind(loop=l4, thread_axis="blockIdx.x")', - 'sch.bind(loop=l5, thread_axis="threadIdx.x")', - ] - ] - target = Target("nvidia/geforce-rtx-3080", host="llvm") - ctx = _create_context( - reduction_loop_only, - target=target, - rule=auto_bind(target=target), + @T.prim_func + def reduction_loop_only_0( + A: T.Buffer[2, "float32"], + B: T.Buffer[2, "float32"], + C: T.Buffer[(), "float32"], + ) -> None: + for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + for i0 in T.serial(2): + with T.block("C"): + k0 = T.axis.reduce(2, i0) + T.reads(A[k0], B[k0]) + T.writes(C[()]) + with T.init(): + C[()] = T.float32(1) + C[()] = T.min(C[()], A[k0] / B[k0]) + + mod = reduction_loop_only + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3080", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.AutoBind( + max_threadblocks=256, + thread_extents=[32, 64, 128, 256, 512, 1024], + ) + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[reduction_loop_only_0], + expected_decisions=[[]], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) def test_cuda_zero_dim_add(): - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "l1 = sch.add_unit_loop(block_or_loop=b0)", - "l2 = sch.fuse(l1, preserve_unit_iters=True)", - "l3, l4 = sch.split(loop=l2, factors=[None, 1], preserve_unit_iters=True)", - 'sch.bind(loop=l3, thread_axis="blockIdx.x")', - 'sch.bind(loop=l4, thread_axis="threadIdx.x")', - ] - ] - target = Target("nvidia/geforce-rtx-3080", host="llvm") - ctx = _create_context( - zero_dim_add, - target=target, - rule=auto_bind(target=target), + @T.prim_func + def zero_dim_add_0( + A: T.Buffer[(), "float32"], + B: T.Buffer[(), "float32"], + C: T.Buffer[(), "float32"], + ) -> None: + for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + with T.block("C"): + vi = T.axis.spatial(1, 0) + T.reads(A[()], B[()]) + T.writes(C[()]) + C[()] = A[()] + B[()] + + mod = zero_dim_add + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3080", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.AutoBind( + max_threadblocks=256, + thread_extents=[32, 64, 128, 256, 512, 1024], + ) + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[zero_dim_add_0], + expected_decisions=[[]], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 592d32d624..a0ca47c09a 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -17,14 +17,12 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import cross_thread_reduction -from tvm.meta_schedule.testing.space_generation import check_trace -from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule.testing.space_generation import check_sketches from tvm.script import tir as T from tvm.target import Target -from tvm.te.operation import create_prim_func +from tvm.te import create_prim_func @tvm.script.ir_module @@ -59,179 +57,522 @@ class Softmax_mn_after_inline: ) -def _create_context(mod, target, rule) -> TuneContext: - ctx = TuneContext( - mod=mod, - target=target, - space_generator=PostOrderApply(), - sch_rules=[rule], - task_name="test", - ) - return ctx +def test_gpu_softmax_mn(): + @T.prim_func + def softmax_mn_0( + A: T.Buffer[(256, 256), "float32"], + T_softmax_norm: T.Buffer[(256, 256), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem[i0_1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_exp"): + i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2]) + T.writes(T_softmax_exp[i0_2, i1_1]) + T_softmax_exp[i0_2, i1_1] = T.exp( + A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_4, k = T.axis.remap("SR", [i0_3, i1]) + T.reads(T_softmax_exp[i0_4, k]) + T.writes(T_softmax_expsum[i0_4]) + with T.init(): + T_softmax_expsum[i0_4] = T.float32(0) + T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] + for i0_5, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) + T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6]) + T.writes(T_softmax_norm[i0_6, i1_2]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] + @T.prim_func + def softmax_mn_1( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + T.where(ax1_0 * 512 + ax1_1 < 256) + i0_1 = T.axis.spatial(256, ax0 + i0) + k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem_shared[i0_1]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_exp"): + T.where(i1_0 * 512 + i1_1 < 256) + i0_2 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) + T.reads(A[i0_2, i1], T_softmax_maxelem_shared[i0_2]) + T.writes(T_softmax_exp[i0_2, i1]) + T_softmax_exp[i0_2, i1] = T.exp( + A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_4, k = T.axis.remap("SR", [i0_3, i1]) + T.reads(T_softmax_exp[i0_4, k]) + T.writes(T_softmax_expsum[i0_4]) + with T.init(): + T_softmax_expsum[i0_4] = T.float32(0) + T_softmax_expsum[i0_4] = T_softmax_expsum[i0_4] + T_softmax_exp[i0_4, k] + for i0_5, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_6, i1_2 = T.axis.remap("SS", [i0_5, i1]) + T.reads(T_softmax_exp[i0_6, i1_2], T_softmax_expsum[i0_6]) + T.writes(T_softmax_norm[i0_6, i1_2]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_6, i1_2] = T_softmax_exp[i0_6, i1_2] / T_softmax_expsum[i0_6] -def test_gpu_softmax_mn(): - expected = [ - [], - [ - 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', - "b1, = sch.get_consumers(block=b0)", - "l2, l3 = sch.get_loops(block=b1)", - "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', - "l7, l8, l9 = sch.get_loops(block=b0)", - "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l11, thread_axis="threadIdx.x")', - ], - [ - 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', - "b1, = sch.get_consumers(block=b0)", - "l2, l3 = sch.get_loops(block=b1)", - "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', - "l7, l8, l9 = sch.get_loops(block=b0)", - "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l11, thread_axis="threadIdx.x")', - ], - [ - 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', - 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', - "b2, = sch.get_consumers(block=b1)", - "l3, l4 = sch.get_loops(block=b2)", - "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)", - 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', - "l8, l9, l10 = sch.get_loops(block=b1)", - "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)", - 'sch.bind(loop=l12, thread_axis="threadIdx.x")', - "b13, = sch.get_consumers(block=b0)", - "l14, l15 = sch.get_loops(block=b13)", - "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l17, l18 = sch.split(loop=l15, factors=[None, v16], preserve_unit_iters=True)", - 'sch.bind(loop=l18, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', - "l19, l20, l21 = sch.get_loops(block=b0)", - "l22, l23 = sch.split(loop=l21, factors=[None, v16], preserve_unit_iters=True)", - 'sch.bind(loop=l23, thread_axis="threadIdx.x")', - ], + @T.prim_func + def softmax_mn_2( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem[i0_1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_exp"): + i0_2, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[i0_2, i1_1], T_softmax_maxelem[i0_2]) + T.writes(T_softmax_exp[i0_2, i1_1]) + T_softmax_exp[i0_2, i1_1] = T.exp( + A[i0_2, i1_1] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i0_3 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 32): + for ax1_1 in T.thread_binding(8, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_4 = T.axis.spatial(256, ax0 + i0_3) + k = T.axis.reduce(256, ax1_0 * 8 + ax1_1) + T.reads(T_softmax_exp[i0_4, k]) + T.writes(T_softmax_expsum_shared[i0_4]) + with T.init(): + T_softmax_expsum_shared[i0_4] = T.float32(0) + T_softmax_expsum_shared[i0_4] = ( + T_softmax_expsum_shared[i0_4] + T_softmax_exp[i0_4, k] + ) + for i1_0 in T.serial(32): + for i1_1_1 in T.thread_binding(8, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_5 = T.axis.spatial(256, i0_3) + i1 = T.axis.spatial(256, i1_0 * 8 + i1_1_1) + T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5]) + T.writes(T_softmax_norm[i0_5, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_5, i1] = ( + T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5] + ) + + @T.prim_func + def softmax_mn_3( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_exp = T.alloc_buffer([256, 256], dtype="float32") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + T.where(ax1_0 * 512 + ax1_1 < 256) + i0_1 = T.axis.spatial(256, ax0 + i0) + k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem_shared[i0_1]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_exp"): + T.where(i1_0 * 512 + i1_1 < 256) + i0_2 = T.axis.spatial(256, i0) + i1 = T.axis.spatial(256, i1_0 * 512 + i1_1) + T.reads(A[i0_2, i1], T_softmax_maxelem_shared[i0_2]) + T.writes(T_softmax_exp[i0_2, i1]) + T_softmax_exp[i0_2, i1] = T.exp( + A[i0_2, i1] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i0_3 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 32): + for ax1_1 in T.thread_binding(8, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + i0_4 = T.axis.spatial(256, ax0 + i0_3) + k = T.axis.reduce(256, ax1_0 * 8 + ax1_1) + T.reads(T_softmax_exp[i0_4, k]) + T.writes(T_softmax_expsum_shared[i0_4]) + with T.init(): + T_softmax_expsum_shared[i0_4] = T.float32(0) + T_softmax_expsum_shared[i0_4] = ( + T_softmax_expsum_shared[i0_4] + T_softmax_exp[i0_4, k] + ) + for i1_0 in T.serial(32): + for i1_1 in T.thread_binding(8, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + i0_5 = T.axis.spatial(256, i0_3) + i1 = T.axis.spatial(256, i1_0 * 8 + i1_1) + T.reads(T_softmax_exp[i0_5, i1], T_softmax_expsum_shared[i0_5]) + T.writes(T_softmax_norm[i0_5, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_5, i1] = ( + T_softmax_exp[i0_5, i1] / T_softmax_expsum_shared[i0_5] + ) + + decision_0 = [] # type: ignore + decision_1 = [ + ("SampleCategorical", 7), + ] + decision_2 = [ + ("SampleCategorical", 1), + ] + decision_3 = [ + ("SampleCategorical", 1), + ("SampleCategorical", 7), ] - target = Target("nvidia/geforce-rtx-3090", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.softmax_mn( - n=256, - m=256, - ) - ), - target=target, - rule=cross_thread_reduction(target=target), + mod = create_prim_func(te_workload.softmax_mn(n=256, m=256)) + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3090", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[softmax_mn_0, softmax_mn_1, softmax_mn_2, softmax_mn_3], + expected_decisions=[decision_0, decision_1, decision_2, decision_3], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 4 - check_trace(spaces, expected) def test_gpu_softmax_mn_after_inline(): - expected = [ - [], - [ - 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', - "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l2, l3 = sch.get_loops(block=b0)", - "l4, l5 = sch.split(loop=l3, factors=[None, v1], preserve_unit_iters=True)", - 'sch.bind(loop=l5, thread_axis="threadIdx.x")', - ], - [ - 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', - "b1, = sch.get_consumers(block=b0)", - "l2, l3 = sch.get_loops(block=b1)", - "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l5, l6 = sch.split(loop=l3, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l6, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', - "l7, l8, l9 = sch.get_loops(block=b0)", - "l10, l11 = sch.split(loop=l9, factors=[None, v4], preserve_unit_iters=True)", - 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + @T.prim_func + def softmax_mn_after_inline_0( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem[i0_1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + T.reads(A[i0_2, k], T_softmax_maxelem[i0_2]) + T.writes(T_softmax_expsum[i0_2]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4]) + T.writes(T_softmax_norm[i0_4, i1_1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_4, i1_1] = ( + T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") + / T_softmax_expsum[i0_4] + ) + + @T.prim_func + def softmax_mn_after_inline_1( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1_0 in T.grid(256, 4): + for i1_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + i0_1 = T.axis.spatial(256, i0) + k = T.axis.reduce(256, i1_0 * 64 + i1_1) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem[i0_1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + T.reads(A[i0_2, k], T_softmax_maxelem[i0_2]) + T.writes(T_softmax_expsum[i0_2]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T.reads(A[i0_4, i1_1], T_softmax_maxelem[i0_4], T_softmax_expsum[i0_4]) + T.writes(T_softmax_norm[i0_4, i1_1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_4, i1_1] = ( + T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") + / T_softmax_expsum[i0_4] + ) + + @T.prim_func + def softmax_mn_after_inline_2( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem[i0_1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0_3 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + T.where(ax1_0 * 512 + ax1_1 < 256) + i0_2 = T.axis.spatial(256, ax0 + i0_3) + k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + T.reads(A[i0_2, k], T_softmax_maxelem[i0_2]) + T.writes(T_softmax_expsum_shared[i0_2]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + T.where(i1_0 * 512 + i1_1 < 256) + i0_4 = T.axis.spatial(256, i0_3) + i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1) + T.reads( + A[i0_4, i1_1_1], T_softmax_maxelem[i0_4], T_softmax_expsum_shared[i0_4] + ) + T.writes(T_softmax_norm[i0_4, i1_1_1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_4, i1_1_1] = ( + T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem[i0_4], dtype="float32") + / T_softmax_expsum_shared[i0_4] + ) + + @T.prim_func + def softmax_mn_after_inline_3( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + T_softmax_maxelem_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + T_softmax_expsum_shared = T.alloc_buffer([256], dtype="float32", scope="shared") + for i0_3 in T.serial(256): + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_maxelem"): + T.where(ax1_0 * 512 + ax1_1 < 256) + i0_1 = T.axis.spatial(256, ax0 + i0_3) + k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + T.reads(A[i0_1, k]) + T.writes(T_softmax_maxelem_shared[i0_1]) + with T.init(): + T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38) + T_softmax_maxelem_shared[i0_1] = T.max( + T_softmax_maxelem_shared[i0_1], A[i0_1, k] + ) + for ax0, ax1_0 in T.grid(1, 1): + for ax1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_expsum"): + T.where(ax1_0 * 512 + ax1_1 < 256) + i0_2 = T.axis.spatial(256, ax0 + i0_3) + k = T.axis.reduce(256, ax1_0 * 512 + ax1_1) + T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2]) + T.writes(T_softmax_expsum_shared[i0_2]) + with T.init(): + T_softmax_expsum_shared[i0_2] = T.float32(0) + T_softmax_expsum_shared[i0_2] = T_softmax_expsum_shared[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem_shared[i0_2], dtype="float32" + ) + for i1_0 in T.serial(1): + for i1_1 in T.thread_binding(512, thread="threadIdx.x"): + with T.block("T_softmax_norm"): + T.where(i1_0 * 512 + i1_1 < 256) + i0_4 = T.axis.spatial(256, i0_3) + i1_1_1 = T.axis.spatial(256, i1_0 * 512 + i1_1) + T.reads( + A[i0_4, i1_1_1], + T_softmax_maxelem_shared[i0_4], + T_softmax_expsum_shared[i0_4], + ) + T.writes(T_softmax_norm[i0_4, i1_1_1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_4, i1_1_1] = ( + T.exp(A[i0_4, i1_1_1] - T_softmax_maxelem_shared[i0_4], dtype="float32") + / T_softmax_expsum_shared[i0_4] + ) + + decision_0 = [] # type: ignore + decision_1 = [ + ("SampleCategorical", 4), + ] + decision_2 = [ + ("SampleCategorical", 7), + ] + decision_3 = [ + ("SampleCategorical", 7), + ("SampleCategorical", 0), + ] + + mod = Softmax_mn_after_inline + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3090", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) ], - [ - 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', - 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', - "b2, = sch.get_consumers(block=b1)", - "l3, l4 = sch.get_loops(block=b2)", - "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l6, l7 = sch.split(loop=l4, factors=[None, v5], preserve_unit_iters=True)", - 'sch.bind(loop=l7, thread_axis="threadIdx.x")', - "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', - "l8, l9, l10 = sch.get_loops(block=b1)", - "l11, l12 = sch.split(loop=l10, factors=[None, v5], preserve_unit_iters=True)", - 'sch.bind(loop=l12, thread_axis="threadIdx.x")', - "b13, b14 = sch.get_consumers(block=b0)", - "l15, l16, l17, l18 = sch.get_loops(block=b13)", - "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', - "l19, l20, l21 = sch.get_loops(block=b0)", - "l22, l23 = sch.split(loop=l21, factors=[None, v5], preserve_unit_iters=True)", - 'sch.bind(loop=l23, thread_axis="threadIdx.x")', + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[ + softmax_mn_after_inline_0, + softmax_mn_after_inline_1, + softmax_mn_after_inline_2, + softmax_mn_after_inline_3, ], - ] - target = Target("nvidia/geforce-rtx-3090", host="llvm") - ctx = _create_context( - mod=Softmax_mn_after_inline, - target=target, - rule=cross_thread_reduction(target=target), + expected_decisions=[decision_0, decision_1, decision_2, decision_3], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 4 - check_trace(spaces, expected) def test_gpu_batch_norm_bmn(): - expected = [ - [], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "b1, = sch.get_consumers(block=b0)", - "l2, = sch.get_loops(block=b1)", - "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", - "l4, l5 = sch.split(loop=l2, factors=[None, v3], preserve_unit_iters=True)", - 'sch.bind(loop=l5, thread_axis="threadIdx.x")', - "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=True, index=-1)", - 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', - "l6, l7, l8, l9 = sch.get_loops(block=b0)", - "l10 = sch.fuse(l8, l9, preserve_unit_iters=True)", - "l11, l12 = sch.split(loop=l10, factors=[None, v3], preserve_unit_iters=True)", - 'sch.bind(loop=l12, thread_axis="threadIdx.x")', - ], + @T.prim_func + def batch_norm_bmn_0(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([1], dtype="float32") + for i0, i1, i2 in T.grid(1, 512, 512): + with T.block("C"): + b, i, j = T.axis.remap("SRR", [i0, i1, i2]) + T.reads(A[b, i, j]) + T.writes(C[b]) + with T.init(): + C[b] = T.float32(0) + C[b] = C[b] + A[b, i, j] * A[b, i, j] + for i0 in T.serial(1): + with T.block("D"): + b = T.axis.spatial(1, i0) + T.reads(C[b]) + T.writes(D[b]) + D[b] = T.sqrt(C[b], dtype="float32") + + @T.prim_func + def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_shared = T.alloc_buffer([1], dtype="float32", scope="shared") + for i0_0 in T.serial(1): + for ax0, ax1_ax2_fused_0 in T.grid(1, 1024): + for ax1_ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("C"): + b = T.axis.spatial(1, ax0) + i = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) // 512) + j = T.axis.reduce(512, (ax1_ax2_fused_0 * 256 + ax1_ax2_fused_1) % 512) + T.reads(A[b, i, j]) + T.writes(C_shared[b]) + with T.init(): + C_shared[b] = T.float32(0) + C_shared[b] = C_shared[b] + A[b, i, j] * A[b, i, j] + for i0_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("D"): + T.where(i0_0 * 256 + i0_1 < 1) + b = T.axis.spatial(1, i0_0 * 256 + i0_1) + T.reads(C_shared[b]) + T.writes(D[b]) + D[b] = T.sqrt(C_shared[b], dtype="float32") + + decision_0 = [] # type: ignore + decision_1 = [ + ("SampleCategorical", 6), ] - target = Target("nvidia/geforce-rtx-3090", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.norm_bmn( - B=1, - M=512, - N=512, - ) - ), - target=target, - rule=cross_thread_reduction(target=target), + + mod = create_prim_func(te_workload.norm_bmn(B=1, M=512, N=512)) + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3090", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[batch_norm_bmn_0, batch_norm_bmn_1], + expected_decisions=[decision_0, decision_1], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 2 - check_trace(spaces, expected) if __name__ == "__main__": - # test_gpu_softmax_mn() - # test_gpu_softmax_mn_after_inline() + test_gpu_softmax_mn() + test_gpu_softmax_mn_after_inline() test_gpu_batch_norm_bmn() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index 02b55350b7..8076fcaa8b 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -17,10 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply -from tvm.meta_schedule.testing.schedule_rule import parallel_vectorize_unroll -from tvm.meta_schedule.testing.space_generation import check_trace -from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule.testing.space_generation import check_sketches from tvm.script import tir as T from tvm.target import Target @@ -68,10 +65,7 @@ class ParallelizeVectorizeUnroll: class PureSpatial: @T.prim_func def main(placeholder: T.Buffer[(1, 13, 13, 3, 85), "float32"], placeholder_1: T.Buffer[(1, 26, 26, 3, 85), "float32"], placeholder_2: T.Buffer[(1, 52, 52, 3, 85), "float32"], T_expand_dims: T.Buffer[(1, 80, 10647), "float32"]) -> None: - # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") T_strided_slice_with_axes = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_sigmoid = T.alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_strided_slice_with_axes_1 = T.alloc_buffer([1, 52, 52, 3, 80], dtype="float32") @@ -224,55 +218,80 @@ class PureSpatial: # fmt: on -def _create_context(mod, target, rule): - ctx = TuneContext( - mod=mod, - target=target, - space_generator=PostOrderApply(), - sch_rules=[rule], - task_name="test", - ) - return ctx - - def test_parallel_vectorize_unroll(): - expected = [ - [ - 'b0 = sch.get_block(name="root", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.parallel", ann_val=512)', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.vectorize", ann_val=32)', - "v1 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25])", - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.unroll_explicit", ann_val=v1)', - ] + @T.prim_func + def Matmul_0( + A: T.Buffer[(1024, 1024), "float32"], + B: T.Buffer[(1024, 1024), "float32"], + C: T.Buffer[(1024, 1024), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr( + { + "meta_schedule.parallel": 512, + "meta_schedule.unroll_explicit": 16, + "meta_schedule.vectorize": 32, + } + ) + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + T.reads(A[vi, vk], B[vk, vj]) + T.writes(C[vi, vj]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + decision_0 = [ + ("SampleCategorical", 1), ] + mod = Matmul - target = Target("llvm --num-cores=32") - ctx = _create_context( + actual = ms.TuneContext( mod=mod, - target=target, - rule=parallel_vectorize_unroll(target=target), + target=Target("llvm --num-cores=32"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=32, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[Matmul_0], + expected_decisions=[decision_0], ) - spaces = ctx.space_generator.generate_design_space(mod=mod) - assert len(spaces) == 1 - check_trace(spaces, expected) def test_parallel_vectorize_unroll_spatial(): mod = PureSpatial - target = Target("llvm --num-cores=32") - ctx = _create_context( + actual = ms.TuneContext( mod=mod, - target=target, - rule=ms.schedule_rule.ParallelizeVectorizeUnroll( - max_jobs_per_core=-1, - max_vectorize_extent=-1, - unroll_max_steps=[1, 2, 4, 8, 16, 32, 64], - unroll_explicit=True, - ), - ) - spaces = ctx.space_generator.generate_design_space(mod=mod) - assert len(spaces) == 1 - trace = spaces[0].trace.simplified(remove_postproc=True) + target=Target("llvm --num-cores=32"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, + max_vectorize_extent=-1, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ], + task_name="test", + ).generate_design_space() + assert len(actual) == 1 + trace = actual[0].trace.simplified(remove_postproc=True) assert not trace.insts diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py index c951a5adf3..fc52aa199c 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_random_compute_location.py @@ -16,10 +16,8 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -from tvm.meta_schedule.schedule_rule import RandomComputeLocation -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply -from tvm.meta_schedule.testing.space_generation import check_trace -from tvm.meta_schedule.tune_context import TuneContext +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.space_generation import check_sketches from tvm.script import tir as T from tvm.target import Target @@ -55,35 +53,53 @@ class Add: # fmt: on -def _create_context(mod, target, rule): - ctx = TuneContext( - mod=mod, - target=target, - space_generator=PostOrderApply(), - sch_rules=[rule], - task_name="test", - ) - return ctx - - def test_random_compute_location(): - expected = [ - [ - 'b0 = sch.get_block(name="move", func_name="main")', - "l1 = sch.sample_compute_location(block=b0)", - "sch.compute_at(block=b0, loop=l1, preserve_unit_loops=True, index=-1)", - ] + @T.prim_func + def add_0( + A: T.Buffer[(2048, 2048, 2048), "float32"], + B: T.Buffer[(2048, 2048, 2048), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + # body + # with T.block("root") + A_cached = T.alloc_buffer([2048, 2048, 2048], dtype="float32") + for i0, j0, i1, j1, k0, i2 in T.grid(128, 64, 4, 4, 64, 4): + for ax0, ax1, ax2 in T.grid(1, 8, 32): + with T.block("move"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2 + ax0) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + ax1) + vk = T.axis.spatial(2048, k0 * 32 + ax2) + T.reads(A[vi, vj, vk]) + T.writes(A_cached[vi, vj, vk]) + A_cached[vi, vj, vk] = A[vi, vj, vk] + for j2, k1 in T.grid(8, 32): + with T.block("add"): + vi = T.axis.spatial(2048, i0 * 16 + i1 * 4 + i2) + vj = T.axis.spatial(2048, j0 * 32 + j1 * 8 + j2) + vk = T.axis.spatial(2048, k0 * 32 + k1) + T.reads(A_cached[vi, vj, vk]) + T.writes(B[vi, vj, vk]) + B[vi, vj, vk] = A_cached[vi, vj, vk] + T.float32(1) + + decision_0 = [ + ("SampleComputeLocation", 5), ] + mod = Add - target = Target("llvm") - ctx = _create_context( + actual = ms.TuneContext( mod=mod, - target=target, - rule=RandomComputeLocation(), + target=Target("llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ms.schedule_rule.RandomComputeLocation()], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[add_0], + expected_decisions=[decision_0], ) - spaces = ctx.space_generator.generate_design_space(mod=mod) - assert len(spaces) == 1 - check_trace(spaces, expected) if __name__ == "__main__":