[llvm-branch-commits] [llvm] InstCombine: Infer fast math flags for sqrt (PR #176003)
Matt Arsenault via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 14 11:24:17 PST 2026
https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/176003
>From 4d6a3911109ece865521de43104e2d7102b41062 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Wed, 14 Jan 2026 18:40:41 +0100
Subject: [PATCH] InstCombine: Infer fast math flags for sqrt
---
.../InstCombineSimplifyDemanded.cpp | 40 ++++++++++--
.../simplify-demanded-fpclass-sqrt.ll | 65 +++++++++++++++++--
2 files changed, 96 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 87a03a558e5ac..b198e5e824f1f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -2036,9 +2036,9 @@ static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask,
/// Try to set an inferred no-nans or no-infs in \p FMF. \p ValidResults is a
/// mask of known valid results for the operator (already computed from the
/// result, and the known operand inputs in \p Known)
-static FastMathFlags
-inferFastMathValueFlags(FastMathFlags FMF, FPClassTest ValidResults,
- ArrayRef<const KnownFPClass> Known) {
+static FastMathFlags inferFastMathValueFlags(FastMathFlags FMF,
+ FPClassTest ValidResults,
+ ArrayRef<KnownFPClass> Known) {
if (!FMF.noNaNs() && (ValidResults & fcNan) == fcNone) {
if (all_of(Known, [](const KnownFPClass KnownSrc) {
return KnownSrc.isKnownNeverNaN();
@@ -2056,6 +2056,29 @@ inferFastMathValueFlags(FastMathFlags FMF, FPClassTest ValidResults,
return FMF;
}
+/// Apply epilog fixups to a floating-point intrinsic. See if the result can
+/// fold to a constant, or apply fast math flags.
+static Value *simplifyDemandedFPClassResult(CallInst *FPOp, FastMathFlags FMF,
+ FPClassTest DemandedMask,
+ KnownFPClass &Known,
+ ArrayRef<KnownFPClass> KnownSrcs) {
+ FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
+ Constant *SingleVal = getFPClassConstant(FPOp->getType(), ValidResults,
+ /*IsCanonicalizing=*/true);
+ if (SingleVal)
+ return SingleVal;
+
+ FastMathFlags InferredFMF =
+ inferFastMathValueFlags(FMF, ValidResults, KnownSrcs);
+ if (InferredFMF != FMF) {
+ FPOp->dropUBImplyingAttrsAndMetadata();
+ FPOp->setFastMathFlags(InferredFMF);
+ return FPOp;
+ }
+
+ return nullptr;
+}
+
static Value *
simplifyDemandedFPClassMinMax(KnownFPClass &Known, Intrinsic::ID IID,
const CallInst *CI, FPClassTest DemandedMask,
@@ -2771,6 +2794,14 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I,
if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1))
return I;
+ // Infer the source cannot be negative if the result cannot be nan.
+ if ((DemandedMask & fcNan) == fcNone)
+ KnownSrc.knownNot((fcNegative & ~fcNegZero) | fcNan);
+
+ // Infer the source cannot be +inf if the result is not +nf
+ if ((DemandedMask & fcPosInf) == fcNone)
+ KnownSrc.knownNot(fcPosInf);
+
Type *EltTy = VTy->getScalarType();
DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics());
@@ -2792,7 +2823,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Instruction *I,
return Copysign;
}
- return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
+ return simplifyDemandedFPClassResult(CI, FMF, DemandedMask, Known,
+ {KnownSrc});
}
case Intrinsic::trunc:
case Intrinsic::floor:
diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
index b09faf0f4c3af..9b5274b9e7459 100644
--- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
@@ -220,7 +220,7 @@ define nofpclass(nan inf zero sub nnorm) float @pnorm_result_demands_pnorm_sourc
; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) {
; CHECK-NEXT: [[ONLY_PNORM:%.*]] = call nofpclass(nan inf zero sub nnorm) float @func()
; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PNORM]]
-; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf float @llvm.sqrt.f32(float [[SELECT]])
; CHECK-NEXT: ret float [[RESULT]]
;
%only.pnorm = call nofpclass(nan inf nnorm sub zero) float @func()
@@ -234,7 +234,7 @@ define nofpclass(nan inf zero sub nnorm) float @pnorm_result_demands_psub_source
; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[NOT_NAN:%.*]]) {
; CHECK-NEXT: [[ONLY_PSUB:%.*]] = call nofpclass(nan inf zero nsub norm) float @func()
; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[NOT_NAN]], float [[ONLY_PSUB]]
-; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf float @llvm.sqrt.f32(float [[SELECT]])
; CHECK-NEXT: ret float [[RESULT]]
;
%only.psub = call nofpclass(nan inf norm nsub zero) float @func()
@@ -258,7 +258,7 @@ define nofpclass(nan inf zero nsub norm) float @psub_result_implies_not_pnorm_so
define nofpclass(nan) float @ret_no_nan__sqrt(float %x) {
; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]])
; CHECK-NEXT: ret float [[RESULT]]
;
%result = call contract float @llvm.sqrt.f32(float %x)
@@ -278,7 +278,7 @@ define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_inputs(float nofpclass(n
define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(float nofpclass(nan ninf nnorm nsub) %x) {
; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(
; CHECK-SAME: float nofpclass(nan ninf nsub nnorm) [[X:%.*]]) {
-; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]])
; CHECK-NEXT: ret float [[RESULT]]
;
%result = call contract float @llvm.sqrt.f32(float %x)
@@ -289,12 +289,67 @@ define nofpclass(snan) float @ret_no_snan__sqrt__no_neg_or_nan_inputs(float nofp
define nofpclass(snan) float @ret_no_snan__noundef_sqrt__no_neg_or_nan_inputs(float nofpclass(nan ninf nnorm nsub) %x) {
; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__noundef_sqrt__no_neg_or_nan_inputs(
; CHECK-SAME: float nofpclass(nan ninf nsub nnorm) [[X:%.*]]) {
-; CHECK-NEXT: [[RESULT:%.*]] = call contract noundef float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan contract float @llvm.sqrt.f32(float [[X]])
; CHECK-NEXT: ret float [[RESULT]]
;
%result = call contract noundef float @llvm.sqrt.f32(float %x)
ret float %result
}
+define nofpclass(snan) float @ret_no_snan__sqrt__no_pinf_inputs(float nofpclass(pinf) %x) {
+; CHECK-LABEL: define nofpclass(snan) float @ret_no_snan__sqrt__no_pinf_inputs(
+; CHECK-SAME: float nofpclass(pinf) [[X:%.*]]) {
+; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: ret float [[RESULT]]
+;
+ %result = call contract float @llvm.sqrt.f32(float %x)
+ ret float %result
+}
+
+; Cannot infer flags. A nan output could still be produced by a -inf
+; input.
+define nofpclass(pinf) float @ret_no_pinf__sqrt(float %x) {
+; CHECK-LABEL: define nofpclass(pinf) float @ret_no_pinf__sqrt(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[RESULT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: ret float [[RESULT]]
+;
+ %result = call contract float @llvm.sqrt.f32(float %x)
+ ret float %result
+}
+
+; Infer nnan and ninf
+define nofpclass(nan pinf) float @ret_no_pinf_no_nan__sqrt(float %x) {
+; CHECK-LABEL: define nofpclass(nan pinf) float @ret_no_pinf_no_nan__sqrt(
+; CHECK-SAME: float [[X:%.*]]) {
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: ret float [[RESULT]]
+;
+ %result = call contract float @llvm.sqrt.f32(float %x)
+ ret float %result
+}
+
+; Infer nnan and ninf
+define nofpclass(nan) float @ret_no_nan__sqrt__no_pinf_inputs(float nofpclass(pinf) %x) {
+; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt__no_pinf_inputs(
+; CHECK-SAME: float nofpclass(pinf) [[X:%.*]]) {
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: ret float [[RESULT]]
+;
+ %result = call contract float @llvm.sqrt.f32(float %x)
+ ret float %result
+}
+
+; Infer nnan and ninf
+define nofpclass(nan) float @ret_no_nan__sqrt__no_inf_inputs(float nofpclass(inf) %x) {
+; CHECK-LABEL: define nofpclass(nan) float @ret_no_nan__sqrt__no_inf_inputs(
+; CHECK-SAME: float nofpclass(inf) [[X:%.*]]) {
+; CHECK-NEXT: [[RESULT:%.*]] = call nnan ninf contract float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: ret float [[RESULT]]
+;
+ %result = call contract float @llvm.sqrt.f32(float %x)
+ ret float %result
+}
+
attributes #0 = { "denormal-fp-math"="preserve-sign,preserve-sign" }
attributes #1 = { "denormal-fp-math"="dynamic,dynamic" }
More information about the llvm-branch-commits
mailing list