This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cb08f0d57b [TIR][Driver] Use `BindTarget` to specify target for FP8 
legalization (#16767)
cb08f0d57b is described below

commit cb08f0d57b5098a6edadad18ee058523087d81f1
Author: Steven S. Lyubomirsky <slyubomir...@octoml.ai>
AuthorDate: Sun Mar 24 20:26:35 2024 -0400

    [TIR][Driver] Use `BindTarget` to specify target for FP8 legalization 
(#16767)
    
    * Do not pass target explicitly to FP8 legalization, use BindTarget instead
    
    * Lint: Remove unused import
    
    * Add comment on pass ordering
---
 include/tvm/tir/transform.h                            |  8 ++++----
 python/tvm/tir/transform/transform.py                  | 18 +++++-------------
 src/driver/driver_api.cc                               |  8 ++++----
 src/tir/transforms/unsupported_dtype_legalize.cc       |  6 ++++--
 .../tir-transform/test_tir_transform_fp8_legalize.py   | 15 ++++++++-------
 5 files changed, 25 insertions(+), 30 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index e219cc6846..98edbeaceb 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -398,7 +398,6 @@ TVM_DLL Pass ForceNarrowIndexToInt32();
 /*!
  * \brief Legalize bf16 compute Ops. Add a cast to fp32
  *   before Ops, then add a cast back to bf16.
- * \param target The target used for checking native bf16 support
  * \return The pass.
  */
 TVM_DLL Pass BF16ComputeLegalize();
@@ -406,11 +405,11 @@ TVM_DLL Pass BF16ComputeLegalize();
 /*!
  * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
  *   before Ops, then add a cast back to fp8.
- * \param target The target used for checking native fp8 support
  * \param promote_dtype_str The data type used for type promotion, defaults to 
float16
+ * \note Must be run after BindTarget, as it relies on target attributes for 
PrimFuncs
  * \return The pass.
  */
-TVM_DLL Pass FP8ComputeLegalize(Target target, String promote_dtype_str = 
"float16");
+TVM_DLL Pass FP8ComputeLegalize(String promote_dtype_str = "float16");
 
 /*!
  * \brief Legalize bf16 storage types to u16.
@@ -420,9 +419,10 @@ TVM_DLL Pass BF16StorageLegalize();
 
 /*!
  * \brief Legalize fp8 storage types to u8.
+ * \note Must be run after BindTarget, as it relies on target attributes for 
PrimFuncs
  * \return The pass.
  */
-TVM_DLL Pass FP8StorageLegalize(Target target);
+TVM_DLL Pass FP8StorageLegalize();
 
 /*!
  * \brief Inline calls to private functions
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 9f7f92dbed..c2022b9186 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -19,7 +19,7 @@
 
 
 import enum
-from typing import Any, Callable, Optional
+from typing import Callable, Optional
 
 from . import _ffi_api
 from . import function_pass as _fpass
@@ -323,7 +323,7 @@ def BF16ComputeLegalize():
     return _ffi_api.BF16ComputeLegalize()  # type: ignore
 
 
-def FP8ComputeLegalize(target: Any, promote_dtype_str: str = "float32"):
+def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
     """Legalize fp8 compute Ops.
 
     Parameters
@@ -331,15 +331,12 @@ def FP8ComputeLegalize(target: Any, promote_dtype_str: 
str = "float32"):
     promote_dtype : str
         The data type we promote fp8 to, options: float16/float32.
 
-    target : tvm.target.Target
-        The legalization target
-
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.FP8ComputeLegalize(target, promote_dtype_str)  # type: 
ignore
+    return _ffi_api.FP8ComputeLegalize(promote_dtype_str)  # type: ignore
 
 
 def BF16StorageLegalize():
@@ -353,20 +350,15 @@ def BF16StorageLegalize():
     return _ffi_api.BF16StorageLegalize()  # type: ignore
 
 
-def FP8StorageLegalize(target: Any):
+def FP8StorageLegalize():
     """Legalize fp8 storage types to u8.
 
-    Parameters
-    ----------
-    target : tvm.target.Target
-        The legalization target
-
     Returns
     -------
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.FP8StorageLegalize(target)  # type: ignore
+    return _ffi_api.FP8StorageLegalize()  # type: ignore
 
 
 def CommonSubexprElimTIR(enable_cse_tir: bool = True, identify_equiv_terms: 
bool = False):
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 33b4514e6b..7ea5032fa0 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -569,15 +569,15 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
 
   Array<Pass> mixed_pass_list;
 
-  mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize(target));
+  // FPComputeLegalize uses the target attrs added by BindTarget, so it must 
come first
+  mixed_pass_list.push_back(tir::transform::BindTarget(target));
+  mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());
 
   // VerifyVTCMLimit must occur before LowerVtcmAlloc
   mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
   // LowerVtcmAlloc must occur after any transformations that modify memory 
allocation locations
   mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
 
-  mixed_pass_list.push_back(tir::transform::BindTarget(target));
-
   mixed_pass_list.push_back(tir::transform::VerifyMemory());
 
   mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
@@ -620,7 +620,7 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
   } else {
     mixed_pass_list.push_back(tir::transform::MakePackedAPI());
   }
-  mixed_pass_list.push_back(tir::transform::FP8StorageLegalize(target));
+  mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
   mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
 
   mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc 
b/src/tir/transforms/unsupported_dtype_legalize.cc
index c037879074..5537c8a409 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -727,8 +727,9 @@ Pass BF16StorageLegalize() {
 
 
TVM_REGISTER_GLOBAL("tir.transform.BF16StorageLegalize").set_body_typed(BF16StorageLegalize);
 
-Pass FP8ComputeLegalize(Target target, String promote_dtype_str) {
+Pass FP8ComputeLegalize(String promote_dtype_str) {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
     if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
       return f;
     }
@@ -739,8 +740,9 @@ Pass FP8ComputeLegalize(Target target, String 
promote_dtype_str) {
 
 
TVM_REGISTER_GLOBAL("tir.transform.FP8ComputeLegalize").set_body_typed(FP8ComputeLegalize);
 
-Pass FP8StorageLegalize(Target target) {
+Pass FP8StorageLegalize() {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
     if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
       return f;
     }
diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py 
b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
index 6e44b53d0c..e1f487c572 100644
--- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_fp8_legalize.py
@@ -19,6 +19,7 @@ import tvm.script
 import tvm.testing
 from tvm.target import Target
 from tvm.script import tir as T
+from tvm.tir.transform.transform import BindTarget
 
 # pylint: disable=no-member,invalid-name,unused-variable
 
@@ -206,20 +207,20 @@ promote_dtype = tvm.testing.parameter("float16", 
"float32")
 
 def test_fp8_compute_legalize(dtype, promote_dtype):
     target = Target("cuda")
-    before = get_before(dtype)
-    expected = get_after_compute_legalize(dtype, promote_dtype)
+    before = BindTarget(target)(get_before(dtype))
+    expected = BindTarget(target)(get_after_compute_legalize(dtype, 
promote_dtype))
     # run the transform twice to ensure we can afford to deal
     # with this repeative optimizations
-    after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(before)
-    after = tvm.tir.transform.FP8ComputeLegalize(target, promote_dtype)(after)
+    after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before)
+    after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after)
     tvm.ir.assert_structural_equal(after, expected)
 
 
 def test_fp8_storage_legalize(dtype, promote_dtype):
     target = Target("cuda")
-    before = get_after_compute_legalize(dtype, promote_dtype)
-    after = tvm.tir.transform.FP8StorageLegalize(target)(before)
-    expected = get_after_storage_legalize(dtype, promote_dtype)
+    before = BindTarget(target)(get_after_compute_legalize(dtype, 
promote_dtype))
+    after = tvm.tir.transform.FP8StorageLegalize()(before)
+    expected = BindTarget(target)(get_after_storage_legalize(dtype, 
promote_dtype))
     tvm.ir.assert_structural_equal(after, expected)
 
 

Reply via email to