This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e1f93f361e Fix conflict parameter name promote_dtye in
FP8ComputeLegalize (#18334)
e1f93f361e is described below
commit e1f93f361ed80fe8407f7463be503bab656edf42
Author: Qingchao Shen <[email protected]>
AuthorDate: Wed Sep 24 01:10:58 2025 +0800
Fix conflict parameter name promote_dtye in FP8ComputeLegalize (#18334)
---
include/tvm/tir/transform.h | 4 ++--
python/tvm/tir/transform/transform.py | 4 ++--
src/tir/transforms/unsupported_dtype_legalize.cc | 4 ++--
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index af59db3877..bf100dc49c 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -357,11 +357,11 @@ TVM_DLL Pass BF16ComputeLegalize();
/*!
* \brief Legalize fp8 compute Ops. Add a cast to fp16/fp32
* before Ops, then add a cast back to fp8.
- * \param promote_dtype_str The data type used for type promotion, defaults to
float16
+ * \param promote_dtype The data type used for type promotion, defaults to
float16
* \note Must be run after BindTarget, as it relies on target attributes for
PrimFuncs
* \return The pass.
*/
-TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype_str = "float16");
+TVM_DLL Pass FP8ComputeLegalize(ffi::String promote_dtype = "float16");
/*!
* \brief Legalize bf16 storage types to u16.
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index de11d30fbc..39105f21a2 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -244,7 +244,7 @@ def BF16ComputeLegalize():
return _ffi_api.BF16ComputeLegalize() # type: ignore
-def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
+def FP8ComputeLegalize(promote_dtype: str = "float32"):
"""Legalize fp8 compute Ops.
Parameters
@@ -257,7 +257,7 @@ def FP8ComputeLegalize(promote_dtype_str: str = "float32"):
fpass : tvm.transform.Pass
The result pass
"""
- return _ffi_api.FP8ComputeLegalize(promote_dtype_str) # type: ignore
+ return _ffi_api.FP8ComputeLegalize(promote_dtype) # type: ignore
def BF16StorageLegalize():
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc
b/src/tir/transforms/unsupported_dtype_legalize.cc
index ecdb9883d1..d35caa4db9 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -780,13 +780,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("tir.transform.BF16StorageLegalize",
BF16StorageLegalize);
}
-Pass FP8ComputeLegalize(ffi::String promote_dtype_str) {
+Pass FP8ComputeLegalize(ffi::String promote_dtype) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_fp8")) {
return f;
}
- return
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype_str))).Legalize(f);
+ return
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {});
}