This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new ee319d9d23 [MetaSchedule] Integration test for CUDA AutoTensorization (#12142) ee319d9d23 is described below commit ee319d9d23c80091da9c4fb764b1e6d49d462714 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Thu Jul 28 02:03:27 2022 -0700 [MetaSchedule] Integration test for CUDA AutoTensorization (#12142) * [MetaSchedule] Integration test for CUDA AutoTensorization * cleanup * fix --- python/tvm/meta_schedule/default_config.py | 52 +++++++++++++++ src/meta_schedule/schedule_rule/auto_bind.cc | 3 + .../test_meta_schedule_auto_tensorize.py | 74 ++++++++++++++++++++-- 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/python/tvm/meta_schedule/default_config.py b/python/tvm/meta_schedule/default_config.py index e99dd1383a..dc021e1731 100644 --- a/python/tvm/meta_schedule/default_config.py +++ b/python/tvm/meta_schedule/default_config.py @@ -349,3 +349,55 @@ class _DefaultCUDA: M.MutateUnroll(): 0.08, M.MutateThreadBinding(): 0.02, } + + +class _DefaultCUDATensorCore: + """Default tuning configuration for CUDA TensorCore.""" + + @staticmethod + def schedule_rules(): + from tvm.meta_schedule import schedule_rule as M + from tvm.tir.tensor_intrin import get_wmma_intrin_group + + return [ + M.MultiLevelTilingTensorCore( + intrin_groups=[ + get_wmma_intrin_group( + store_scope="shared", + in_dtype="float16", + out_dtype="float16", + trans_b=trans_b, + ) + for trans_b in [False, True] + ], + structure="SSSRRSRS", + tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], + max_innermost_factor=4, + vector_load_lens=[1, 2, 3, 4], + reuse_read=M.ReuseType(req="must", levels=[4], scope="shared"), + reuse_write=M.ReuseType( + req="must", + levels=[2], + scope="shared", + ), + ), + *_DefaultCUDA.schedule_rules(), + ] + + @staticmethod + def postprocs() -> List[Postproc]: + from tvm.meta_schedule import postproc as M + + return [ + M.DisallowDynamicLoop(), + M.RewriteCooperativeFetch(), + M.RewriteUnboundBlock(), + M.RewriteParallelVectorizeUnroll(), + M.RewriteReductionBlock(), + M.RewriteTensorize(), + M.VerifyGPUCode(), + ] + + @staticmethod + def mutator_probs() -> Dict[Mutator, float]: + return _DefaultCUDA.mutator_probs() diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index a67432ebc5..ff4d26084e 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -34,6 +34,9 @@ void BindBlockThreadIdx(const tir::Schedule& sch, const tir::BlockRV& block_rv, if (block_sref->parent == nullptr) { return; } + if (tir::HasBeenMultiLevelTiled(block_sref)) { + return; + } Array<StmtSRef> loops = tir::GetLoops(block_sref); int n = loops.size(); int i_block_idx = -1; diff --git a/tests/python/integration/test_meta_schedule_auto_tensorize.py b/tests/python/integration/test_meta_schedule_auto_tensorize.py index b855dc6fa0..b1525df10e 100644 --- a/tests/python/integration/test_meta_schedule_auto_tensorize.py +++ b/tests/python/integration/test_meta_schedule_auto_tensorize.py @@ -27,6 +27,7 @@ from tvm import meta_schedule as ms from tvm import relay from tvm.meta_schedule import ApplyHistoryBest, postproc, schedule_rule from tvm.meta_schedule.relay_integration import extract_task_from_relay +from tvm.meta_schedule.testing import relay_workload from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base from tvm.meta_schedule.tune import tune_extracted_tasks from tvm.tir.tensor_intrin import AMDGPU_SDOT4_INTRIN, DP4A_INTRIN @@ -337,10 +338,71 @@ def test_dp4a_bert_int8(): # _test_bert_int8("rocm", sch_rules_for_sdot4, postprocs_for_dp4a) +@tvm.testing.requires_gpu +@pytest.mark.skip("Slow on CI") +@pytest.mark.parametrize( + ["model_name", "input_shape"], + [("bert_base", (8, 128)), ("resnet_18", (16, 3, 224, 224)), ("resnet_50", (16, 3, 224, 224))], +) +def test_cuda_tensor_core(model_name, input_shape): + """Integration tests of auto tensorization with CUDA tensor core""" + target = tvm.target.Target("nvidia/geforce-rtx-3070") + dev = tvm.cuda() + if model_name.startswith("bert"): + data = tvm.nd.array(np.random.randint(0, 30521, size=input_shape), dev) # embedding size + else: + data = tvm.nd.array(np.random.randn(*input_shape).astype("float32"), dev) + + mod, params, (input_name, _, _) = relay_workload.get_network(model_name, input_shape) + seq = tvm.transform.Sequential( + [ + relay.transform.ToMixedPrecision(), + ] + ) + + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + + def convert_layout(mod): + seq = tvm.transform.Sequential( + [relay.transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"]})] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod + + with tempfile.TemporaryDirectory() as work_dir: + with ms.Profiler() as profiler: + rt_mod1: tvm.runtime.Module = ms.tune_relay( + mod=convert_layout(mod), + params=params, + target=target, + config=ms.TuneConfig( + num_trials_per_iter=32, + max_trials_per_task=200, + max_trials_global=3000, + ), + sch_rules=ms.default_config._DefaultCUDATensorCore.schedule_rules, + postprocs=ms.default_config._DefaultCUDATensorCore.postprocs, + work_dir=work_dir, + ) + print(profiler.table()) + + # Compile without meta-scheduler for correctness check + with tvm.transform.PassContext(opt_level=0): + rt_mod2 = relay.build(mod, target=target, params=params) + + def get_output(data, lib): + module = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + module.set_input(input_name, data) + module.run() + return module.get_output(0).numpy() + + # Check correctness + actual_output = get_output(data, rt_mod1) + expected_output = get_output(data, rt_mod2) + assert np.allclose(actual_output, expected_output, rtol=1e-2, atol=2e-2) + + if __name__ == "__main__": - test_vnni_dense() - test_vnni_conv2d() - test_vnni_bert_int8() - test_dp4a_dense() - test_dp4a_conv2d() - test_dp4a_bert_int8() + tvm.testing.main()