https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/176003
None >From 82372b2c36018ae7b2abc9af106e07bbc8bac689 Mon Sep 17 00:00:00 2001 From: Matt Arsenault <[email protected]> Date: Wed, 14 Jan 2026 18:40:41 +0100 Subject: [PATCH] InstCombine: Infer fast math flags for sqrt --- .../InstCombineSimplifyDemanded.cpp | 39 +++++++++-- .../simplify-demanded-fpclass-sqrt.ll | 65 +++++++++++++++++-- 2 files changed, 95 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 3bbc4a913ada6..3292d3538b4e3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -2036,9 +2036,9 @@ static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask, /// Try to set an inferred no-nans or no-infs in \p FMF. \p ValidResults is a /// mask of known valid results for the operator (already computed from the /// result, and the known operand inputs in \p Known) -static FastMathFlags -inferFastMathValueFlags(FastMathFlags FMF, FPClassTest ValidResults, - ArrayRef<const KnownFPClass> Known) { +static FastMathFlags inferFastMathValueFlags(FastMathFlags FMF, + FPClassTest ValidResults, + ArrayRef<KnownFPClass> Known) { if (!FMF.noNaNs() && (ValidResults & fcNan) == fcNone) { if (all_of(Known, [](const KnownFPClass KnownSrc) { return KnownSrc.isKnownNeverNaN(); @@ -2056,6 +2056,28 @@ inferFastMathValueFlags(FastMathFlags FMF, FPClassTest ValidResults, return FMF; } +/// Apply epilog fixups to a floating-point intrinsic. See if the result can +/// fold to a constant, or apply fast math flags. +static Value *simplifyDemandedFPClassResult(CallInst *FPOp, FastMathFlags FMF, + FPClassTest DemandedMask, + KnownFPClass &Known, + ArrayRef<KnownFPClass> KnownSrcs) { + FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses; + Constant *SingleVal = getFPClassConstant(FPOp->getType(), ValidResults, + /*IsCanonicalizing=*/true); + if (SingleVal) + return SingleVal; + + FastMathFlags InferredFMF = + inferFastMathValueFlags(FMF, ValidResults, KnownSrcs); + if (InferredFMF != FMF) { + FPOp->setFastMathFlags(InferredFMF); + return FPOp; + } + + return nullptr; +} + Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I, FPClassTest DemandedMask, KnownFPClass &Known, @@ -2790,6 +2812,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I, if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1)) return I; + // Infer the source cannot be negative if the result cannot be nan. + if ((DemandedMask & fcNan) == fcNone) + KnownSrc.knownNot((fcNegative & ~fcNegZero) | fcNan); + + // Infer the source cannot be +inf if the result is not +nf + if ((DemandedMask & fcPosInf) == fcNone) + KnownSrc.knownNot(fcPosInf); + Type *EltTy = VTy->getScalarType(); DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics()); @@ -2811,7 +2841,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I, return Copysign; } - return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true); + return simplifyDemandedFPClassResult(CI, FMF, DemandedMask, Known, + {KnownSrc}); } case Intrinsic::trunc: case Intrinsic::floor: diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll index b09faf0f4c3af..6ec5daa48e125 100644 --- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll +++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll @@ -220,7 +220,7 @@ define nofpclass(nan inf zero sub nnorm) float @pnorm_result_demands_pnorm_sourc ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) { ; CHECK-NEXT: [[ONLY_PNORM:%.*]] = call nofpclass(nan inf zero sub nnorm) float @func() ; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PNORM]] -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) +; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf float @llvm.sqrt.f32(float [[SELECT]]) ; CHECK-NEXT: ret float [[RESULT]] ; %only.pnorm = call nofpclass(nan inf nnorm sub zero) float @func() @@ -234,7 +234,7 @@ define nofpclass(nan inf zero sub nnorm) float @pnorm_result_demands_psub_source ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) { ; CHECK-NEXT: [[ONLY_PSUB:%.*]] = call nofpclass(nan inf zero nsub norm) float @func() ; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PSUB]] -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) +; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf float @llvm.sqrt.f32(float [[SELECT]]) ; CHECK-NEXT: ret float [[RESULT]] ; %only.psub = call nofpclass(nan inf norm nsub zero) float @func() @@ -258,7 +258,7 @@ define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_so define nofpclass(nan) float @ret_no_nan__sqrt(float %x) { ; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]]) ; CHECK-NEXT: ret float [[RESULT]] ; %result = call contract float @llvm.sqrt.f32(float %x) @@ -278,7 +278,7 @@ define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_inputs(float nofpclass(n define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(float nofpclass(nan ninf nnorm nsub) %x) { ; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs( ; CHECK-SAME: float nofpclass(nan ninf nsub nnorm) [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]]) ; CHECK-NEXT: ret float [[RESULT]] ; %result = call contract float @llvm.sqrt.f32(float %x) @@ -289,12 +289,67 @@ define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(float nofp define nofpclass(snan) float @ret_no_snan__noundef_sqrt__no_neg_or_nan_inputs(float nofpclass(nan ninf nnorm nsub) %x) { ; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__noundef_sqrt__no_neg_or_nan_inputs( ; CHECK-SAME: float nofpclass(nan ninf nsub nnorm) [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call contract noundef float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call nnan contract noundef float @llvm.sqrt.f32(float [[X]]) ; CHECK-NEXT: ret float [[RESULT]] ; %result = call contract noundef float @llvm.sqrt.f32(float %x) ret float %result } +define nofpclass(snan) float @ret_no_snan__sqrt__no_pinf_inputs(float nofpclass(pinf) %x) { +; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__sqrt__no_pinf_inputs( +; CHECK-SAME: float nofpclass(pinf) [[X:%.*]]) { +; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: ret float [[RESULT]] +; + %result = call contract float @llvm.sqrt.f32(float %x) + ret float %result +} + +; Cannot infer flags. A nan output could still be produced by a -inf +; input. +define nofpclass(pinf) float @ret_no_pinf__sqrt(float %x) { +; CHECK-LABEL: define nofpclass(pinf) float @ret_no_pinf__sqrt( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: ret float [[RESULT]] +; + %result = call contract float @llvm.sqrt.f32(float %x) + ret float %result +} + +; Infer nnan and ninf +define nofpclass(nan pinf) float @ret_no_pinf_no_nan__sqrt(float %x) { +; CHECK-LABEL: define nofpclass(nan pinf) float @ret_no_pinf_no_nan__sqrt( +; CHECK-SAME: float [[X:%.*]]) { +; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: ret float [[RESULT]] +; + %result = call contract float @llvm.sqrt.f32(float %x) + ret float %result +} + +; Infer nnan and ninf +define nofpclass(nan) float @ret_no_nan__sqrt__no_pinf_inputs(float nofpclass(pinf) %x) { +; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt__no_pinf_inputs( +; CHECK-SAME: float nofpclass(pinf) [[X:%.*]]) { +; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: ret float [[RESULT]] +; + %result = call contract float @llvm.sqrt.f32(float %x) + ret float %result +} + +; Infer nnan and ninf +define nofpclass(nan) float @ret_no_nan__sqrt__no_inf_inputs(float nofpclass(inf) %x) { +; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt__no_inf_inputs( +; CHECK-SAME: float nofpclass(inf) [[X:%.*]]) { +; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: ret float [[RESULT]] +; + %result = call contract float @llvm.sqrt.f32(float %x) + ret float %result +} + attributes #0 = { "denormal-fp-math"="preserve-sign,preserve-sign" } attributes #1 = { "denormal-fp-math"="dynamic,dynamic" } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
