This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 48f346bb07 [RFC][CodeGen][CUDA]: Gate fast math intrinsic lowering
behind target option (#19565)
48f346bb07 is described below
commit 48f346bb0720185dd907121b3b86f49131c101bc
Author: ConvolutedDog <[email protected]>
AuthorDate: Tue May 19 10:32:37 2026 +0800
[RFC][CodeGen][CUDA]: Gate fast math intrinsic lowering behind target
option (#19565)
Fix CUDA lowering of standard TIR math intrinsics so they use precise
CUDA math functions by default instead of fast-math `__*f` functions.
This fixes the default behavior reported in #19546, where operators such
as `tirx.exp` could lower to `__expf` even though fast math was not
explicitly requested.
This change adds a CUDA target attribute, `enable_fast_math`, which
defaults to `false`. When the attribute is unset or false, standard math
intrinsics lower through the normal CUDA math rule, for example `expf`,
`logf`, `sinf`, `cosf`, `powf`, and `rsqrtf` for `float32`. When users
explicitly enable the attribute on the target, the lowering pass also
checks the `cuda.fastmath.FLowerIntrinsic` rules before the normal CUDA
lowering rules.
Users can opt in to fast math by constructing a CUDA target with the
attribute:
```py
tvm.target.Target({"kind": "cuda", "enable_fast_math": True})
target = tvm.target.Target({
"tag": "nvidia/nvidia-a100",
"enable_fast_math": True,
})
```
The fast-math lowering path currently covers the CUDA math operators
registered with `cuda.fastmath.FLowerIntrinsic`: `tirx.exp`,
`tirx.exp10`, `tirx.log`, `tirx.log2`, `tirx.log10`, `tirx.tan`,
`tirx.cos`, `tirx.sin`, `tirx.tanh`, and `tirx.pow`.
`tirx.rsqrt` is also registered for CUDA lowering so it maps to the CUDA
reciprocal-square-root intrinsic instead of being legalized as `1 /
sqrt(x)`.
Add CUDA codegen tests
`tests/python/codegen/test_target_codegen_cuda_fastmath.py` that check
the lowered IR, generated CUDA source, and runtime results for the
supported math intrinsics across floating point dtypes and both default
and fast-math targets.
---
python/tvm/target/detect_target.py | 1 +
python/tvm/target/tag_registry/cuda.py | 5 +-
src/target/cuda/intrin_rule_cuda.cc | 30 ++-
src/target/target_kind.cc | 9 +
src/tirx/transform/lower_intrin.cc | 20 +-
.../codegen/test_target_codegen_cuda_fastmath.py | 298 +++++++++++++++++++++
tests/python/relax/test_frontend_onnx.py | 2 +-
tests/python/relax/test_frontend_onnx_backend.py | 4 +-
tests/python/target/test_target_target.py | 14 +-
9 files changed, 364 insertions(+), 19 deletions(-)
diff --git a/python/tvm/target/detect_target.py
b/python/tvm/target/detect_target.py
index 81accfed12..f7d79ba434 100644
--- a/python/tvm/target/detect_target.py
+++ b/python/tvm/target/detect_target.py
@@ -41,6 +41,7 @@ def _detect_cuda(dev: Device) -> Target:
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
"arch": "sm_" + dev.compute_version.replace(".", ""),
+ "enable_fast_math": False,
}
)
diff --git a/python/tvm/target/tag_registry/cuda.py
b/python/tvm/target/tag_registry/cuda.py
index 6b1bd9e8a8..d3740cb515 100644
--- a/python/tvm/target/tag_registry/cuda.py
+++ b/python/tvm/target/tag_registry/cuda.py
@@ -28,12 +28,14 @@ def _register_cuda_tag(name, arch, shared_mem=49152,
regs=65536, **extra):
"max_threads_per_block": 1024,
"thread_warp_size": 32,
"registers_per_block": regs,
+ # Default to disable fast math
+ "enable_fast_math": False,
}
config.update(extra)
register_tag(name, config)
-def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536):
+def _register_jetson_tag(name, arch, mcpu, num_cores, regs=65536,
enable_fast_math=False):
register_tag(
name,
{
@@ -49,6 +51,7 @@ def _register_jetson_tag(name, arch, mcpu, num_cores,
regs=65536):
"mcpu": mcpu,
"num-cores": num_cores,
},
+ "enable_fast_math": enable_fast_math,
},
)
diff --git a/src/target/cuda/intrin_rule_cuda.cc
b/src/target/cuda/intrin_rule_cuda.cc
index 39d01cf1b0..89bc154365 100644
--- a/src/target/cuda/intrin_rule_cuda.cc
+++ b/src/target/cuda/intrin_rule_cuda.cc
@@ -176,37 +176,46 @@ TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.exp")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.exp2")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.exp10")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.erf")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.log")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.log2")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.log10")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.tan")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMathTan>);
+ // Now the fast math version of tan and the default version of tan are
same.
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMathTan>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.cos")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.cosh")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.sin")
- .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>);
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.sinh")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
@@ -215,12 +224,17 @@ TVM_REGISTER_OP("tirx.atan")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.tanh")
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.sqrt")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
+TVM_REGISTER_OP("tirx.rsqrt")
+ .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
+
TVM_REGISTER_OP("tirx.pow")
+ .set_attr<FLowerIntrinsic>("cuda.fastmath.FLowerIntrinsic",
DispatchPureExtern<CUDAFastMath>)
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
DispatchPureExtern<CUDAMath>);
TVM_REGISTER_OP("tirx.popcount")
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 5779b4da0e..d6a8d30c4f 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -188,6 +188,14 @@ ffi::Map<ffi::String, ffi::Any>
UpdateCUDAAttrs(ffi::Map<ffi::String, ffi::Any>
target.Set("arch", ffi::String("sm_") + std::to_string(archInt));
}
}
+ // Update enable_fast_math
+ if (target.count("enable_fast_math")) {
+ // If enable_fast_math has been specified, validate that enable_fast_math
is a bool
+ Downcast<bool>(target.at("enable_fast_math"));
+ } else {
+ // If enable_fast_math has not been specified, default to false
+ target.Set("enable_fast_math", false);
+ }
return target;
}
@@ -372,6 +380,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<int64_t>("l2_cache_size_bytes")
.add_attr_option<int64_t>("max_num_threads",
refl::DefaultValue(1024)) // TODO(@zxybazh):
deprecate it
+ .add_attr_option<bool>("enable_fast_math")
.set_default_keys({"cuda", "gpu"})
.set_target_canonicalizer(UpdateCUDAAttrs);
diff --git a/src/tirx/transform/lower_intrin.cc
b/src/tirx/transform/lower_intrin.cc
index 981615b0d1..7f4b1aa30b 100644
--- a/src/tirx/transform/lower_intrin.cc
+++ b/src/tirx/transform/lower_intrin.cc
@@ -46,11 +46,21 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt_;
using FLowerGeneral = ffi::TypedFunction<PrimExpr(PrimExpr)>;
- IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string
mtriple = "")
- : IRMutatorWithAnalyzer(analyzer) {
+ IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt) :
IRMutatorWithAnalyzer(analyzer) {
+ std::string target = tgt->kind->name;
+ ffi::String mtriple = tgt->GetAttr<ffi::String>("mtriple").value_or("");
+
std::vector<std::string> patterns;
+ // For CUDA targets, we need to add the fast math patterns if
enable_fast_math is true.
+ // The priority of the fast math patterns is higher than the normal
patterns.
+ bool is_fast_math = tgt->GetAttr<bool>("enable_fast_math").value_or(false);
+ if (is_fast_math) {
+ patterns.push_back(target + ".fastmath.FLowerIntrinsic");
+ patterns.push_back(target + ".fastmath.FLegalize");
+ }
patterns.push_back(target + ".FLowerIntrinsic");
patterns.push_back(target + ".FLegalize");
+
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
patterns.push_back(target + ".aarch64.FLowerIntrinsic");
@@ -354,7 +364,7 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
arith::Analyzer analyzer;
- return IntrinInjecter(&analyzer, target)(std::move(stmt));
+ return IntrinInjecter(&analyzer,
Target(ffi::String(target)))(std::move(stmt));
}
namespace transform {
@@ -365,9 +375,7 @@ Pass LowerIntrin() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target
attribute";
arith::Analyzer analyzer;
- auto mtriple = target.value()->GetAttr<ffi::String>("mtriple", "");
- n->body =
- IntrinInjecter(&analyzer, target.value()->kind->name,
mtriple.value())(std::move(n->body));
+ n->body = IntrinInjecter(&analyzer, target.value())(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {});
diff --git a/tests/python/codegen/test_target_codegen_cuda_fastmath.py
b/tests/python/codegen/test_target_codegen_cuda_fastmath.py
new file mode 100644
index 0000000000..84cac4361e
--- /dev/null
+++ b/tests/python/codegen/test_target_codegen_cuda_fastmath.py
@@ -0,0 +1,298 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+from collections.abc import Callable
+from dataclasses import dataclass
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+import tvm.tirx as tirx
+from tvm.contrib.nvcc import have_fp16
+from tvm.ir.module import IRModule
+from tvm.runtime.executable import Executable
+from tvm.script import tirx as T
+
+VECTOR_N_INPUTS = 8
+
+
+def make_prim_func(
+ name: str,
+ dtype: str,
+ num_inputs: int,
+ op: Callable[[tirx.PrimExpr, ...], tirx.PrimExpr],
+) -> tirx.PrimFunc:
+ """Make a primitive function that applies the given operation to the input
buffer."""
+ if num_inputs == 1:
+
+ @T.prim_func
+ def kernel(
+ A: T.Buffer((VECTOR_N_INPUTS,), dtype),
+ B: T.Buffer((VECTOR_N_INPUTS,), dtype),
+ ):
+ T.func_attr({"global_symbol": name + "_kernel", "tirx.noalias":
True})
+ for i in T.thread_binding(VECTOR_N_INPUTS, thread="threadIdx.x"):
+ B[i] = op(A[i])
+
+ return kernel
+ elif num_inputs == 2:
+
+ @T.prim_func
+ def kernel(
+ A: T.Buffer((VECTOR_N_INPUTS,), dtype),
+ E: T.Buffer((VECTOR_N_INPUTS,), dtype),
+ B: T.Buffer((VECTOR_N_INPUTS,), dtype),
+ ):
+ T.func_attr({"global_symbol": name + "_kernel", "tirx.noalias":
True})
+ for i in T.thread_binding(VECTOR_N_INPUTS, thread="threadIdx.x"):
+ B[i] = op(A[i], E[i])
+
+ return kernel
+ else:
+ raise ValueError(f"Unsupported number of inputs: {num_inputs}")
+
+
+@dataclass(frozen=True)
+class MathCase:
+ name: str
+ op: Callable[[tirx.PrimExpr, ...], tirx.PrimExpr]
+ num_inputs: int
+ default_intrinsic_f16: str
+ default_intrinsic_bf16: str
+ default_intrinsic_f32: str
+ default_intrinsic_f64: str
+ fast_math_intrinsic_f32: str
+ np_ref: object
+ rtol: float = 1e-5
+ atol: float = 1e-6
+
+
+MATH_CASES = [
+ MathCase(
+ "exp_case",
+ T.exp,
+ 1,
+ "hexp",
+ "hexp",
+ "expf",
+ "exp",
+ "__expf",
+ lambda x: np.exp(x),
+ ),
+ MathCase(
+ "exp10_case",
+ T.exp10,
+ 1,
+ "hexp10",
+ "hexp10",
+ "exp10f",
+ "exp10",
+ "__exp10f",
+ lambda x: np.power(10.0, x),
+ ),
+ MathCase(
+ "log_case",
+ T.log,
+ 1,
+ "hlog",
+ "hlog",
+ "logf",
+ "log",
+ "__logf",
+ lambda x: np.log(x),
+ ),
+ MathCase(
+ "log2_case",
+ T.log2,
+ 1,
+ "hlog2",
+ "hlog2",
+ "log2f",
+ "log2",
+ "__log2f",
+ lambda x: np.log2(x),
+ ),
+ MathCase(
+ "log10_case",
+ T.log10,
+ 1,
+ "hlog10",
+ "hlog10",
+ "log10f",
+ "log10",
+ "__log10f",
+ lambda x: np.log10(x),
+ ),
+ MathCase(
+ "tan_case",
+ T.tan,
+ 1,
+ "htan",
+ "htan",
+ "tanf",
+ "tan",
+ "tanf",
+ lambda x: np.tan(x),
+ ),
+ MathCase(
+ "cos_case",
+ T.cos,
+ 1,
+ "hcos",
+ "hcos",
+ "cosf",
+ "cos",
+ "__cosf",
+ lambda x: np.cos(x),
+ ),
+ MathCase(
+ "sin_case",
+ T.sin,
+ 1,
+ "hsin",
+ "hsin",
+ "sinf",
+ "sin",
+ "__sinf",
+ lambda x: np.sin(x),
+ ),
+ MathCase(
+ "tanh_case",
+ T.tanh,
+ 1,
+ "htanh",
+ "htanh",
+ "tanhf",
+ "tanh",
+ "__tanhf",
+ lambda x: np.tanh(x),
+ ),
+ MathCase(
+ "pow_case",
+ T.pow,
+ 2,
+ "hpow",
+ "hpow",
+ "powf",
+ "pow",
+ "__powf",
+ lambda x, y: np.power(x, y),
+ ),
+]
+
+
+def make_mod(
+ dtype: str, case: MathCase, enable_fast_math: bool
+) -> tuple[tvm.target.Target, tvm.IRModule]:
+ """Make a module for the given dtype and case."""
+ target = tvm.target.Target({"kind": "cuda", "enable_fast_math":
enable_fast_math})
+ prim_func = make_prim_func(case.name, dtype, case.num_inputs, case.op)
+ return target, tvm.IRModule.from_expr(prim_func.with_attr("target",
target))
+
+
+def expected_intrinsic(dtype: str, case: MathCase, enable_fast_math: bool) ->
str:
+ """Get the expected intrinsic for the given dtype and case."""
+ if dtype == "float16":
+ return case.default_intrinsic_f16
+ elif dtype == "bfloat16":
+ return case.default_intrinsic_bf16
+ elif dtype == "float32":
+ return case.fast_math_intrinsic_f32 if enable_fast_math else
case.default_intrinsic_f32
+ elif dtype == "float64":
+ return case.default_intrinsic_f64
+ else:
+ raise ValueError(f"Unsupported dtype: {dtype}")
+
+
+def check_lowered_ir(
+ dtype: str, case: MathCase, enable_fast_math: bool
+) -> tuple[tvm.target.Target, IRModule]:
+ """Check the lowered IR for the given dtype and case."""
+ target, mod = make_mod(dtype, case, enable_fast_math)
+ lowered_mod = tvm.tirx.transform.LowerIntrin()(mod)
+ script = lowered_mod.script(show_meta=False)
+ expected = expected_intrinsic(dtype, case, enable_fast_math)
+ assert re.search(rf"""["']{re.escape(expected)}["']""", script)
+ return target, lowered_mod
+
+
+def check_cuda_source(
+ target: tvm.target.Target,
+ mod: IRModule,
+ dtype: str,
+ case: MathCase,
+ enable_fast_math: bool,
+) -> Executable:
+ """Check the CUDA source for the given dtype and case."""
+ executable = tvm.compile(mod, target=target)
+ source = executable.mod.imports[0].inspect_source()
+ expected = expected_intrinsic(dtype, case, enable_fast_math)
+ assert re.search(rf"(?<!_)\b{re.escape(expected)}\s*\(", source)
+ return executable
+
+
+def make_numpy_inputs(dtype: str, case: MathCase):
+ """Make the numpy inputs for the given dtype and case."""
+ lhs = np.array([0.25, 0.5, 1.0, 2.0, 4.0, 9.0, 16.0, 10.0], dtype=dtype)
+ if case.num_inputs == 1:
+ return [lhs]
+ elif case.num_inputs == 2:
+ rhs = np.array([2.0, 3.0, 0.5, 1.5, 0.25, 0.5, 2.0, 1.0], dtype=dtype)
+ return [lhs, rhs]
+ else:
+ raise ValueError(f"Unsupported number of inputs: {case.num_inputs}")
+
+
+def check_runtime(dtype: str, case: MathCase, executable: Executable):
+ """Check the runtime for the given dtype and case."""
+ dev = tvm.cuda(0)
+
+ np_inputs = make_numpy_inputs(dtype, case)
+ expected = case.np_ref(*[arr.astype(dtype) for arr in
np_inputs]).astype(dtype)
+
+ tvm_inputs = [tvm.runtime.tensor(arr, device=dev) for arr in np_inputs]
+ output = tvm.runtime.empty((VECTOR_N_INPUTS,), dtype, dev)
+
+ executable(*tvm_inputs, output)
+ dev.sync()
+
+ actual = output.numpy()
+
+ np.testing.assert_allclose(actual, expected, rtol=case.rtol,
atol=case.atol)
+
+
[email protected]_gpu
[email protected]_cuda
[email protected](
+ "dtype",
+ ["float16", "bfloat16", "float32", "float64"],
+)
[email protected]("case", MATH_CASES, ids=lambda case: f"{case.name}")
[email protected]("enable_fast_math", [False, True], ids=["default",
"fast_math"])
+def test_cuda_math_intrinsic_lowering_source_and_runtime(dtype, case,
enable_fast_math):
+ if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version):
+ pytest.skip("GPU does not support float16")
+ if dtype == "bfloat16" and case.name.startswith("pow_"):
+ pytest.skip("pow_argnames=case is only supported for float")
+
+ target, lowered_mod = check_lowered_ir(dtype, case, enable_fast_math)
+ executable = check_cuda_source(target, lowered_mod, dtype, case,
enable_fast_math)
+ check_runtime(dtype, case, executable)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index d73ec5bae5..ca05a6492f 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4943,7 +4943,7 @@ def
test_nms_max_output_boxes_per_class_zero(with_explicit_max: bool):
check_correctness(model, inputs=inputs, opset=11)
tvm_out = run_in_tvm(model, inputs=inputs, opset=11)
- tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, list | tuple)
else tvm_out.numpy()
+ tvm_selected = tvm_out[0].numpy() if isinstance(tvm_out, (list, tuple))
else tvm_out.numpy() # noqa: UP038
assert tvm_selected.shape == (0, 3)
diff --git a/tests/python/relax/test_frontend_onnx_backend.py
b/tests/python/relax/test_frontend_onnx_backend.py
index 301b95f640..ad3d490df7 100644
--- a/tests/python/relax/test_frontend_onnx_backend.py
+++ b/tests/python/relax/test_frontend_onnx_backend.py
@@ -77,9 +77,9 @@ class TVMRelaxBackendRep(BackendRep):
self._vm.invoke_stateful("main")
output = self._vm.get_outputs("main")
- if isinstance(output, tvm.runtime.Tensor | np.ndarray):
+ if isinstance(output, (tvm.runtime.Tensor, np.ndarray)): # noqa: UP038
return (output.numpy() if hasattr(output, "numpy") else output,)
- if isinstance(output, tuple | list):
+ if isinstance(output, (tuple, list)): # noqa: UP038
return tuple(o.numpy() if hasattr(o, "numpy") else np.array(o) for
o in output)
return (np.array(output),)
diff --git a/tests/python/target/test_target_target.py
b/tests/python/target/test_target_target.py
index c037fcadd2..94706ee8d8 100644
--- a/tests/python/target/test_target_target.py
+++ b/tests/python/target/test_target_target.py
@@ -148,6 +148,7 @@ def test_target_tag_0():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 65536
+ assert not tgt.attrs["enable_fast_math"]
def test_target_tag_1():
@@ -158,15 +159,19 @@ def test_target_tag_1():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 32768
+ assert not tgt.attrs["enable_fast_math"]
def test_target_tag_override():
"""Test creating a target from a tag with attribute overrides."""
- tgt = tvm.target.Target({"tag": "nvidia/nvidia-a100",
"l2_cache_size_bytes": 12345})
+ tgt = tvm.target.Target(
+ {"tag": "nvidia/nvidia-a100", "l2_cache_size_bytes": 12345,
"enable_fast_math": True}
+ )
assert tgt.kind.name == "cuda"
assert tgt.attrs["arch"] == "sm_80"
# Override should take effect
assert int(tgt.attrs["l2_cache_size_bytes"]) == 12345
+ assert tgt.attrs["enable_fast_math"]
# Base tag fields should be preserved
assert tgt.attrs["max_shared_memory_per_block"] == 49152
assert tgt.attrs["thread_warp_size"] == 32
@@ -189,12 +194,14 @@ def test_target_host_tags():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 32768
+ assert not tgt.attrs["enable_fast_math"]
assert tgt.host.kind.name == "cuda"
assert tgt.host.attrs["arch"] == "sm_75"
assert tgt.host.attrs["max_shared_memory_per_block"] == 49152
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 65536
+ assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_tag_dict():
@@ -205,6 +212,7 @@ def test_target_host_tag_dict():
assert tgt.attrs["max_threads_per_block"] == 1024
assert tgt.attrs["thread_warp_size"] == 32
assert tgt.attrs["registers_per_block"] == 32768
+ assert not tgt.attrs["enable_fast_math"]
assert tgt.host.kind.name == "llvm"
@@ -217,6 +225,7 @@ def test_target_host_single_dict():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
+ assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_single_string():
@@ -234,6 +243,7 @@ def test_target_host_single_string_with_tag():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
+ assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_merge_0():
@@ -245,6 +255,7 @@ def test_target_host_merge_0():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
+ assert not tgt.host.attrs["enable_fast_math"]
def test_target_host_merge_1():
@@ -295,6 +306,7 @@ def test_target_with_host():
assert tgt.host.attrs["max_threads_per_block"] == 1024
assert tgt.host.attrs["thread_warp_size"] == 32
assert tgt.host.attrs["registers_per_block"] == 32768
+ assert not tgt.host.attrs["enable_fast_math"]
def test_target_attr_bool_value():