https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/173883
>From 0db24a1ecffe789c39c75dc625bcb4f2a781f5ea Mon Sep 17 00:00:00 2001 From: Matt Arsenault <[email protected]> Date: Tue, 23 Dec 2025 20:36:25 +0100 Subject: [PATCH] InstCombine: Implement SimplifyDemandedFPClass for sqrt --- llvm/include/llvm/Support/KnownFPClass.h | 4 +++ llvm/lib/Analysis/ValueTracking.cpp | 29 ++++++----------- llvm/lib/Support/KnownFPClass.cpp | 24 ++++++++++++++ .../InstCombineSimplifyDemanded.cpp | 31 +++++++++++++++++++ .../simplify-demanded-fpclass-sqrt.ll | 18 +++++------ 5 files changed, 76 insertions(+), 30 deletions(-) diff --git a/llvm/include/llvm/Support/KnownFPClass.h b/llvm/include/llvm/Support/KnownFPClass.h index 07d74f2867089..ae9513bbebe80 100644 --- a/llvm/include/llvm/Support/KnownFPClass.h +++ b/llvm/include/llvm/Support/KnownFPClass.h @@ -267,6 +267,10 @@ struct KnownFPClass { static LLVM_ABI KnownFPClass log(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic()); + /// Propagate known class for sqrt + static LLVM_ABI KnownFPClass + sqrt(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic()); + void resetAll() { *this = KnownFPClass(); } }; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index bac863cb3c67c..25664805b016c 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -5131,27 +5131,18 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts, computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs, KnownSrc, Q, Depth + 1); - if (KnownSrc.isKnownNeverPosInfinity()) - Known.knownNot(fcPosInf); - if (KnownSrc.isKnownNever(fcSNan)) - Known.knownNot(fcSNan); - - // Any negative value besides -0 returns a nan. - if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero()) - Known.knownNot(fcNan); - - // The only negative value that can be returned is -0 for -0 inputs. - Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal); + DenormalMode Mode = DenormalMode::getDynamic(); - // If the input denormal mode could be PreserveSign, a negative - // subnormal input could produce a negative zero output. - const Function *F = II->getFunction(); - const fltSemantics &FltSem = - II->getType()->getScalarType()->getFltSemantics(); + bool HasNSZ = Q.IIQ.hasNoSignedZeros(II); + if (!HasNSZ) { + const Function *F = II->getFunction(); + const fltSemantics &FltSem = + II->getType()->getScalarType()->getFltSemantics(); + Mode = F ? F->getDenormalMode(FltSem) : DenormalMode::getDynamic(); + } - if (Q.IIQ.hasNoSignedZeros(II) || - (F && - KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem)))) + Known = KnownFPClass::sqrt(KnownSrc, Mode); + if (HasNSZ) Known.knownNot(fcNegZero); break; diff --git a/llvm/lib/Support/KnownFPClass.cpp b/llvm/lib/Support/KnownFPClass.cpp index ff98908fdb2c4..afa08c1fd047f 100644 --- a/llvm/lib/Support/KnownFPClass.cpp +++ b/llvm/lib/Support/KnownFPClass.cpp @@ -243,3 +243,27 @@ KnownFPClass KnownFPClass::log(const KnownFPClass &KnownSrc, return Known; } + +KnownFPClass KnownFPClass::sqrt(const KnownFPClass &KnownSrc, + DenormalMode Mode) { + KnownFPClass Known; + + if (KnownSrc.isKnownNeverPosInfinity()) + Known.knownNot(fcPosInf); + if (KnownSrc.isKnownNever(fcSNan)) + Known.knownNot(fcSNan); + + // Any negative value besides -0 returns a nan. + if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero()) + Known.knownNot(fcNan); + + // The only negative value that can be returned is -0 for -0 inputs. + Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal); + + // If the input denormal mode could be PreserveSign, a negative + // subnormal input could produce a negative zero output. + if (KnownSrc.isKnownNeverLogicalNegZero(Mode)) + Known.knownNot(fcNegZero); + + return Known; +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 3cb4c629b9108..711ef900bf420 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -2239,6 +2239,37 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Value *V, FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses; return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true); } + case Intrinsic::sqrt: { + FPClassTest DemandedSrcMask = + DemandedMask & (fcNegZero | fcPositive | fcNan); + + if (DemandedMask & fcNan) + DemandedSrcMask |= (fcNegative & ~fcNegZero); + + KnownFPClass KnownSrc; + if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1)) + return I; + + Type *EltTy = VTy->getScalarType(); + DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics()); + + // sqrt(-x) = nan, but be careful of negative subnormals flushed to 0. + if (KnownSrc.isKnownNever(fcPositive) && + KnownSrc.isKnownNeverLogicalZero(Mode)) + return ConstantFP::getQNaN(VTy); + + Known = KnownFPClass::sqrt(KnownSrc, Mode); + FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses; + + if (ValidResults == fcZero) { + Value *Copysign = Builder.CreateCopySign(ConstantFP::getZero(VTy), + CI->getArgOperand(0)); + Copysign->takeName(CI); + return Copysign; + } + + return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true); + } case Intrinsic::canonicalize: { Type *EltTy = VTy->getScalarType(); diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll index 9288bb7be3ecd..ad9881d70b5fd 100644 --- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll +++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll @@ -18,8 +18,7 @@ declare nofpclass(pinf pnorm psub zero) float @returns_negative_nonzero_or_nan() define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) { ; CHECK-LABEL: define nofpclass(inf zero sub norm) float @ret_only_nan_sqrt( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]]) -; CHECK-NEXT: ret float [[RESULT]] +; CHECK-NEXT: ret float 0x7FF8000000000000 ; %result = call float @llvm.sqrt.f32(float %x) ret float %result @@ -30,7 +29,7 @@ define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) { define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) { ; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt( ; CHECK-SAME: float [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.copysign.f32(float 0.000000e+00, float [[X]]) ; CHECK-NEXT: ret float [[RESULT]] ; %result = call float @llvm.sqrt.f32(float %x) @@ -40,7 +39,7 @@ define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) { define nofpclass(inf nan norm sub) <2 x float> @ret_only_zero_sqrt_vec(<2 x float> %x) { ; CHECK-LABEL: define nofpclass(nan inf sub norm) <2 x float> @ret_only_zero_sqrt_vec( ; CHECK-SAME: <2 x float> [[X:%.*]]) { -; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]]) +; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> zeroinitializer, <2 x float> [[X]]) ; CHECK-NEXT: ret <2 x float> [[RESULT]] ; %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x) @@ -83,8 +82,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative() { define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() { ; CHECK-LABEL: define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() { ; CHECK-NEXT: [[KNOWN_NEGATIVE_NONZERO:%.*]] = call float @returns_negative_nonzero() -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[KNOWN_NEGATIVE_NONZERO]]) -; CHECK-NEXT: ret float [[RESULT]] +; CHECK-NEXT: ret float 0x7FF8000000000000 ; %known.negative.nonzero = call float @returns_negative_nonzero() %result = call float @llvm.sqrt.f32(float %known.negative.nonzero) @@ -95,8 +93,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() { define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() { ; CHECK-LABEL: define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() { ; CHECK-NEXT: [[KNOWN_NEGATIVE_NONZERO:%.*]] = call <2 x float> @returns_negative_nonzero_vec() -; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[KNOWN_NEGATIVE_NONZERO]]) -; CHECK-NEXT: ret <2 x float> [[RESULT]] +; CHECK-NEXT: ret <2 x float> splat (float 0x7FF8000000000000) ; %known.negative.nonzero = call <2 x float> @returns_negative_nonzero_vec() %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %known.negative.nonzero) @@ -149,7 +146,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown ; CHECK-LABEL: define nofpclass(inf zero norm) float @ret_only_nan_or_sub__sqrt__select_unknown_or_maybe_ninf( ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[X:%.*]]) { ; CHECK-NEXT: [[MAYBE_NINF:%.*]] = call nofpclass(nan pinf sub norm) float @func() -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float [[MAYBE_NINF]] +; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float 0xFFF0000000000000 ; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) ; CHECK-NEXT: ret float [[RESULT]] ; @@ -193,8 +190,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(i1 %cond, float %unknown) { ; CHECK-LABEL: define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source( ; CHECK-SAME: i1 [[COND:%.*]], float [[UNKNOWN:%.*]]) { -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[UNKNOWN]], float 0x7FF0000000000000 -; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]]) +; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[UNKNOWN]]) ; CHECK-NEXT: ret float [[RESULT]] ; %select = select i1 %cond, float %unknown, float 0x7ff0000000000000 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
