https://github.com/arsenm updated 
https://github.com/llvm/llvm-project/pull/173883

>From 46adc117b0c0d4bcb0037a151f93ca0020108a85 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <[email protected]>
Date: Tue, 23 Dec 2025 20:36:25 +0100
Subject: [PATCH] InstCombine: Implement SimplifyDemandedFPClass for sqrt

---
 llvm/include/llvm/Support/KnownFPClass.h      |  4 +++
 llvm/lib/Analysis/ValueTracking.cpp           | 29 ++++++-----------
 llvm/lib/Support/KnownFPClass.cpp             | 24 ++++++++++++++
 .../InstCombineSimplifyDemanded.cpp           | 31 +++++++++++++++++++
 .../simplify-demanded-fpclass-sqrt.ll         | 18 +++++------
 5 files changed, 76 insertions(+), 30 deletions(-)

diff --git a/llvm/include/llvm/Support/KnownFPClass.h 
b/llvm/include/llvm/Support/KnownFPClass.h
index 07d74f2867089..ae9513bbebe80 100644
--- a/llvm/include/llvm/Support/KnownFPClass.h
+++ b/llvm/include/llvm/Support/KnownFPClass.h
@@ -267,6 +267,10 @@ struct KnownFPClass {
   static LLVM_ABI KnownFPClass
   log(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic());
 
+  /// Propagate known class for sqrt
+  static LLVM_ABI KnownFPClass
+  sqrt(const KnownFPClass &Src, DenormalMode Mode = 
DenormalMode::getDynamic());
+
   void resetAll() { *this = KnownFPClass(); }
 };
 
diff --git a/llvm/lib/Analysis/ValueTracking.cpp 
b/llvm/lib/Analysis/ValueTracking.cpp
index 4fbbfd1a0cf12..543a2992f9a46 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -5131,27 +5131,18 @@ void computeKnownFPClass(const Value *V, const APInt 
&DemandedElts,
       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
                           KnownSrc, Q, Depth + 1);
 
-      if (KnownSrc.isKnownNeverPosInfinity())
-        Known.knownNot(fcPosInf);
-      if (KnownSrc.isKnownNever(fcSNan))
-        Known.knownNot(fcSNan);
-
-      // Any negative value besides -0 returns a nan.
-      if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
-        Known.knownNot(fcNan);
-
-      // The only negative value that can be returned is -0 for -0 inputs.
-      Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
+      DenormalMode Mode = DenormalMode::getDynamic();
 
-      // If the input denormal mode could be PreserveSign, a negative
-      // subnormal input could produce a negative zero output.
-      const Function *F = II->getFunction();
-      const fltSemantics &FltSem =
-          II->getType()->getScalarType()->getFltSemantics();
+      bool HasNSZ = Q.IIQ.hasNoSignedZeros(II);
+      if (!HasNSZ) {
+        const Function *F = II->getFunction();
+        const fltSemantics &FltSem =
+            II->getType()->getScalarType()->getFltSemantics();
+        Mode = F ? F->getDenormalMode(FltSem) : DenormalMode::getDynamic();
+      }
 
-      if (Q.IIQ.hasNoSignedZeros(II) ||
-          (F &&
-           KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem))))
+      Known = KnownFPClass::sqrt(KnownSrc, Mode);
+      if (HasNSZ)
         Known.knownNot(fcNegZero);
 
       break;
diff --git a/llvm/lib/Support/KnownFPClass.cpp 
b/llvm/lib/Support/KnownFPClass.cpp
index ff98908fdb2c4..afa08c1fd047f 100644
--- a/llvm/lib/Support/KnownFPClass.cpp
+++ b/llvm/lib/Support/KnownFPClass.cpp
@@ -243,3 +243,27 @@ KnownFPClass KnownFPClass::log(const KnownFPClass 
&KnownSrc,
 
   return Known;
 }
+
+KnownFPClass KnownFPClass::sqrt(const KnownFPClass &KnownSrc,
+                                DenormalMode Mode) {
+  KnownFPClass Known;
+
+  if (KnownSrc.isKnownNeverPosInfinity())
+    Known.knownNot(fcPosInf);
+  if (KnownSrc.isKnownNever(fcSNan))
+    Known.knownNot(fcSNan);
+
+  // Any negative value besides -0 returns a nan.
+  if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
+    Known.knownNot(fcNan);
+
+  // The only negative value that can be returned is -0 for -0 inputs.
+  Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
+
+  // If the input denormal mode could be PreserveSign, a negative
+  // subnormal input could produce a negative zero output.
+  if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
+    Known.knownNot(fcNegZero);
+
+  return Known;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp 
b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index a98da1b35ae71..addc1e2126b2d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -2244,6 +2244,37 @@ Value 
*InstCombinerImpl::SimplifyDemandedUseFPClass(Value *V,
       FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
       return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
     }
+    case Intrinsic::sqrt: {
+      FPClassTest DemandedSrcMask =
+          DemandedMask & (fcNegZero | fcPositive | fcNan);
+
+      if (DemandedMask & fcNan)
+        DemandedSrcMask |= (fcNegative & ~fcNegZero);
+
+      KnownFPClass KnownSrc;
+      if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1))
+        return I;
+
+      Type *EltTy = VTy->getScalarType();
+      DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics());
+
+      // sqrt(-x) = nan, but be careful of negative subnormals flushed to 0.
+      if (KnownSrc.isKnownNever(fcPositive) &&
+          KnownSrc.isKnownNeverLogicalZero(Mode))
+        return ConstantFP::getQNaN(VTy);
+
+      Known = KnownFPClass::sqrt(KnownSrc, Mode);
+      FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
+
+      if (ValidResults == fcZero) {
+        Value *Copysign = Builder.CreateCopySign(ConstantFP::getZero(VTy),
+                                                 CI->getArgOperand(0));
+        Copysign->takeName(CI);
+        return Copysign;
+      }
+
+      return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
+    }
     case Intrinsic::canonicalize: {
       Type *EltTy = VTy->getScalarType();
 
diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll 
b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
index 9288bb7be3ecd..ad9881d70b5fd 100644
--- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
@@ -18,8 +18,7 @@ declare nofpclass(pinf pnorm psub zero) float 
@returns_negative_nonzero_or_nan()
 define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) {
 ; CHECK-LABEL: define nofpclass(inf zero sub norm) float @ret_only_nan_sqrt(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]])
-; CHECK-NEXT:    ret float [[RESULT]]
+; CHECK-NEXT:    ret float 0x7FF8000000000000
 ;
   %result = call float @llvm.sqrt.f32(float %x)
   ret float %result
@@ -30,7 +29,7 @@ define nofpclass(inf norm sub zero) float 
@ret_only_nan_sqrt(float %x) {
 define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) {
 ; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.copysign.f32(float 
0.000000e+00, float [[X]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %result = call float @llvm.sqrt.f32(float %x)
@@ -40,7 +39,7 @@ define nofpclass(inf nan norm sub) float 
@ret_only_zero_sqrt(float %x) {
 define nofpclass(inf nan norm sub) <2 x float> @ret_only_zero_sqrt_vec(<2 x 
float> %x) {
 ; CHECK-LABEL: define nofpclass(nan inf sub norm) <2 x float> 
@ret_only_zero_sqrt_vec(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x 
float> [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x 
float> zeroinitializer, <2 x float> [[X]])
 ; CHECK-NEXT:    ret <2 x float> [[RESULT]]
 ;
   %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x)
@@ -83,8 +82,7 @@ define nofpclass(ninf) float 
@ret_only_ninf__sqrt__known_negative() {
 define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
 ; CHECK-LABEL: define nofpclass(ninf) float 
@ret_only_ninf__sqrt__known_negative_nonzero() {
 ; CHECK-NEXT:    [[KNOWN_NEGATIVE_NONZERO:%.*]] = call float 
@returns_negative_nonzero()
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float 
[[KNOWN_NEGATIVE_NONZERO]])
-; CHECK-NEXT:    ret float [[RESULT]]
+; CHECK-NEXT:    ret float 0x7FF8000000000000
 ;
   %known.negative.nonzero = call float @returns_negative_nonzero()
   %result = call float @llvm.sqrt.f32(float %known.negative.nonzero)
@@ -95,8 +93,7 @@ define nofpclass(ninf) float 
@ret_only_ninf__sqrt__known_negative_nonzero() {
 define nofpclass(ninf) <2 x float> 
@ret_only_ninf__sqrt__known_negative_nonzero_vec() {
 ; CHECK-LABEL: define nofpclass(ninf) <2 x float> 
@ret_only_ninf__sqrt__known_negative_nonzero_vec() {
 ; CHECK-NEXT:    [[KNOWN_NEGATIVE_NONZERO:%.*]] = call <2 x float> 
@returns_negative_nonzero_vec()
-; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x 
float> [[KNOWN_NEGATIVE_NONZERO]])
-; CHECK-NEXT:    ret <2 x float> [[RESULT]]
+; CHECK-NEXT:    ret <2 x float> splat (float 0x7FF8000000000000)
 ;
   %known.negative.nonzero = call <2 x float> @returns_negative_nonzero_vec()
   %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> 
%known.negative.nonzero)
@@ -149,7 +146,7 @@ define nofpclass(inf norm zero) float 
@ret_only_nan_or_sub__sqrt__select_unknown
 ; CHECK-LABEL: define nofpclass(inf zero norm) float 
@ret_only_nan_or_sub__sqrt__select_unknown_or_maybe_ninf(
 ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[X:%.*]]) {
 ; CHECK-NEXT:    [[MAYBE_NINF:%.*]] = call nofpclass(nan pinf sub norm) float 
@func()
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float 
[[MAYBE_NINF]]
+; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float 
0xFFF0000000000000
 ; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
@@ -193,8 +190,7 @@ define nofpclass(inf norm zero) float 
@ret_only_nan_or_sub__sqrt__select_unknown
 define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(i1 %cond, 
float %unknown) {
 ; CHECK-LABEL: define nofpclass(pinf) float 
@no_pinf_result_implies_no_pinf_source(
 ; CHECK-SAME: i1 [[COND:%.*]], float [[UNKNOWN:%.*]]) {
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[UNKNOWN]], float 
0x7FF0000000000000
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[UNKNOWN]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %select = select i1 %cond, float %unknown, float 0x7ff0000000000000

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to