[llvm] 99ab1dd - InstCombine: Implement SimplifyDemandedFPClass for sqrt (#173883)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 11 03:42:45 PST 2026


Author: Matt Arsenault
Date: 2026-01-11T11:42:41Z
New Revision: 99ab1dd145dd4e9a3a149568e65d3901b4245d4b

URL: https://github.com/llvm/llvm-project/commit/99ab1dd145dd4e9a3a149568e65d3901b4245d4b
DIFF: https://github.com/llvm/llvm-project/commit/99ab1dd145dd4e9a3a149568e65d3901b4245d4b.diff

LOG: InstCombine: Implement SimplifyDemandedFPClass for sqrt (#173883)

Added: 
    

Modified: 
    llvm/include/llvm/Support/KnownFPClass.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Support/KnownFPClass.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
    llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/KnownFPClass.h b/llvm/include/llvm/Support/KnownFPClass.h
index 2348927fd9b9a..5a3ed677bd762 100644
--- a/llvm/include/llvm/Support/KnownFPClass.h
+++ b/llvm/include/llvm/Support/KnownFPClass.h
@@ -290,6 +290,10 @@ struct KnownFPClass {
   static LLVM_ABI KnownFPClass
   log(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic());
 
+  /// Propagate known class for sqrt
+  static LLVM_ABI KnownFPClass
+  sqrt(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic());
+
   /// Propagate known class for fpext.
   static LLVM_ABI KnownFPClass fpext(const KnownFPClass &KnownSrc,
                                      const fltSemantics &DstTy,

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0c76522efdc29..90e23bf81d99e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -5122,27 +5122,18 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
       computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
                           KnownSrc, Q, Depth + 1);
 
-      if (KnownSrc.isKnownNeverPosInfinity())
-        Known.knownNot(fcPosInf);
-      if (KnownSrc.isKnownNever(fcSNan))
-        Known.knownNot(fcSNan);
-
-      // Any negative value besides -0 returns a nan.
-      if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
-        Known.knownNot(fcNan);
-
-      // The only negative value that can be returned is -0 for -0 inputs.
-      Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
+      DenormalMode Mode = DenormalMode::getDynamic();
 
-      // If the input denormal mode could be PreserveSign, a negative
-      // subnormal input could produce a negative zero output.
-      const Function *F = II->getFunction();
-      const fltSemantics &FltSem =
-          II->getType()->getScalarType()->getFltSemantics();
+      bool HasNSZ = Q.IIQ.hasNoSignedZeros(II);
+      if (!HasNSZ) {
+        const Function *F = II->getFunction();
+        const fltSemantics &FltSem =
+            II->getType()->getScalarType()->getFltSemantics();
+        Mode = F ? F->getDenormalMode(FltSem) : DenormalMode::getDynamic();
+      }
 
-      if (Q.IIQ.hasNoSignedZeros(II) ||
-          (F &&
-           KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem))))
+      Known = KnownFPClass::sqrt(KnownSrc, Mode);
+      if (HasNSZ)
         Known.knownNot(fcNegZero);
 
       break;

diff  --git a/llvm/lib/Support/KnownFPClass.cpp b/llvm/lib/Support/KnownFPClass.cpp
index 7a60221f7e8f8..e621da9e39ed7 100644
--- a/llvm/lib/Support/KnownFPClass.cpp
+++ b/llvm/lib/Support/KnownFPClass.cpp
@@ -330,6 +330,30 @@ KnownFPClass KnownFPClass::log(const KnownFPClass &KnownSrc,
   return Known;
 }
 
+KnownFPClass KnownFPClass::sqrt(const KnownFPClass &KnownSrc,
+                                DenormalMode Mode) {
+  KnownFPClass Known;
+
+  if (KnownSrc.isKnownNeverPosInfinity())
+    Known.knownNot(fcPosInf);
+  if (KnownSrc.isKnownNever(fcSNan))
+    Known.knownNot(fcSNan);
+
+  // Any negative value besides -0 returns a nan.
+  if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
+    Known.knownNot(fcNan);
+
+  // The only negative value that can be returned is -0 for -0 inputs.
+  Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
+
+  // If the input denormal mode could be PreserveSign, a negative
+  // subnormal input could produce a negative zero output.
+  if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
+    Known.knownNot(fcNegZero);
+
+  return Known;
+}
+
 KnownFPClass KnownFPClass::fpext(const KnownFPClass &KnownSrc,
                                  const fltSemantics &DstTy,
                                  const fltSemantics &SrcTy) {

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 92b4bfe4ec54e..22064638bb229 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -2600,6 +2600,44 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I,
       FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
       return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
     }
+    case Intrinsic::sqrt: {
+      FPClassTest DemandedSrcMask =
+          DemandedMask & (fcNegZero | fcPositive | fcNan);
+
+      if (DemandedMask & fcNan)
+        DemandedSrcMask |= (fcNegative & ~fcNegZero);
+
+      // sqrt(max_subnormal) is a normal value
+      if (DemandedMask & fcPosNormal)
+        DemandedSrcMask |= fcPosSubnormal;
+
+      KnownFPClass KnownSrc;
+      if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1))
+        return I;
+
+      Type *EltTy = VTy->getScalarType();
+      DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics());
+
+      // sqrt(-x) = nan, but be careful of negative subnormals flushed to 0.
+      if (KnownSrc.isKnownNever(fcPositive) &&
+          KnownSrc.isKnownNeverLogicalZero(Mode))
+        return ConstantFP::getQNaN(VTy);
+
+      Known = KnownFPClass::sqrt(KnownSrc, Mode);
+      FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
+
+      if (ValidResults == fcZero) {
+        if (FMF.noSignedZeros())
+          return ConstantFP::getZero(VTy);
+
+        Value *Copysign = Builder.CreateCopySign(ConstantFP::getZero(VTy),
+                                                 CI->getArgOperand(0), FMF);
+        Copysign->takeName(CI);
+        return Copysign;
+      }
+
+      return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
+    }
     case Intrinsic::canonicalize: {
       Type *EltTy = VTy->getScalarType();
 

diff  --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
index 447b3786109ee..234d97c05f3aa 100644
--- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
@@ -18,8 +18,7 @@ declare nofpclass(pinf pnorm psub zero) float @returns_negative_nonzero_or_nan()
 define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) {
 ; CHECK-LABEL: define nofpclass(inf zero sub norm) float @ret_only_nan_sqrt(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]])
-; CHECK-NEXT:    ret float [[RESULT]]
+; CHECK-NEXT:    ret float 0x7FF8000000000000
 ;
   %result = call float @llvm.sqrt.f32(float %x)
   ret float %result
@@ -30,7 +29,7 @@ define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) {
 define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) {
 ; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.copysign.f32(float 0.000000e+00, float [[X]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %result = call float @llvm.sqrt.f32(float %x)
@@ -40,7 +39,7 @@ define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) {
 define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt_preserve_flags(float %x) {
 ; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt_preserve_flags(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call ninf float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call ninf float @llvm.copysign.f32(float 0.000000e+00, float [[X]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %result = call ninf float @llvm.sqrt.f32(float %x)
@@ -59,7 +58,7 @@ define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt_nsz(float %x) {
 define nofpclass(inf nan norm sub) <2 x float> @ret_only_zero_sqrt_vec(<2 x float> %x) {
 ; CHECK-LABEL: define nofpclass(nan inf sub norm) <2 x float> @ret_only_zero_sqrt_vec(
 ; CHECK-SAME: <2 x float> [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> zeroinitializer, <2 x float> [[X]])
 ; CHECK-NEXT:    ret <2 x float> [[RESULT]]
 ;
   %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x)
@@ -102,8 +101,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative() {
 define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
 ; CHECK-LABEL: define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
 ; CHECK-NEXT:    [[KNOWN_NEGATIVE_NONZERO:%.*]] = call float @returns_negative_nonzero()
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[KNOWN_NEGATIVE_NONZERO]])
-; CHECK-NEXT:    ret float [[RESULT]]
+; CHECK-NEXT:    ret float 0x7FF8000000000000
 ;
   %known.negative.nonzero = call float @returns_negative_nonzero()
   %result = call float @llvm.sqrt.f32(float %known.negative.nonzero)
@@ -114,8 +112,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
 define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() {
 ; CHECK-LABEL: define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() {
 ; CHECK-NEXT:    [[KNOWN_NEGATIVE_NONZERO:%.*]] = call <2 x float> @returns_negative_nonzero_vec()
-; CHECK-NEXT:    [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[KNOWN_NEGATIVE_NONZERO]])
-; CHECK-NEXT:    ret <2 x float> [[RESULT]]
+; CHECK-NEXT:    ret <2 x float> splat (float 0x7FF8000000000000)
 ;
   %known.negative.nonzero = call <2 x float> @returns_negative_nonzero_vec()
   %result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %known.negative.nonzero)
@@ -168,7 +165,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown
 ; CHECK-LABEL: define nofpclass(inf zero norm) float @ret_only_nan_or_sub__sqrt__select_unknown_or_maybe_ninf(
 ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[X:%.*]]) {
 ; CHECK-NEXT:    [[MAYBE_NINF:%.*]] = call nofpclass(nan pinf sub norm) float @func()
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float [[MAYBE_NINF]]
+; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float 0xFFF0000000000000
 ; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
@@ -212,8 +209,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown
 define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(i1 %cond, float %unknown) {
 ; CHECK-LABEL: define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(
 ; CHECK-SAME: i1 [[COND:%.*]], float [[UNKNOWN:%.*]]) {
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[UNKNOWN]], float 0x7FF0000000000000
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[UNKNOWN]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %select = select i1 %cond, float %unknown, float 0x7ff0000000000000
@@ -253,8 +249,7 @@ define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_so
 ; CHECK-LABEL: define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_source(
 ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) {
 ; CHECK-NEXT:    [[ONLY_PNORM:%.*]] = call nofpclass(nan inf zero sub nnorm) float @func()
-; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PNORM]]
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[NOT_NAN]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %only.pnorm = call nofpclass(nan inf nnorm sub zero) float @func()


        


More information about the llvm-commits mailing list