[llvm] [ConstantFPRange] Implement `ConstantFPRange::makeAllowedFCmpRegion` (PR #110082)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 25 22:28:07 PDT 2024
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/110082
Note: the return type of `makeExactFCmpRegion` is changed to `std::optional<ConstantFPRange>` because I realized that we cannot represent the result of `makeExactFCmpRegion(one, X)` as a ConstantFPRange.
>From 203d316d77d216ed931537d0b77514c8e60a317b Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Thu, 26 Sep 2024 12:48:51 +0800
Subject: [PATCH] [ConstantFPRange] Implement
`ConstantFPRange::makeAllowedFCmpRegion`
---
llvm/include/llvm/IR/ConstantFPRange.h | 17 +++-
llvm/lib/IR/ConstantFPRange.cpp | 119 ++++++++++++++++++++--
llvm/unittests/IR/ConstantFPRangeTest.cpp | 45 ++++++++
3 files changed, 169 insertions(+), 12 deletions(-)
diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h
index 67f9f945d748ba..cab3a860eaf4ef 100644
--- a/llvm/include/llvm/IR/ConstantFPRange.h
+++ b/llvm/include/llvm/IR/ConstantFPRange.h
@@ -50,7 +50,6 @@ class [[nodiscard]] ConstantFPRange {
void makeEmpty();
void makeFull();
- bool isNaNOnly() const;
/// Initialize a full or empty set for the specified semantics.
explicit ConstantFPRange(const fltSemantics &Sem, bool IsFullSet);
@@ -78,6 +77,9 @@ class [[nodiscard]] ConstantFPRange {
/// Helper for (-inf, inf) to represent all finite values.
static ConstantFPRange getFinite(const fltSemantics &Sem);
+ /// Helper for [-inf, inf] to represent all non-NaN values.
+ static ConstantFPRange getNonNaN(const fltSemantics &Sem);
+
/// Create a range which doesn't contain NaNs.
static ConstantFPRange getNonNaN(APFloat LowerVal, APFloat UpperVal) {
return ConstantFPRange(std::move(LowerVal), std::move(UpperVal),
@@ -123,8 +125,10 @@ class [[nodiscard]] ConstantFPRange {
/// { x : fcmp op x y is true}'.
///
/// Example: Pred = olt and Other = float 3 returns [-inf, 3)
- static ConstantFPRange makeExactFCmpRegion(FCmpInst::Predicate Pred,
- const APFloat &Other);
+ /// If the exact answer is not representable as a ConstantFPRange, return
+ /// std::nullopt.
+ static std::optional<ConstantFPRange>
+ makeExactFCmpRegion(FCmpInst::Predicate Pred, const APFloat &Other);
/// Does the predicate \p Pred hold between ranges this and \p Other?
/// NOTE: false does not mean that inverse predicate holds!
@@ -139,6 +143,7 @@ class [[nodiscard]] ConstantFPRange {
bool containsNaN() const { return MayBeQNaN || MayBeSNaN; }
bool containsQNaN() const { return MayBeQNaN; }
bool containsSNaN() const { return MayBeSNaN; }
+ bool isNaNOnly() const;
/// Get the semantics of this ConstantFPRange.
const fltSemantics &getSemantics() const { return Lower.getSemantics(); }
@@ -157,10 +162,12 @@ class [[nodiscard]] ConstantFPRange {
bool contains(const ConstantFPRange &CR) const;
/// If this set contains a single element, return it, otherwise return null.
- const APFloat *getSingleElement() const;
+ const APFloat *getSingleElement(bool ExcludesNaN = false) const;
/// Return true if this set contains exactly one member.
- bool isSingleElement() const { return getSingleElement() != nullptr; }
+ bool isSingleElement(bool ExcludesNaN = false) const {
+ return getSingleElement(ExcludesNaN) != nullptr;
+ }
/// Return true if the sign bit of all values in this range is 1.
/// Return false if the sign bit of all values in this range is 0.
diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp
index 957701891c8f37..9f9e4f69a4079d 100644
--- a/llvm/lib/IR/ConstantFPRange.cpp
+++ b/llvm/lib/IR/ConstantFPRange.cpp
@@ -103,11 +103,115 @@ ConstantFPRange ConstantFPRange::getNaNOnly(const fltSemantics &Sem,
MayBeSNaN);
}
+ConstantFPRange ConstantFPRange::getNonNaN(const fltSemantics &Sem) {
+ return ConstantFPRange(APFloat::getInf(Sem, /*Negative=*/true),
+ APFloat::getInf(Sem, /*Negative=*/false),
+ /*MayBeQNaN=*/false, /*MayBeSNaN=*/false);
+}
+
+/// Return [-inf, V) or [-inf, V]
+static ConstantFPRange makeLessThan(APFloat V, FCmpInst::Predicate Pred) {
+ const fltSemantics &Sem = V.getSemantics();
+ if (!(Pred & FCmpInst::FCMP_OEQ)) {
+ if (V.isNegInfinity())
+ return ConstantFPRange::getEmpty(Sem);
+ V.next(/*nextDown=*/true);
+ }
+ return ConstantFPRange::getNonNaN(APFloat::getInf(Sem, /*Negative=*/true),
+ std::move(V));
+}
+
+/// Return (V, +inf] or [V, +inf]
+static ConstantFPRange makeGreaterThan(APFloat V, FCmpInst::Predicate Pred) {
+ const fltSemantics &Sem = V.getSemantics();
+ if (!(Pred & FCmpInst::FCMP_OEQ)) {
+ if (V.isPosInfinity())
+ return ConstantFPRange::getEmpty(Sem);
+ V.next(/*nextDown=*/false);
+ }
+ return ConstantFPRange::getNonNaN(std::move(V),
+ APFloat::getInf(Sem, /*Negative=*/false));
+}
+
+/// Make sure that +0/-0 are both included in the range.
+static ConstantFPRange extendZeroIfEqual(const ConstantFPRange &CR,
+ FCmpInst::Predicate Pred) {
+ if (!(Pred & FCmpInst::FCMP_OEQ))
+ return CR;
+
+ APFloat Lower = CR.getLower();
+ APFloat Upper = CR.getUpper();
+ if (Lower.isPosZero())
+ Lower = APFloat::getZero(Lower.getSemantics(), /*Negative=*/true);
+ if (Upper.isNegZero())
+ Upper = APFloat::getZero(Upper.getSemantics(), /*Negative=*/false);
+ return ConstantFPRange(std::move(Lower), std::move(Upper), CR.containsQNaN(),
+ CR.containsSNaN());
+}
+
+static ConstantFPRange setNaNField(const ConstantFPRange &CR,
+ FCmpInst::Predicate Pred) {
+ bool ContainsNaN = FCmpInst::isUnordered(Pred);
+ return ConstantFPRange(CR.getLower(), CR.getUpper(),
+ /*MayBeQNaN=*/ContainsNaN, /*MayBeSNaN=*/ContainsNaN);
+}
+
ConstantFPRange
ConstantFPRange::makeAllowedFCmpRegion(FCmpInst::Predicate Pred,
const ConstantFPRange &Other) {
- // TODO
- return getFull(Other.getSemantics());
+ if (Other.isEmptySet())
+ return Other;
+ if (Other.containsNaN() && FCmpInst::isUnordered(Pred))
+ return getFull(Other.getSemantics());
+ if (Other.isNaNOnly() && FCmpInst::isOrdered(Pred))
+ return getEmpty(Other.getSemantics());
+
+ switch (Pred) {
+ case FCmpInst::FCMP_TRUE:
+ return getFull(Other.getSemantics());
+ case FCmpInst::FCMP_FALSE:
+ return getEmpty(Other.getSemantics());
+ case FCmpInst::FCMP_ORD:
+ return getNonNaN(Other.getSemantics());
+ case FCmpInst::FCMP_UNO:
+ return getNaNOnly(Other.getSemantics(), /*MayBeQNaN=*/true,
+ /*MayBeSNaN=*/true);
+ case FCmpInst::FCMP_OEQ:
+ case FCmpInst::FCMP_UEQ:
+ return setNaNField(extendZeroIfEqual(Other, Pred), Pred);
+ case FCmpInst::FCMP_ONE:
+ case FCmpInst::FCMP_UNE:
+ if (const APFloat *SingleElement =
+ Other.getSingleElement(/*ExcludesNaN=*/true)) {
+ const fltSemantics &Sem = SingleElement->getSemantics();
+ if (SingleElement->isPosInfinity())
+ return setNaNField(
+ getNonNaN(APFloat::getInf(Sem, /*Negative=*/true),
+ APFloat::getLargest(Sem, /*Negative=*/false)),
+ Pred);
+ if (SingleElement->isNegInfinity())
+ return setNaNField(
+ getNonNaN(APFloat::getLargest(Sem, /*Negative=*/true),
+ APFloat::getInf(Sem, /*Negative=*/false)),
+ Pred);
+ }
+ return Pred == FCmpInst::FCMP_ONE ? getNonNaN(Other.getSemantics())
+ : getFull(Other.getSemantics());
+ case FCmpInst::FCMP_OLT:
+ case FCmpInst::FCMP_OLE:
+ case FCmpInst::FCMP_ULT:
+ case FCmpInst::FCMP_ULE:
+ return setNaNField(
+ extendZeroIfEqual(makeLessThan(Other.getUpper(), Pred), Pred), Pred);
+ case FCmpInst::FCMP_OGT:
+ case FCmpInst::FCMP_OGE:
+ case FCmpInst::FCMP_UGT:
+ case FCmpInst::FCMP_UGE:
+ return setNaNField(
+ extendZeroIfEqual(makeGreaterThan(Other.getLower(), Pred), Pred), Pred);
+ default:
+ llvm_unreachable("Unexpected predicate");
+ }
}
ConstantFPRange
@@ -117,9 +221,10 @@ ConstantFPRange::makeSatisfyingFCmpRegion(FCmpInst::Predicate Pred,
return getEmpty(Other.getSemantics());
}
-ConstantFPRange ConstantFPRange::makeExactFCmpRegion(FCmpInst::Predicate Pred,
- const APFloat &Other) {
- return makeAllowedFCmpRegion(Pred, ConstantFPRange(Other));
+std::optional<ConstantFPRange>
+ConstantFPRange::makeExactFCmpRegion(FCmpInst::Predicate Pred,
+ const APFloat &Other) {
+ return std::nullopt;
}
bool ConstantFPRange::fcmp(FCmpInst::Predicate Pred,
@@ -161,8 +266,8 @@ bool ConstantFPRange::contains(const ConstantFPRange &CR) const {
strictCompare(CR.Upper, Upper) != APFloat::cmpGreaterThan;
}
-const APFloat *ConstantFPRange::getSingleElement() const {
- if (MayBeSNaN || MayBeQNaN)
+const APFloat *ConstantFPRange::getSingleElement(bool ExcludesNaN) const {
+ if (!ExcludesNaN && (MayBeSNaN || MayBeQNaN))
return nullptr;
return Lower.bitwiseIsEqual(Upper) ? &Lower : nullptr;
}
diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp
index 722e6566730da5..1fe9231392d622 100644
--- a/llvm/unittests/IR/ConstantFPRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp
@@ -161,6 +161,19 @@ static void EnumerateValuesInConstantFPRange(const ConstantFPRange &CR,
}
}
+template <typename Fn>
+static bool AnyOfValueInConstantFPRange(const ConstantFPRange &CR, Fn TestFn) {
+ const fltSemantics &Sem = CR.getSemantics();
+ unsigned Bits = APFloat::semanticsSizeInBits(Sem);
+ assert(Bits < 32 && "Too many bits");
+ for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
+ APFloat V(Sem, APInt(Bits, I));
+ if (CR.contains(V) && TestFn(V))
+ return true;
+ }
+ return false;
+}
+
TEST_F(ConstantFPRangeTest, Basics) {
EXPECT_TRUE(Full.isFullSet());
EXPECT_FALSE(Full.isEmptySet());
@@ -263,12 +276,16 @@ TEST_F(ConstantFPRangeTest, SingleElement) {
EXPECT_EQ(*One.getSingleElement(), APFloat(1.0));
EXPECT_EQ(*PosZero.getSingleElement(), APFloat::getZero(Sem));
EXPECT_EQ(*PosInf.getSingleElement(), APFloat::getInf(Sem));
+ ConstantFPRange PosZeroOrNaN = PosZero.unionWith(NaN);
+ EXPECT_EQ(*PosZeroOrNaN.getSingleElement(/*ExcludesNaN=*/true),
+ APFloat::getZero(Sem));
EXPECT_FALSE(Full.isSingleElement());
EXPECT_FALSE(Empty.isSingleElement());
EXPECT_TRUE(One.isSingleElement());
EXPECT_FALSE(Some.isSingleElement());
EXPECT_FALSE(Zero.isSingleElement());
+ EXPECT_TRUE(PosZeroOrNaN.isSingleElement(/*ExcludesNaN=*/true));
}
TEST_F(ConstantFPRangeTest, ExhaustivelyEnumerate) {
@@ -425,4 +442,32 @@ TEST_F(ConstantFPRangeTest, MismatchedSemantics) {
#endif
#endif
+TEST_F(ConstantFPRangeTest, makeAllowedFCmpRegion) {
+ for (auto Pred : FCmpInst::predicates()) {
+ EnumerateConstantFPRanges(
+ [Pred](const ConstantFPRange &CR) {
+ ConstantFPRange Res =
+ ConstantFPRange::makeAllowedFCmpRegion(Pred, CR);
+ ConstantFPRange Optimal =
+ ConstantFPRange::getEmpty(CR.getSemantics());
+ EnumerateValuesInConstantFPRange(
+ ConstantFPRange::getFull(CR.getSemantics()),
+ [&](const APFloat &V) {
+ if (AnyOfValueInConstantFPRange(CR, [&](const APFloat &U) {
+ return FCmpInst::compare(V, U, Pred);
+ }))
+ Optimal = Optimal.unionWith(ConstantFPRange(V));
+ });
+
+ ASSERT_TRUE(Res.contains(Optimal))
+ << "Wrong result for makeAllowedFCmpRegion(" << Pred << ", " << CR
+ << "). Expected " << Optimal << ", but got " << Res;
+ EXPECT_EQ(Res, Optimal)
+ << "Suboptimal result for makeAllowedFCmpRegion(" << Pred << ", "
+ << CR << ")";
+ },
+ /*Exhaustive=*/false);
+ }
+}
+
} // anonymous namespace
More information about the llvm-commits
mailing list