This is an automated email from the ASF dual-hosted git repository. yongwww pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new cb08f0d57b [TIR][Driver] Use `BindTarget` to specify target for FP8 legalization (#16767) cb08f0d57b is described below commit cb08f0d57b5098a6edadad18ee058523087d81f1 Author: Steven S. Lyubomirsky <slyubomir...@octoml.ai> AuthorDate: Sun Mar 24 20:26:35 2024 -0400 [TIR][Driver] Use `BindTarget` to specify target for FP8 legalization (#16767) * Do not pass target explicitly to FP8 legalization, use BindTarget instead * Lint: Remove unused import * Add comment on pass ordering --- include/tvm/tir/transform.h | 8 ++++---- python/tvm/tir/transform/transform.py | 18 +++++------------- src/driver/driver_api.cc | 8 ++++---- src/tir/transforms/unsupported_dtype_legalize.cc | 6 ++++-- .../tir-transform/test_tir_transform_fp8_legalize.py | 15 ++++++++------- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e219cc6846..98edbeaceb 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -398,7 +398,6 @@ TVM_DLL Pass ForceNarrowIndexToInt32(); /*! * \brief Legalize bf16 compute Ops. Add a cast to fp32 * before Ops, then add a cast back to bf16. - * \param target The target used for checking native bf16 support * \return The pass. */ TVM_DLL Pass BF16ComputeLegalize(); @@ -406,11 +405,11 @@ TVM_DLL Pass BF16ComputeLegalize(); /*! * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32 * before Ops, then add a cast back to fp8. - * \param target The target used for checking native fp8 support * \param promote_dtype_str The data type used for type promotion, defaults to float16 + * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = "float16"); +TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16"); /*! * \brief Legalize bf16 storage types to u16. @@ -420,9 +419,10 @@ TVM_DLL Pass BF16StorageLegalize(); /*! * \brief Legalize fp8 storage types to u8. + * \note Must be run after BindTarget, as it relies on target attributes for PrimFuncs * \return The pass. */ -TVM_DLL Pass FP8StorageLegalize(Target target); +TVM_DLL Pass FP8StorageLegalize(); /*! * \brief Inline calls to private functions diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 9f7f92dbed..c2022b9186 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -19,7 +19,7 @@ import enum -from typing import Any, Callable, Optional +from typing import Callable, Optional from . import _ffi_api from . import function_pass as _fpass @@ -323,7 +323,7 @@ def BF16ComputeLegalize(): return _ffi_api.BF16ComputeLegalize() # type: ignore -def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"): +def FP8ComputeLegalize(promote_dtype_str: str = "float32"): """Legalize fp8 compute Ops. Parameters @@ -331,15 +331,12 @@ def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"): promote_dtype : str The data type we promote fp8 to, options: float16/float32. - target : tvm.target.Target - The legalization target - Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str) # type: ignore + return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore def BF16StorageLegalize(): @@ -353,20 +350,15 @@ def BF16StorageLegalize(): return _ffi_api.BF16StorageLegalize() # type: ignore -def FP8StorageLegalize(target: Any): +def FP8StorageLegalize(): """Legalize fp8 storage types to u8. - Parameters - ---------- - target : tvm.target.Target - The legalization target - Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.FP8StorageLegalize(target) # type: ignore + return _ffi_api.FP8StorageLegalize() # type: ignore def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: bool = False): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 33b4514e6b..7ea5032fa0 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -569,15 +569,15 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) Array<Pass> mixed_pass_list; - mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target)); + // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first + mixed_pass_list.push_back(tir::transform::BindTarget(target)); + mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); // VerifyVTCMLimit must occur before LowerVtcmAlloc mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - mixed_pass_list.push_back(tir::transform::VerifyMemory()); mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); @@ -620,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) } else { mixed_pass_list.push_back(tir::transform::MakePackedAPI()); } - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target)); + mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index c037879074..5537c8a409 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -727,8 +727,9 @@ Pass BF16StorageLegalize() { TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize); -Pass FP8ComputeLegalize(Target target, String promote_dtype_str) { +Pass FP8ComputeLegalize(String promote_dtype_str) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr<Target>(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } @@ -739,8 +740,9 @@ Pass FP8ComputeLegalize(Target target, String promote_dtype_str) { TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize); -Pass FP8StorageLegalize(Target target) { +Pass FP8StorageLegalize() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto target = f->GetAttr<Target>(tvm::attr::kTarget).value(); if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) { return f; } diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py index 6e44b53d0c..e1f487c572 100644 --- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py +++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py @@ -19,6 +19,7 @@ import tvm.script import tvm.testing from tvm.target import Target from tvm.script import tir as T +from tvm.tir.transform.transform import BindTarget # pylint: disable=no-member,invalid-name,unused-variable @@ -206,20 +207,20 @@ promote_dtype = tvm.testing.parameter("float16", "float32") def test_fp8_compute_legalize(dtype, promote_dtype): target = Target("cuda") - before = get_before(dtype) - expected = get_after_compute_legalize(dtype, promote_dtype) + before = BindTarget(target)(get_before(dtype)) + expected = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) # run the transform twice to ensure we can afford to deal # with this repeative optimizations - after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before) - after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after) + after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before) + after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after) tvm.ir.assert_structural_equal(after, expected) def test_fp8_storage_legalize(dtype, promote_dtype): target = Target("cuda") - before = get_after_compute_legalize(dtype, promote_dtype) - after = tvm.tir.transform.FP8StorageLegalize(target)(before) - expected = get_after_storage_legalize(dtype, promote_dtype) + before = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) + after = tvm.tir.transform.FP8StorageLegalize()(before) + expected = BindTarget(target)(get_after_storage_legalize(dtype, promote_dtype)) tvm.ir.assert_structural_equal(after, expected)