[llvm-branch-commits] [llvm] InstCombine: Implement SimplifyDemandedFPClass for sqrt (PR #173883)
Matt Arsenault via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Dec 30 10:46:35 PST 2025
https://github.com/arsenm updated https://github.com/llvm/llvm-project/pull/173883
>From 16a8055724792ff059017b93d01f9de209c756f6 Mon Sep 17 00:00:00 2001
From: Matt Arsenault <Matthew.Arsenault at amd.com>
Date: Tue, 23 Dec 2025 20:36:25 +0100
Subject: [PATCH] InstCombine: Implement SimplifyDemandedFPClass for sqrt
---
llvm/include/llvm/Support/KnownFPClass.h | 4 +++
llvm/lib/Analysis/ValueTracking.cpp | 29 ++++++-----------
llvm/lib/Support/KnownFPClass.cpp | 24 ++++++++++++++
.../InstCombineSimplifyDemanded.cpp | 31 +++++++++++++++++++
.../simplify-demanded-fpclass-sqrt.ll | 18 +++++------
5 files changed, 76 insertions(+), 30 deletions(-)
diff --git a/llvm/include/llvm/Support/KnownFPClass.h b/llvm/include/llvm/Support/KnownFPClass.h
index 07d74f2867089..ae9513bbebe80 100644
--- a/llvm/include/llvm/Support/KnownFPClass.h
+++ b/llvm/include/llvm/Support/KnownFPClass.h
@@ -267,6 +267,10 @@ struct KnownFPClass {
static LLVM_ABI KnownFPClass
log(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic());
+ /// Propagate known class for sqrt
+ static LLVM_ABI KnownFPClass
+ sqrt(const KnownFPClass &Src, DenormalMode Mode = DenormalMode::getDynamic());
+
void resetAll() { *this = KnownFPClass(); }
};
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 4fbbfd1a0cf12..543a2992f9a46 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -5131,27 +5131,18 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
computeKnownFPClass(II->getArgOperand(0), DemandedElts, InterestedSrcs,
KnownSrc, Q, Depth + 1);
- if (KnownSrc.isKnownNeverPosInfinity())
- Known.knownNot(fcPosInf);
- if (KnownSrc.isKnownNever(fcSNan))
- Known.knownNot(fcSNan);
-
- // Any negative value besides -0 returns a nan.
- if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
- Known.knownNot(fcNan);
-
- // The only negative value that can be returned is -0 for -0 inputs.
- Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
+ DenormalMode Mode = DenormalMode::getDynamic();
- // If the input denormal mode could be PreserveSign, a negative
- // subnormal input could produce a negative zero output.
- const Function *F = II->getFunction();
- const fltSemantics &FltSem =
- II->getType()->getScalarType()->getFltSemantics();
+ bool HasNSZ = Q.IIQ.hasNoSignedZeros(II);
+ if (!HasNSZ) {
+ const Function *F = II->getFunction();
+ const fltSemantics &FltSem =
+ II->getType()->getScalarType()->getFltSemantics();
+ Mode = F ? F->getDenormalMode(FltSem) : DenormalMode::getDynamic();
+ }
- if (Q.IIQ.hasNoSignedZeros(II) ||
- (F &&
- KnownSrc.isKnownNeverLogicalNegZero(F->getDenormalMode(FltSem))))
+ Known = KnownFPClass::sqrt(KnownSrc, Mode);
+ if (HasNSZ)
Known.knownNot(fcNegZero);
break;
diff --git a/llvm/lib/Support/KnownFPClass.cpp b/llvm/lib/Support/KnownFPClass.cpp
index ff98908fdb2c4..afa08c1fd047f 100644
--- a/llvm/lib/Support/KnownFPClass.cpp
+++ b/llvm/lib/Support/KnownFPClass.cpp
@@ -243,3 +243,27 @@ KnownFPClass KnownFPClass::log(const KnownFPClass &KnownSrc,
return Known;
}
+
+KnownFPClass KnownFPClass::sqrt(const KnownFPClass &KnownSrc,
+ DenormalMode Mode) {
+ KnownFPClass Known;
+
+ if (KnownSrc.isKnownNeverPosInfinity())
+ Known.knownNot(fcPosInf);
+ if (KnownSrc.isKnownNever(fcSNan))
+ Known.knownNot(fcSNan);
+
+ // Any negative value besides -0 returns a nan.
+ if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
+ Known.knownNot(fcNan);
+
+ // The only negative value that can be returned is -0 for -0 inputs.
+ Known.knownNot(fcNegInf | fcNegSubnormal | fcNegNormal);
+
+ // If the input denormal mode could be PreserveSign, a negative
+ // subnormal input could produce a negative zero output.
+ if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
+ Known.knownNot(fcNegZero);
+
+ return Known;
+}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 39e8cd5a5c6c5..77fb794d160ba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -2239,6 +2239,37 @@ Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Value *V,
FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
}
+ case Intrinsic::sqrt: {
+ FPClassTest DemandedSrcMask =
+ DemandedMask & (fcNegZero | fcPositive | fcNan);
+
+ if (DemandedMask & fcNan)
+ DemandedSrcMask |= (fcNegative & ~fcNegZero);
+
+ KnownFPClass KnownSrc;
+ if (SimplifyDemandedFPClass(I, 0, DemandedSrcMask, KnownSrc, Depth + 1))
+ return I;
+
+ Type *EltTy = VTy->getScalarType();
+ DenormalMode Mode = F.getDenormalMode(EltTy->getFltSemantics());
+
+ // sqrt(-x) = nan, but be careful of negative subnormals flushed to 0.
+ if (KnownSrc.isKnownNever(fcPositive) &&
+ KnownSrc.isKnownNeverLogicalZero(Mode))
+ return ConstantFP::getQNaN(VTy);
+
+ Known = KnownFPClass::sqrt(KnownSrc, Mode);
+ FPClassTest ValidResults = DemandedMask & Known.KnownFPClasses;
+
+ if (ValidResults == fcZero) {
+ Value *Copysign = Builder.CreateCopySign(ConstantFP::getZero(VTy),
+ CI->getArgOperand(0));
+ Copysign->takeName(CI);
+ return Copysign;
+ }
+
+ return getFPClassConstant(VTy, ValidResults, /*IsCanonicalizing=*/true);
+ }
case Intrinsic::canonicalize: {
Type *EltTy = VTy->getScalarType();
diff --git a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
index 9288bb7be3ecd..ad9881d70b5fd 100644
--- a/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/simplify-demanded-fpclass-sqrt.ll
@@ -18,8 +18,7 @@ declare nofpclass(pinf pnorm psub zero) float @returns_negative_nonzero_or_nan()
define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) {
; CHECK-LABEL: define nofpclass(inf zero sub norm) float @ret_only_nan_sqrt(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]])
-; CHECK-NEXT: ret float [[RESULT]]
+; CHECK-NEXT: ret float 0x7FF8000000000000
;
%result = call float @llvm.sqrt.f32(float %x)
ret float %result
@@ -30,7 +29,7 @@ define nofpclass(inf norm sub zero) float @ret_only_nan_sqrt(float %x) {
define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) {
; CHECK-LABEL: define nofpclass(nan inf sub norm) float @ret_only_zero_sqrt(
; CHECK-SAME: float [[X:%.*]]) {
-; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[X]])
+; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.copysign.f32(float 0.000000e+00, float [[X]])
; CHECK-NEXT: ret float [[RESULT]]
;
%result = call float @llvm.sqrt.f32(float %x)
@@ -40,7 +39,7 @@ define nofpclass(inf nan norm sub) float @ret_only_zero_sqrt(float %x) {
define nofpclass(inf nan norm sub) <2 x float> @ret_only_zero_sqrt_vec(<2 x float> %x) {
; CHECK-LABEL: define nofpclass(nan inf sub norm) <2 x float> @ret_only_zero_sqrt_vec(
; CHECK-SAME: <2 x float> [[X:%.*]]) {
-; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[X]])
+; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.copysign.v2f32(<2 x float> zeroinitializer, <2 x float> [[X]])
; CHECK-NEXT: ret <2 x float> [[RESULT]]
;
%result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x)
@@ -83,8 +82,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative() {
define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
; CHECK-LABEL: define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
; CHECK-NEXT: [[KNOWN_NEGATIVE_NONZERO:%.*]] = call float @returns_negative_nonzero()
-; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[KNOWN_NEGATIVE_NONZERO]])
-; CHECK-NEXT: ret float [[RESULT]]
+; CHECK-NEXT: ret float 0x7FF8000000000000
;
%known.negative.nonzero = call float @returns_negative_nonzero()
%result = call float @llvm.sqrt.f32(float %known.negative.nonzero)
@@ -95,8 +93,7 @@ define nofpclass(ninf) float @ret_only_ninf__sqrt__known_negative_nonzero() {
define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() {
; CHECK-LABEL: define nofpclass(ninf) <2 x float> @ret_only_ninf__sqrt__known_negative_nonzero_vec() {
; CHECK-NEXT: [[KNOWN_NEGATIVE_NONZERO:%.*]] = call <2 x float> @returns_negative_nonzero_vec()
-; CHECK-NEXT: [[RESULT:%.*]] = call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[KNOWN_NEGATIVE_NONZERO]])
-; CHECK-NEXT: ret <2 x float> [[RESULT]]
+; CHECK-NEXT: ret <2 x float> splat (float 0x7FF8000000000000)
;
%known.negative.nonzero = call <2 x float> @returns_negative_nonzero_vec()
%result = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %known.negative.nonzero)
@@ -149,7 +146,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown
; CHECK-LABEL: define nofpclass(inf zero norm) float @ret_only_nan_or_sub__sqrt__select_unknown_or_maybe_ninf(
; CHECK-SAME: i1 [[COND:%.*]], float nofpclass(nan) [[X:%.*]]) {
; CHECK-NEXT: [[MAYBE_NINF:%.*]] = call nofpclass(nan pinf sub norm) float @func()
-; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float [[MAYBE_NINF]]
+; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[X]], float 0xFFF0000000000000
; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
; CHECK-NEXT: ret float [[RESULT]]
;
@@ -193,8 +190,7 @@ define nofpclass(inf norm zero) float @ret_only_nan_or_sub__sqrt__select_unknown
define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(i1 %cond, float %unknown) {
; CHECK-LABEL: define nofpclass(pinf) float @no_pinf_result_implies_no_pinf_source(
; CHECK-SAME: i1 [[COND:%.*]], float [[UNKNOWN:%.*]]) {
-; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[COND]], float [[UNKNOWN]], float 0x7FF0000000000000
-; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[SELECT]])
+; CHECK-NEXT: [[RESULT:%.*]] = call float @llvm.sqrt.f32(float [[UNKNOWN]])
; CHECK-NEXT: ret float [[RESULT]]
;
%select = select i1 %cond, float %unknown, float 0x7ff0000000000000
More information about the llvm-branch-commits
mailing list