[llvm] [ConstantFPRange] Implement `ConstantFPRange::makeAllowedFCmpRegion` (PR #110082)

via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 25 22:28:41 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

Note: the return type of `makeExactFCmpRegion` is changed to `std::optional<ConstantFPRange>` because I realized that we cannot represent the result of `makeExactFCmpRegion(one, X)` as a ConstantFPRange.


---
Full diff: https://github.com/llvm/llvm-project/pull/110082.diff


3 Files Affected:

- (modified) llvm/include/llvm/IR/ConstantFPRange.h (+12-5) 
- (modified) llvm/lib/IR/ConstantFPRange.cpp (+112-7) 
- (modified) llvm/unittests/IR/ConstantFPRangeTest.cpp (+45) 


``````````diff
diff --git a/llvm/include/llvm/IR/ConstantFPRange.h b/llvm/include/llvm/IR/ConstantFPRange.h
index 67f9f945d748ba..cab3a860eaf4ef 100644
--- a/llvm/include/llvm/IR/ConstantFPRange.h
+++ b/llvm/include/llvm/IR/ConstantFPRange.h
@@ -50,7 +50,6 @@ class [[nodiscard]] ConstantFPRange {
 
   void makeEmpty();
   void makeFull();
-  bool isNaNOnly() const;
 
   /// Initialize a full or empty set for the specified semantics.
   explicit ConstantFPRange(const fltSemantics &Sem, bool IsFullSet);
@@ -78,6 +77,9 @@ class [[nodiscard]] ConstantFPRange {
   /// Helper for (-inf, inf) to represent all finite values.
   static ConstantFPRange getFinite(const fltSemantics &Sem);
 
+  /// Helper for [-inf, inf] to represent all non-NaN values.
+  static ConstantFPRange getNonNaN(const fltSemantics &Sem);
+
   /// Create a range which doesn't contain NaNs.
   static ConstantFPRange getNonNaN(APFloat LowerVal, APFloat UpperVal) {
     return ConstantFPRange(std::move(LowerVal), std::move(UpperVal),
@@ -123,8 +125,10 @@ class [[nodiscard]] ConstantFPRange {
   /// { x : fcmp op x y is true}'.
   ///
   /// Example: Pred = olt and Other = float 3 returns [-inf, 3)
-  static ConstantFPRange makeExactFCmpRegion(FCmpInst::Predicate Pred,
-                                             const APFloat &Other);
+  /// If the exact answer is not representable as a ConstantFPRange, return
+  /// std::nullopt.
+  static std::optional<ConstantFPRange>
+  makeExactFCmpRegion(FCmpInst::Predicate Pred, const APFloat &Other);
 
   /// Does the predicate \p Pred hold between ranges this and \p Other?
   /// NOTE: false does not mean that inverse predicate holds!
@@ -139,6 +143,7 @@ class [[nodiscard]] ConstantFPRange {
   bool containsNaN() const { return MayBeQNaN || MayBeSNaN; }
   bool containsQNaN() const { return MayBeQNaN; }
   bool containsSNaN() const { return MayBeSNaN; }
+  bool isNaNOnly() const;
 
   /// Get the semantics of this ConstantFPRange.
   const fltSemantics &getSemantics() const { return Lower.getSemantics(); }
@@ -157,10 +162,12 @@ class [[nodiscard]] ConstantFPRange {
   bool contains(const ConstantFPRange &CR) const;
 
   /// If this set contains a single element, return it, otherwise return null.
-  const APFloat *getSingleElement() const;
+  const APFloat *getSingleElement(bool ExcludesNaN = false) const;
 
   /// Return true if this set contains exactly one member.
-  bool isSingleElement() const { return getSingleElement() != nullptr; }
+  bool isSingleElement(bool ExcludesNaN = false) const {
+    return getSingleElement(ExcludesNaN) != nullptr;
+  }
 
   /// Return true if the sign bit of all values in this range is 1.
   /// Return false if the sign bit of all values in this range is 0.
diff --git a/llvm/lib/IR/ConstantFPRange.cpp b/llvm/lib/IR/ConstantFPRange.cpp
index 957701891c8f37..9f9e4f69a4079d 100644
--- a/llvm/lib/IR/ConstantFPRange.cpp
+++ b/llvm/lib/IR/ConstantFPRange.cpp
@@ -103,11 +103,115 @@ ConstantFPRange ConstantFPRange::getNaNOnly(const fltSemantics &Sem,
                          MayBeSNaN);
 }
 
+ConstantFPRange ConstantFPRange::getNonNaN(const fltSemantics &Sem) {
+  return ConstantFPRange(APFloat::getInf(Sem, /*Negative=*/true),
+                         APFloat::getInf(Sem, /*Negative=*/false),
+                         /*MayBeQNaN=*/false, /*MayBeSNaN=*/false);
+}
+
+/// Return [-inf, V) or [-inf, V]
+static ConstantFPRange makeLessThan(APFloat V, FCmpInst::Predicate Pred) {
+  const fltSemantics &Sem = V.getSemantics();
+  if (!(Pred & FCmpInst::FCMP_OEQ)) {
+    if (V.isNegInfinity())
+      return ConstantFPRange::getEmpty(Sem);
+    V.next(/*nextDown=*/true);
+  }
+  return ConstantFPRange::getNonNaN(APFloat::getInf(Sem, /*Negative=*/true),
+                                    std::move(V));
+}
+
+/// Return (V, +inf] or [V, +inf]
+static ConstantFPRange makeGreaterThan(APFloat V, FCmpInst::Predicate Pred) {
+  const fltSemantics &Sem = V.getSemantics();
+  if (!(Pred & FCmpInst::FCMP_OEQ)) {
+    if (V.isPosInfinity())
+      return ConstantFPRange::getEmpty(Sem);
+    V.next(/*nextDown=*/false);
+  }
+  return ConstantFPRange::getNonNaN(std::move(V),
+                                    APFloat::getInf(Sem, /*Negative=*/false));
+}
+
+/// Make sure that +0/-0 are both included in the range.
+static ConstantFPRange extendZeroIfEqual(const ConstantFPRange &CR,
+                                         FCmpInst::Predicate Pred) {
+  if (!(Pred & FCmpInst::FCMP_OEQ))
+    return CR;
+
+  APFloat Lower = CR.getLower();
+  APFloat Upper = CR.getUpper();
+  if (Lower.isPosZero())
+    Lower = APFloat::getZero(Lower.getSemantics(), /*Negative=*/true);
+  if (Upper.isNegZero())
+    Upper = APFloat::getZero(Upper.getSemantics(), /*Negative=*/false);
+  return ConstantFPRange(std::move(Lower), std::move(Upper), CR.containsQNaN(),
+                         CR.containsSNaN());
+}
+
+static ConstantFPRange setNaNField(const ConstantFPRange &CR,
+                                   FCmpInst::Predicate Pred) {
+  bool ContainsNaN = FCmpInst::isUnordered(Pred);
+  return ConstantFPRange(CR.getLower(), CR.getUpper(),
+                         /*MayBeQNaN=*/ContainsNaN, /*MayBeSNaN=*/ContainsNaN);
+}
+
 ConstantFPRange
 ConstantFPRange::makeAllowedFCmpRegion(FCmpInst::Predicate Pred,
                                        const ConstantFPRange &Other) {
-  // TODO
-  return getFull(Other.getSemantics());
+  if (Other.isEmptySet())
+    return Other;
+  if (Other.containsNaN() && FCmpInst::isUnordered(Pred))
+    return getFull(Other.getSemantics());
+  if (Other.isNaNOnly() && FCmpInst::isOrdered(Pred))
+    return getEmpty(Other.getSemantics());
+
+  switch (Pred) {
+  case FCmpInst::FCMP_TRUE:
+    return getFull(Other.getSemantics());
+  case FCmpInst::FCMP_FALSE:
+    return getEmpty(Other.getSemantics());
+  case FCmpInst::FCMP_ORD:
+    return getNonNaN(Other.getSemantics());
+  case FCmpInst::FCMP_UNO:
+    return getNaNOnly(Other.getSemantics(), /*MayBeQNaN=*/true,
+                      /*MayBeSNaN=*/true);
+  case FCmpInst::FCMP_OEQ:
+  case FCmpInst::FCMP_UEQ:
+    return setNaNField(extendZeroIfEqual(Other, Pred), Pred);
+  case FCmpInst::FCMP_ONE:
+  case FCmpInst::FCMP_UNE:
+    if (const APFloat *SingleElement =
+            Other.getSingleElement(/*ExcludesNaN=*/true)) {
+      const fltSemantics &Sem = SingleElement->getSemantics();
+      if (SingleElement->isPosInfinity())
+        return setNaNField(
+            getNonNaN(APFloat::getInf(Sem, /*Negative=*/true),
+                      APFloat::getLargest(Sem, /*Negative=*/false)),
+            Pred);
+      if (SingleElement->isNegInfinity())
+        return setNaNField(
+            getNonNaN(APFloat::getLargest(Sem, /*Negative=*/true),
+                      APFloat::getInf(Sem, /*Negative=*/false)),
+            Pred);
+    }
+    return Pred == FCmpInst::FCMP_ONE ? getNonNaN(Other.getSemantics())
+                                      : getFull(Other.getSemantics());
+  case FCmpInst::FCMP_OLT:
+  case FCmpInst::FCMP_OLE:
+  case FCmpInst::FCMP_ULT:
+  case FCmpInst::FCMP_ULE:
+    return setNaNField(
+        extendZeroIfEqual(makeLessThan(Other.getUpper(), Pred), Pred), Pred);
+  case FCmpInst::FCMP_OGT:
+  case FCmpInst::FCMP_OGE:
+  case FCmpInst::FCMP_UGT:
+  case FCmpInst::FCMP_UGE:
+    return setNaNField(
+        extendZeroIfEqual(makeGreaterThan(Other.getLower(), Pred), Pred), Pred);
+  default:
+    llvm_unreachable("Unexpected predicate");
+  }
 }
 
 ConstantFPRange
@@ -117,9 +221,10 @@ ConstantFPRange::makeSatisfyingFCmpRegion(FCmpInst::Predicate Pred,
   return getEmpty(Other.getSemantics());
 }
 
-ConstantFPRange ConstantFPRange::makeExactFCmpRegion(FCmpInst::Predicate Pred,
-                                                     const APFloat &Other) {
-  return makeAllowedFCmpRegion(Pred, ConstantFPRange(Other));
+std::optional<ConstantFPRange>
+ConstantFPRange::makeExactFCmpRegion(FCmpInst::Predicate Pred,
+                                     const APFloat &Other) {
+  return std::nullopt;
 }
 
 bool ConstantFPRange::fcmp(FCmpInst::Predicate Pred,
@@ -161,8 +266,8 @@ bool ConstantFPRange::contains(const ConstantFPRange &CR) const {
          strictCompare(CR.Upper, Upper) != APFloat::cmpGreaterThan;
 }
 
-const APFloat *ConstantFPRange::getSingleElement() const {
-  if (MayBeSNaN || MayBeQNaN)
+const APFloat *ConstantFPRange::getSingleElement(bool ExcludesNaN) const {
+  if (!ExcludesNaN && (MayBeSNaN || MayBeQNaN))
     return nullptr;
   return Lower.bitwiseIsEqual(Upper) ? &Lower : nullptr;
 }
diff --git a/llvm/unittests/IR/ConstantFPRangeTest.cpp b/llvm/unittests/IR/ConstantFPRangeTest.cpp
index 722e6566730da5..1fe9231392d622 100644
--- a/llvm/unittests/IR/ConstantFPRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantFPRangeTest.cpp
@@ -161,6 +161,19 @@ static void EnumerateValuesInConstantFPRange(const ConstantFPRange &CR,
   }
 }
 
+template <typename Fn>
+static bool AnyOfValueInConstantFPRange(const ConstantFPRange &CR, Fn TestFn) {
+  const fltSemantics &Sem = CR.getSemantics();
+  unsigned Bits = APFloat::semanticsSizeInBits(Sem);
+  assert(Bits < 32 && "Too many bits");
+  for (unsigned I = 0, E = (1U << Bits) - 1; I != E; ++I) {
+    APFloat V(Sem, APInt(Bits, I));
+    if (CR.contains(V) && TestFn(V))
+      return true;
+  }
+  return false;
+}
+
 TEST_F(ConstantFPRangeTest, Basics) {
   EXPECT_TRUE(Full.isFullSet());
   EXPECT_FALSE(Full.isEmptySet());
@@ -263,12 +276,16 @@ TEST_F(ConstantFPRangeTest, SingleElement) {
   EXPECT_EQ(*One.getSingleElement(), APFloat(1.0));
   EXPECT_EQ(*PosZero.getSingleElement(), APFloat::getZero(Sem));
   EXPECT_EQ(*PosInf.getSingleElement(), APFloat::getInf(Sem));
+  ConstantFPRange PosZeroOrNaN = PosZero.unionWith(NaN);
+  EXPECT_EQ(*PosZeroOrNaN.getSingleElement(/*ExcludesNaN=*/true),
+            APFloat::getZero(Sem));
 
   EXPECT_FALSE(Full.isSingleElement());
   EXPECT_FALSE(Empty.isSingleElement());
   EXPECT_TRUE(One.isSingleElement());
   EXPECT_FALSE(Some.isSingleElement());
   EXPECT_FALSE(Zero.isSingleElement());
+  EXPECT_TRUE(PosZeroOrNaN.isSingleElement(/*ExcludesNaN=*/true));
 }
 
 TEST_F(ConstantFPRangeTest, ExhaustivelyEnumerate) {
@@ -425,4 +442,32 @@ TEST_F(ConstantFPRangeTest, MismatchedSemantics) {
 #endif
 #endif
 
+TEST_F(ConstantFPRangeTest, makeAllowedFCmpRegion) {
+  for (auto Pred : FCmpInst::predicates()) {
+    EnumerateConstantFPRanges(
+        [Pred](const ConstantFPRange &CR) {
+          ConstantFPRange Res =
+              ConstantFPRange::makeAllowedFCmpRegion(Pred, CR);
+          ConstantFPRange Optimal =
+              ConstantFPRange::getEmpty(CR.getSemantics());
+          EnumerateValuesInConstantFPRange(
+              ConstantFPRange::getFull(CR.getSemantics()),
+              [&](const APFloat &V) {
+                if (AnyOfValueInConstantFPRange(CR, [&](const APFloat &U) {
+                      return FCmpInst::compare(V, U, Pred);
+                    }))
+                  Optimal = Optimal.unionWith(ConstantFPRange(V));
+              });
+
+          ASSERT_TRUE(Res.contains(Optimal))
+              << "Wrong result for makeAllowedFCmpRegion(" << Pred << ", " << CR
+              << "). Expected " << Optimal << ", but got " << Res;
+          EXPECT_EQ(Res, Optimal)
+              << "Suboptimal result for makeAllowedFCmpRegion(" << Pred << ", "
+              << CR << ")";
+        },
+        /*Exhaustive=*/false);
+  }
+}
+
 } // anonymous namespace

``````````

</details>


https://github.com/llvm/llvm-project/pull/110082


More information about the llvm-commits mailing list