https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/173883
>From 421ac120df137dd03d440e322147893ab4b18a7f Mon Sep 17 00:00:00 2001 From: Matt Arsenault <[email protected]> Date: Tue, 23 Dec 2025 20:36:25 +0100 Subject: [PATCH] InstCombine: Implement SimplifyDemandedFPClass for sqrt --- llvm/include/llvm/Support/KnownFPClass.h | 4 ++ llvm/lib/Analysis/ValueTracking.cpp | 29 +++++--------- llvm/lib/Support/KnownFPClass.cpp | 24 ++++++++++++ .../InstCombineSimplifyDemanded.cpp | 38 +++++++++++++++++++ .../simplify-demanded-fpclass-sqrt.ll | 23 +++++------ 5 files changed, 85 insertions(+), 33 deletions(-) diff --git a/llvm/include/llvm/Support/KnownFPClass.h b/llvm/include/llvm/Support/KnownFPClass.h index c97f40b5252f7..27bad20561606 100644 --- a/llvm/include/llvm/Support/KnownFPClass.h +++ b/llvm/include/llvm/Support/KnownFPClass.h @@ -269,6 +269,10 @@ struct KnownFPClass { static LLVM_ABI KnownFPClass log(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic()); + /// Propagate known class for sqrt + static LLVM_ABI KnownFPClass + sqrt(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic()); + void resetAll() { *this = KnownFPClass(); } }; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 5ee408304f9bb..5e20a40854620 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -5126,27 +5126,18 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts, computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs, KnownSrc, Q, Depth + 1); - if (KnownSrc.isKnownNeverPosInfinity()) - Known.knownNot(fcPosInf); - if (KnownSrc.isKnownNever(fcSNan)) - Known.knownNot(fcSNan); - - // Any negative value besides -0 returns a nan. - if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero()) - Known.knownNot(fcNan); - - // The only negative value that can be returned is -0 for -0 inputs. - Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal); + DenormalMode Mode = DenormalMode::getDynamic(); - // If the input denormal mode could be PreserveSign, a negative - // subnormal input could produce a negative zero output. - const Function *F = II->getFunction(); - const fltSemantics &FltSem = - II->getType()->getScalarType()->getFltSemantics(); + bool HasNSZ = Q.IIQ.hasNoSignedZeros(II); + if (!HasNSZ) { + const Function *F = II->getFunction(); + const fltSemantics &FltSem = + II->getType()->getScalarType()->getFltSemantics(); + Mode = F ? F->getDenormalMode(FltSem) : DenormalMode::getDynamic(); + } - if (Q.IIQ.hasNoSignedZeros(II) || - (F && - KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem)))) + Known = KnownFPClass::sqrt(KnownSrc, Mode); + if (HasNSZ) Known.knownNot(fcNegZero); break; diff --git a/llvm/lib/Support/KnownFPClass.cpp b/llvm/lib/Support/KnownFPClass.cpp index ff98908fdb2c4..afa08c1fd047f 100644 --- a/llvm/lib/Support/KnownFPClass.cpp +++ b/llvm/lib/Support/KnownFPClass.cpp @@ -243,3 +243,27 @@ KnownFPClass KnownFPClass::log(const KnownFPClass &KnownSrc, return Known; } + +KnownFPClass KnownFPClass::sqrt(const KnownFPClass &KnownSrc, + DenormalMode Mode) { + KnownFPClass Known; + + if (KnownSrc.isKnownNeverPosInfinity()) + Known.knownNot(fcPosInf); + if (KnownSrc.isKnownNever(fcSNan)) + Known.knownNot(fcSNan); + + // Any negative value besides -0 returns a nan. + if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero()) + Known.knownNot(fcNan); + + // The only negative value that can be returned is -0 for -0 inputs. + Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal); + + // If the input denormal mode could be PreserveSign, a negative + // subnormal input could produce a negative zero output. + if (KnownSrc.isKnownNeverLogicalNegZero(Mode)) + Known.knownNot(fcNegZero); + + return Known; +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 9f0d0fc36a825..9508788a7ba28 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -2385,6 +2385,44 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Value *V, FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses; return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true); } + case Intrinsic::sqrt: { + FPClassTest DemandedSrcMask = + DemandedMask & (fcNegZero | fcPositive | fcNan); + + if (DemandedMask & fcNan) + DemandedSrcMask |= (fcNegative & ~fcNegZero); + + // sqrt(max_subnormal) is a normal value + if (DemandedMask & fcPosNormal) + DemandedSrcMask |= fcPosSubnormal; + + KnownFPClass KnownSrc; + if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1)) + return I; + + Type *EltTy = VTy->getScalarType(); + DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics()); + + // sqrt(-x) = nan, but be careful of negative subnormals flushed to 0. + if (KnownSrc.isKnownNever(fcPositive) && + KnownSrc.isKnownNeverLogicalZero(Mode)) + return ConstantFP::getQNaN(VTy); + + Known = KnownFPClass::sqrt(KnownSrc, Mode); + FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses; + + if (ValidResults == fcZero) { + if (FMF.noSignedZeros()) + return ConstantFP::getZero(VTy); + + Value *Copysign = Builder.CreateCopySign(ConstantFP::getZero(VTy), + CI->getArgOperand(0), FMF); + Copysign->takeName(CI); + return Copysign; + } + + return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true); + } case Intrinsic::canonicalize: { Type *EltTy = VTy->getScalarType(); diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll index 447b3786109ee..234d97c05f3aa 100644 --- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll +++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll @@ -18,8 +18,7 @@ declare nofpclass(pinf pnorm psub zero) float @returns_negative_nonzero_or_nan() define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) { ; CHECK-LABEL: define nofpclass(inf zero sub norm) float @ret_only_nan_sqrt( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]]) -; CHECK-NEXT: ret float [[RESULT]] +; CHECK-NEXT: ret float 0x7FF8000000000000 ; %result = call float @llvm.sqrt.f32(float %x) ret float %result @@ -30,7 +29,7 @@ define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) { define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) { ; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.copysign.f32(float 0.000000e+00, float [[X]]) ; CHECK-NEXT: ret float [[RESULT]] ; %result = call float @llvm.sqrt.f32(float %x) @@ -40,7 +39,7 @@ define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) { define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt_preserve_flags(float %x) { ; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt_preserve_flags( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call ninf float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call ninf float @llvm.copysign.f32(float 0.000000e+00, float [[X]]) ; CHECK-NEXT: ret float [[RESULT]] ; %result = call ninf float @llvm.sqrt.f32(float %x) @@ -59,7 +58,7 @@ define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt_nsz(float %x) { define nofpclass(inf nan norm sub) <2 x float> @ret_only_zero_sqrt_vec(<2 x float> %x) { ; CHECK-LABEL: define nofpclass(nan inf sub norm) <2 x float> @ret_only_zero_sqrt_vec( ; CHECK-SAME: <2 x float> [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> zeroinitializer, <2 x float> [[X]]) ; CHECK-NEXT: ret <2 x float> [[RESULT]] ; %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x) @@ -102,8 +101,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative() { define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() { ; CHECK-LABEL: define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() { ; CHECK-NEXT: [[KNOWN_NEGATIVE_NONZERO:%.*]] = call float @returns_negative_nonzero() -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[KNOWN_NEGATIVE_NONZERO]]) -; CHECK-NEXT: ret float [[RESULT]] +; CHECK-NEXT: ret float 0x7FF8000000000000 ; %known.negative.nonzero = call float @returns_negative_nonzero() %result = call float @llvm.sqrt.f32(float %known.negative.nonzero) @@ -114,8 +112,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() { define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() { ; CHECK-LABEL: define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() { ; CHECK-NEXT: [[KNOWN_NEGATIVE_NONZERO:%.*]] = call <2 x float> @returns_negative_nonzero_vec() -; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[KNOWN_NEGATIVE_NONZERO]]) -; CHECK-NEXT: ret <2 x float> [[RESULT]] +; CHECK-NEXT: ret <2 x float> splat (float 0x7FF8000000000000) ; %known.negative.nonzero = call <2 x float> @returns_negative_nonzero_vec() %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %known.negative.nonzero) @@ -168,7 +165,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown ; CHECK-LABEL: define nofpclass(inf zero norm) float @ret_only_nan_or_sub__sqrt__select_unknown_or_maybe_ninf( ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[X:%.*]]) { ; CHECK-NEXT: [[MAYBE_NINF:%.*]] = call nofpclass(nan pinf sub norm) float @func() -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float [[MAYBE_NINF]] +; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float 0xFFF0000000000000 ; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) ; CHECK-NEXT: ret float [[RESULT]] ; @@ -212,8 +209,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(i1 %cond, float %unknown) { ; CHECK-LABEL: define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source( ; CHECK-SAME: i1 [[COND:%.*]], float [[UNKNOWN:%.*]]) { -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[UNKNOWN]], float 0x7FF0000000000000 -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) +; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[UNKNOWN]]) ; CHECK-NEXT: ret float [[RESULT]] ; %select = select i1 %cond, float %unknown, float 0x7ff0000000000000 @@ -253,8 +249,7 @@ define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_so ; CHECK-LABEL: define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_source( ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) { ; CHECK-NEXT: [[ONLY_PNORM:%.*]] = call nofpclass(nan inf zero sub nnorm) float @func() -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PNORM]] -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) +; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[NOT_NAN]]) ; CHECK-NEXT: ret float [[RESULT]] ; %only.pnorm = call nofpclass(nan inf nnorm sub zero) float @func() _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
