This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e1f93f361e Fix conflict parameter name promote_dtye in 
FP8ComputeLegalize (#18334)
e1f93f361e is described below

commit e1f93f361ed80fe8407f7463be503bab656edf42
Author: Qingchao Shen <[email protected]>
AuthorDate: Wed Sep 24 01:10:58 2025 +0800

    Fix conflict parameter name promote_dtye in FP8ComputeLegalize (#18334)
---
 include/tvm/tir/transform.h                      | 4 ++--
 python/tvm/tir/transform/transform.py            | 4 ++--
 src/tir/transforms/unsupported_dtype_legalize.cc | 4 ++--
 3 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index af59db3877..bf100dc49c 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize();
 /*!
  * \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
  *   before Ops, then add a cast back to fp8.
- * \param promote_dtype_str The data type used for type promotion, defaults to 
float16
+ * \param promote_dtype The data type used for type promotion, defaults to 
float16
  * \note Must be run after BindTarget, as it relies on target attributes for 
PrimFuncs
  * \return The pass.
  */
-TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16");
+TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16");
 
 /*!
  * \brief Legalize bf16 storage types to u16.
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index de11d30fbc..39105f21a2 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -244,7 +244,7 @@ def BF16ComputeLegalize():
     return _ffi_api.BF16ComputeLegalize()  # type: ignore
 
 
-def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
+def FP8ComputeLegalize(promote_dtype: str = "float32"):
     """Legalize fp8 compute Ops.
 
     Parameters
@@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
     fpass : tvm.transform.Pass
         The result pass
     """
-    return _ffi_api.FP8ComputeLegalize(promote_dtype_str)  # type: ignore
+    return _ffi_api.FP8ComputeLegalize(promote_dtype)  # type: ignore
 
 
 def BF16StorageLegalize():
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc 
b/src/tir/transforms/unsupported_dtype_legalize.cc
index ecdb9883d1..d35caa4db9 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef().def("tir.transform.BF16StorageLegalize", 
BF16StorageLegalize);
 }
 
-Pass FP8ComputeLegalize(ffi::String promote_dtype_str) {
+Pass FP8ComputeLegalize(ffi::String promote_dtype) {
   auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
     auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
     if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
       return f;
     }
-    return 
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f);
+    return 
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
 }

Reply via email to