[llvm-branch-commits] [llvm] InstCombine: Infer fast math flags for sqrt (PR #176003)

Matt Arsenault via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 14 10:01:06 PST 2026


https://github.com/arsenm created https://github.com/llvm/llvm-project/pull/176003

None

>From 82372b2c36018ae7b2abc9af106e07bbc8bac689 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Wed, 14 Jan 2026 18:40:41 +0100
Subject: [PATCH] InstCombine: Infer fast math flags for sqrt

---
 .../InstCombineSimplifyDemanded.cpp           | 39 +++++++++--
 .../simplify-demanded-fpclass-sqrt.ll         | 65 +++++++++++++++++--
 2 files changed, 95 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 3bbc4a913ada6..3292d3538b4e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -2036,9 +2036,9 @@ static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask,
 /// Try to set an inferred no-nans or no-infs in \p FMF. \p ValidResults is a
 /// mask of known valid results for the operator (already computed from the
 /// result, and the known operand inputs in \p Known)
-static FastMathFlags
-inferFastMathValueFlags(FastMathFlags FMF, FPClassTest ValidResults,
-                        ArrayRef<const KnownFPClass> Known) {
+static FastMathFlags inferFastMathValueFlags(FastMathFlags FMF,
+                                             FPClassTest ValidResults,
+                                             ArrayRef<KnownFPClass> Known) {
   if (!FMF.noNaNs() && (ValidResults & fcNan) == fcNone) {
     if (all_of(Known, [](const KnownFPClass KnownSrc) {
           return KnownSrc.isKnownNeverNaN();
@@ -2056,6 +2056,28 @@ inferFastMathValueFlags(FastMathFlags FMF, FPClassTest ValidResults,
   return FMF;
 }
 
+/// Apply epilog fixups to a floating-point intrinsic. See if the result can
+/// fold to a constant, or apply fast math flags.
+static Value *simplifyDemandedFPClassResult(CallInst *FPOp, FastMathFlags FMF,
+                                            FPClassTest DemandedMask,
+                                            KnownFPClass &Known,
+                                            ArrayRef<KnownFPClass> KnownSrcs) {
+  FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
+  Constant *SingleVal = getFPClassConstant(FPOp->getType(), ValidResults,
+                                           /*IsCanonicalizing=*/true);
+  if (SingleVal)
+    return SingleVal;
+
+  FastMathFlags InferredFMF =
+      inferFastMathValueFlags(FMF, ValidResults, KnownSrcs);
+  if (InferredFMF != FMF) {
+    FPOp->setFastMathFlags(InferredFMF);
+    return FPOp;
+  }
+
+  return nullptr;
+}
+
 Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I,
                                                     FPClassTest DemandedMask,
                                                     KnownFPClass &Known,
@@ -2790,6 +2812,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I,
       if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1))
         return I;
 
+      // Infer the source cannot be negative if the result cannot be nan.
+      if ((DemandedMask & fcNan) == fcNone)
+        KnownSrc.knownNot((fcNegative & ~fcNegZero) | fcNan);
+
+      // Infer the source cannot be +inf if the result is not +nf
+      if ((DemandedMask & fcPosInf) == fcNone)
+        KnownSrc.knownNot(fcPosInf);
+
       Type *EltTy = VTy->getScalarType();
       DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics());
 
@@ -2811,7 +2841,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I,
         return Copysign;
       }
 
-      return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
+      return simplifyDemandedFPClassResult(CI, FMF, DemandedMask, Known,
+                                           {KnownSrc});
     }
     case Intrinsic::trunc:
     case Intrinsic::floor:
diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
index b09faf0f4c3af..6ec5daa48e125 100644
--- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
@@ -220,7 +220,7 @@ define nofpclass(nan inf zero sub nnorm) float @pnorm_result_demands_pnorm_sourc
 ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) {
 ; CHECK-NEXT:    [[ONLY_PNORM:%.*]] = call nofpclass(nan inf zero sub nnorm) float @func()
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PNORM]]
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan ninf float @llvm.sqrt.f32(float [[SELECT]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %only.pnorm = call nofpclass(nan inf nnorm sub zero) float @func()
@@ -234,7 +234,7 @@ define nofpclass(nan inf zero sub nnorm) float @pnorm_result_demands_psub_source
 ; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) {
 ; CHECK-NEXT:    [[ONLY_PSUB:%.*]] = call nofpclass(nan inf zero nsub norm) float @func()
 ; CHECK-NEXT:    [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PSUB]]
-; CHECK-NEXT:    [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan ninf float @llvm.sqrt.f32(float [[SELECT]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %only.psub = call nofpclass(nan inf norm nsub zero) float @func()
@@ -258,7 +258,7 @@ define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_so
 define nofpclass(nan) float @ret_no_nan__sqrt(float %x) {
 ; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt(
 ; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %result = call contract float @llvm.sqrt.f32(float %x)
@@ -278,7 +278,7 @@ define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_inputs(float nofpclass(n
 define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(float nofpclass(nan ninf nnorm nsub) %x) {
 ; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(
 ; CHECK-SAME: float nofpclass(nan ninf nsub nnorm) [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %result = call contract float @llvm.sqrt.f32(float %x)
@@ -289,12 +289,67 @@ define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(float nofp
 define nofpclass(snan) float @ret_no_snan__noundef_sqrt__no_neg_or_nan_inputs(float nofpclass(nan ninf nnorm nsub) %x) {
 ; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__noundef_sqrt__no_neg_or_nan_inputs(
 ; CHECK-SAME: float nofpclass(nan ninf nsub nnorm) [[X:%.*]]) {
-; CHECK-NEXT:    [[RESULT:%.*]] = call contract noundef float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan contract noundef float @llvm.sqrt.f32(float [[X]])
 ; CHECK-NEXT:    ret float [[RESULT]]
 ;
   %result = call contract noundef float @llvm.sqrt.f32(float %x)
   ret float %result
 }
 
+define nofpclass(snan) float @ret_no_snan__sqrt__no_pinf_inputs(float nofpclass(pinf) %x) {
+; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__sqrt__no_pinf_inputs(
+; CHECK-SAME: float nofpclass(pinf) [[X:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    ret float [[RESULT]]
+;
+  %result = call contract float @llvm.sqrt.f32(float %x)
+  ret float %result
+}
+
+; Cannot infer flags. A nan output could still be produced by a -inf
+; input.
+define nofpclass(pinf) float @ret_no_pinf__sqrt(float %x) {
+; CHECK-LABEL: define nofpclass(pinf) float @ret_no_pinf__sqrt(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    ret float [[RESULT]]
+;
+  %result = call contract float @llvm.sqrt.f32(float %x)
+  ret float %result
+}
+
+; Infer nnan and ninf
+define nofpclass(nan pinf) float @ret_no_pinf_no_nan__sqrt(float %x) {
+; CHECK-LABEL: define nofpclass(nan pinf) float @ret_no_pinf_no_nan__sqrt(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    ret float [[RESULT]]
+;
+  %result = call contract float @llvm.sqrt.f32(float %x)
+  ret float %result
+}
+
+; Infer nnan and ninf
+define nofpclass(nan) float @ret_no_nan__sqrt__no_pinf_inputs(float nofpclass(pinf) %x) {
+; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt__no_pinf_inputs(
+; CHECK-SAME: float nofpclass(pinf) [[X:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    ret float [[RESULT]]
+;
+  %result = call contract float @llvm.sqrt.f32(float %x)
+  ret float %result
+}
+
+; Infer nnan and ninf
+define nofpclass(nan) float @ret_no_nan__sqrt__no_inf_inputs(float nofpclass(inf) %x) {
+; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt__no_inf_inputs(
+; CHECK-SAME: float nofpclass(inf) [[X:%.*]]) {
+; CHECK-NEXT:    [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT:    ret float [[RESULT]]
+;
+  %result = call contract float @llvm.sqrt.f32(float %x)
+  ret float %result
+}
+
 attributes #0 = { "denormal-fp-math"="preserve-sign,preserve-sign" }
 attributes #1 = { "denormal-fp-math"="dynamic,dynamic" }



More information about the llvm-branch-commits mailing list