[llvm] [llvm] add GenericFloatingPointPredicateUtils (PR #140254)

Tim Gymnich via llvm-commits llvm-commits at lists.llvm.org
Fri May 16 06:59:24 PDT 2025


https://github.com/tgymnich created https://github.com/llvm/llvm-project/pull/140254

- add `GenericFloatingPointPredicateUtils` in order to generalize effects of floating point comparisons on `KnownFPClass` for both IR and MIR.

>From a3b886cdf6928206b50acfd8d4bbef1680c7433d Mon Sep 17 00:00:00 2001
From: Tim Gymnich <tim at gymni.ch>
Date: Thu, 15 May 2025 23:22:34 -0400
Subject: [PATCH] [llvm] add GenericFloatingPointPredicateUtils

---
 .../ADT/GenericFloatingPointPredicateUtils.h  | 479 ++++++++++++++++++
 .../Analysis/FloatingPointPredicateUtils.h    |  74 +++
 llvm/include/llvm/Analysis/ValueTracking.h    |  43 --
 .../MachineFloatingPointPredicateUtils.h      |  52 ++
 llvm/lib/Analysis/CMakeLists.txt              |   1 +
 .../Analysis/FloatingPointPredicateUtils.cpp  |  42 ++
 llvm/lib/Analysis/InstructionSimplify.cpp     |   1 +
 llvm/lib/Analysis/ValueTracking.cpp           | 424 +---------------
 llvm/lib/CodeGen/CMakeLists.txt               |   1 +
 llvm/lib/CodeGen/CodeGenPrepare.cpp           |   1 +
 .../MachineFloatingPointPredicateUtils.cpp    |  49 ++
 .../InstCombine/InstCombineAndOrXor.cpp       |   1 +
 llvm/unittests/Analysis/ValueTrackingTest.cpp |   1 +
 13 files changed, 704 insertions(+), 465 deletions(-)
 create mode 100644 llvm/include/llvm/ADT/GenericFloatingPointPredicateUtils.h
 create mode 100644 llvm/include/llvm/Analysis/FloatingPointPredicateUtils.h
 create mode 100644 llvm/include/llvm/CodeGen/MachineFloatingPointPredicateUtils.h
 create mode 100644 llvm/lib/Analysis/FloatingPointPredicateUtils.cpp
 create mode 100644 llvm/lib/CodeGen/MachineFloatingPointPredicateUtils.cpp

diff --git a/llvm/include/llvm/ADT/GenericFloatingPointPredicateUtils.h b/llvm/include/llvm/ADT/GenericFloatingPointPredicateUtils.h
new file mode 100644
index 0000000000000..49c5fe0aed6e1
--- /dev/null
+++ b/llvm/include/llvm/ADT/GenericFloatingPointPredicateUtils.h
@@ -0,0 +1,479 @@
+//===- llvm/Support/GenericFloatingPointPredicateUtils.h -----*- C++-*-----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Utilities for dealing with flags related to floating point properties and
+/// mode controls.
+///
+//===----------------------------------------------------------------------===/
+
+#ifndef LLVM_ADT_GENERICFLOATINGPOINTPREDICATEUTILS_H
+#define LLVM_ADT_GENERICFLOATINGPOINTPREDICATEUTILS_H
+
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/FloatingPointMode.h"
+#include "llvm/IR/Instructions.h"
+#include <optional>
+
+namespace llvm {
+
+template <typename ContextT> class GenericFloatingPointPredicateUtils {
+  using ValueRefT = typename ContextT::ValueRefT;
+  using FunctionT = typename ContextT::FunctionT;
+
+  constexpr static ValueRefT Invalid = {};
+
+private:
+  static DenormalMode queryDenormalMode(const FunctionT &F, ValueRefT Val);
+
+  static bool lookThroughFAbs(const FunctionT &F, ValueRefT LHS,
+                              ValueRefT &Src);
+
+  static std::optional<APFloat> matchConstantFloat(const FunctionT &F,
+                                                   ValueRefT Val);
+
+  /// Return the return value for fcmpImpliesClass for a compare that produces
+  /// an exact class test.
+  static std::tuple<ValueRefT, FPClassTest, FPClassTest>
+  exactClass(ValueRefT V, FPClassTest M) {
+    return {V, M, ~M};
+  }
+
+public:
+  /// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
+  /// same result as an fcmp with the given operands.
+  static std::pair<ValueRefT, FPClassTest>
+  fcmpToClassTest(FCmpInst::Predicate Pred, const FunctionT &F, ValueRefT LHS,
+                  ValueRefT RHS, bool LookThroughSrc) {
+    std::optional<APFloat> ConstRHS = matchConstantFloat(F, RHS);
+    if (!ConstRHS)
+      return {Invalid, fcAllFlags};
+
+    return fcmpToClassTest(Pred, F, LHS, *ConstRHS, LookThroughSrc);
+  }
+
+  static std::pair<ValueRefT, FPClassTest>
+  fcmpToClassTest(FCmpInst::Predicate Pred, const FunctionT &F, ValueRefT LHS,
+                  const APFloat &ConstRHS, bool LookThroughSrc) {
+
+    auto [Src, ClassIfTrue, ClassIfFalse] =
+        fcmpImpliesClass(Pred, F, LHS, ConstRHS, LookThroughSrc);
+
+    if (Src && ClassIfTrue == ~ClassIfFalse)
+      return {Src, ClassIfTrue};
+
+    return {Invalid, fcAllFlags};
+  }
+
+  /// Compute the possible floating-point classes that \p LHS could be based on
+  /// fcmp \Pred \p LHS, \p RHS.
+  ///
+  /// \returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
+  ///
+  /// If the compare returns an exact class test, ClassesIfTrue ==
+  /// ~ClassesIfFalse
+  ///
+  /// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
+  /// only succeed for a test of x > 0 implies positive, but not x > 1).
+  ///
+  /// If \p LookThroughSrc is true, consider the input value when computing the
+  /// mask. This may look through sign bit operations.
+  ///
+  /// If \p LookThroughSrc is false, ignore the source value (i.e. the first
+  /// pair element will always be LHS.
+  ///
+  static std::tuple<ValueRefT, FPClassTest, FPClassTest>
+  fcmpImpliesClass(CmpInst::Predicate Pred, const FunctionT &F, ValueRefT LHS,
+                   FPClassTest RHSClass, bool LookThroughSrc) {
+    assert(RHSClass != fcNone);
+    ValueRefT Src = LHS;
+
+    if (Pred == FCmpInst::FCMP_TRUE)
+      return exactClass(Src, fcAllFlags);
+
+    if (Pred == FCmpInst::FCMP_FALSE)
+      return exactClass(Src, fcNone);
+
+    const FPClassTest OrigClass = RHSClass;
+
+    const bool IsNegativeRHS = (RHSClass & fcNegative) == RHSClass;
+    const bool IsPositiveRHS = (RHSClass & fcPositive) == RHSClass;
+    const bool IsNaN = (RHSClass & ~fcNan) == fcNone;
+
+    if (IsNaN) {
+      // fcmp o__ x, nan -> false
+      // fcmp u__ x, nan -> true
+      return exactClass(Src, CmpInst::isOrdered(Pred) ? fcNone : fcAllFlags);
+    }
+
+    // fcmp ord x, zero|normal|subnormal|inf -> ~fcNan
+    if (Pred == FCmpInst::FCMP_ORD)
+      return exactClass(Src, ~fcNan);
+
+    // fcmp uno x, zero|normal|subnormal|inf -> fcNan
+    if (Pred == FCmpInst::FCMP_UNO)
+      return exactClass(Src, fcNan);
+
+    const bool IsFabs = LookThroughSrc && lookThroughFAbs(F, LHS, Src);
+    if (IsFabs)
+      RHSClass = llvm::inverse_fabs(RHSClass);
+
+    const bool IsZero = (OrigClass & fcZero) == OrigClass;
+    if (IsZero) {
+      assert(Pred != FCmpInst::FCMP_ORD && Pred != FCmpInst::FCMP_UNO);
+      // Compares with fcNone are only exactly equal to fcZero if input
+      // denormals are not flushed.
+      // TODO: Handle DAZ by expanding masks to cover subnormal cases.
+      DenormalMode Mode = queryDenormalMode(F, LHS);
+      if (Mode.Input != DenormalMode::IEEE)
+        return {Invalid, fcAllFlags, fcAllFlags};
+
+      switch (Pred) {
+      case FCmpInst::FCMP_OEQ: // Match x == 0.0
+        return exactClass(Src, fcZero);
+      case FCmpInst::FCMP_UEQ: // Match isnan(x) || (x == 0.0)
+        return exactClass(Src, fcZero | fcNan);
+      case FCmpInst::FCMP_UNE: // Match (x != 0.0)
+        return exactClass(Src, ~fcZero);
+      case FCmpInst::FCMP_ONE: // Match !isnan(x) && x != 0.0
+        return exactClass(Src, ~fcNan & ~fcZero);
+      case FCmpInst::FCMP_ORD:
+        // Canonical form of ord/uno is with a zero. We could also handle
+        // non-canonical other non-NaN constants or LHS == RHS.
+        return exactClass(Src, ~fcNan);
+      case FCmpInst::FCMP_UNO:
+        return exactClass(Src, fcNan);
+      case FCmpInst::FCMP_OGT: // x > 0
+        return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf);
+      case FCmpInst::FCMP_UGT: // isnan(x) || x > 0
+        return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan);
+      case FCmpInst::FCMP_OGE: // x >= 0
+        return exactClass(Src, fcPositive | fcNegZero);
+      case FCmpInst::FCMP_UGE: // isnan(x) || x >= 0
+        return exactClass(Src, fcPositive | fcNegZero | fcNan);
+      case FCmpInst::FCMP_OLT: // x < 0
+        return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf);
+      case FCmpInst::FCMP_ULT: // isnan(x) || x < 0
+        return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan);
+      case FCmpInst::FCMP_OLE: // x <= 0
+        return exactClass(Src, fcNegative | fcPosZero);
+      case FCmpInst::FCMP_ULE: // isnan(x) || x <= 0
+        return exactClass(Src, fcNegative | fcPosZero | fcNan);
+      default:
+        llvm_unreachable("all compare types are handled");
+      }
+
+      return {Invalid, fcAllFlags, fcAllFlags};
+    }
+
+    const bool IsDenormalRHS = (OrigClass & fcSubnormal) == OrigClass;
+
+    const bool IsInf = (OrigClass & fcInf) == OrigClass;
+    if (IsInf) {
+      FPClassTest Mask = fcAllFlags;
+
+      switch (Pred) {
+      case FCmpInst::FCMP_OEQ:
+      case FCmpInst::FCMP_UNE: {
+        // Match __builtin_isinf patterns
+        //
+        //   fcmp oeq x, +inf -> is_fpclass x, fcPosInf
+        //   fcmp oeq fabs(x), +inf -> is_fpclass x, fcInf
+        //   fcmp oeq x, -inf -> is_fpclass x, fcNegInf
+        //   fcmp oeq fabs(x), -inf -> is_fpclass x, 0 -> false
+        //
+        //   fcmp une x, +inf -> is_fpclass x, ~fcPosInf
+        //   fcmp une fabs(x), +inf -> is_fpclass x, ~fcInf
+        //   fcmp une x, -inf -> is_fpclass x, ~fcNegInf
+        //   fcmp une fabs(x), -inf -> is_fpclass x, fcAllFlags -> true
+        if (IsNegativeRHS) {
+          Mask = fcNegInf;
+          if (IsFabs)
+            Mask = fcNone;
+        } else {
+          Mask = fcPosInf;
+          if (IsFabs)
+            Mask |= fcNegInf;
+        }
+        break;
+      }
+      case FCmpInst::FCMP_ONE:
+      case FCmpInst::FCMP_UEQ: {
+        // Match __builtin_isinf patterns
+        //   fcmp one x, -inf -> is_fpclass x, fcNegInf
+        //   fcmp one fabs(x), -inf -> is_fpclass x, ~fcNegInf & ~fcNan
+        //   fcmp one x, +inf -> is_fpclass x, ~fcNegInf & ~fcNan
+        //   fcmp one fabs(x), +inf -> is_fpclass x, ~fcInf & fcNan
+        //
+        //   fcmp ueq x, +inf -> is_fpclass x, fcPosInf|fcNan
+        //   fcmp ueq (fabs x), +inf -> is_fpclass x, fcInf|fcNan
+        //   fcmp ueq x, -inf -> is_fpclass x, fcNegInf|fcNan
+        //   fcmp ueq fabs(x), -inf -> is_fpclass x, fcNan
+        if (IsNegativeRHS) {
+          Mask = ~fcNegInf & ~fcNan;
+          if (IsFabs)
+            Mask = ~fcNan;
+        } else {
+          Mask = ~fcPosInf & ~fcNan;
+          if (IsFabs)
+            Mask &= ~fcNegInf;
+        }
+
+        break;
+      }
+      case FCmpInst::FCMP_OLT:
+      case FCmpInst::FCMP_UGE: {
+        if (IsNegativeRHS) {
+          // No value is ordered and less than negative infinity.
+          // All values are unordered with or at least negative infinity.
+          // fcmp olt x, -inf -> false
+          // fcmp uge x, -inf -> true
+          Mask = fcNone;
+          break;
+        }
+
+        // fcmp olt fabs(x), +inf -> fcFinite
+        // fcmp uge fabs(x), +inf -> ~fcFinite
+        // fcmp olt x, +inf -> fcFinite|fcNegInf
+        // fcmp uge x, +inf -> ~(fcFinite|fcNegInf)
+        Mask = fcFinite;
+        if (!IsFabs)
+          Mask |= fcNegInf;
+        break;
+      }
+      case FCmpInst::FCMP_OGE:
+      case FCmpInst::FCMP_ULT: {
+        if (IsNegativeRHS) {
+          // fcmp oge x, -inf -> ~fcNan
+          // fcmp oge fabs(x), -inf -> ~fcNan
+          // fcmp ult x, -inf -> fcNan
+          // fcmp ult fabs(x), -inf -> fcNan
+          Mask = ~fcNan;
+          break;
+        }
+
+        // fcmp oge fabs(x), +inf -> fcInf
+        // fcmp oge x, +inf -> fcPosInf
+        // fcmp ult fabs(x), +inf -> ~fcInf
+        // fcmp ult x, +inf -> ~fcPosInf
+        Mask = fcPosInf;
+        if (IsFabs)
+          Mask |= fcNegInf;
+        break;
+      }
+      case FCmpInst::FCMP_OGT:
+      case FCmpInst::FCMP_ULE: {
+        if (IsNegativeRHS) {
+          // fcmp ogt x, -inf -> fcmp one x, -inf
+          // fcmp ogt fabs(x), -inf -> fcmp ord x, x
+          // fcmp ule x, -inf -> fcmp ueq x, -inf
+          // fcmp ule fabs(x), -inf -> fcmp uno x, x
+          Mask = IsFabs ? ~fcNan : ~(fcNegInf | fcNan);
+          break;
+        }
+
+        // No value is ordered and greater than infinity.
+        Mask = fcNone;
+        break;
+      }
+      case FCmpInst::FCMP_OLE:
+      case FCmpInst::FCMP_UGT: {
+        if (IsNegativeRHS) {
+          Mask = IsFabs ? fcNone : fcNegInf;
+          break;
+        }
+
+        // fcmp ole x, +inf -> fcmp ord x, x
+        // fcmp ole fabs(x), +inf -> fcmp ord x, x
+        // fcmp ole x, -inf -> fcmp oeq x, -inf
+        // fcmp ole fabs(x), -inf -> false
+        Mask = ~fcNan;
+        break;
+      }
+      default:
+        llvm_unreachable("all compare types are handled");
+      }
+
+      // Invert the comparison for the unordered cases.
+      if (FCmpInst::isUnordered(Pred))
+        Mask = ~Mask;
+
+      return exactClass(Src, Mask);
+    }
+
+    if (Pred == FCmpInst::FCMP_OEQ)
+      return {Src, RHSClass, fcAllFlags};
+
+    if (Pred == FCmpInst::FCMP_UEQ) {
+      FPClassTest Class = RHSClass | fcNan;
+      return {Src, Class, ~fcNan};
+    }
+
+    if (Pred == FCmpInst::FCMP_ONE)
+      return {Src, ~fcNan, RHSClass | fcNan};
+
+    if (Pred == FCmpInst::FCMP_UNE)
+      return {Src, fcAllFlags, RHSClass};
+
+    assert((RHSClass == fcNone || RHSClass == fcPosNormal ||
+            RHSClass == fcNegNormal || RHSClass == fcNormal ||
+            RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal ||
+            RHSClass == fcSubnormal) &&
+           "should have been recognized as an exact class test");
+
+    if (IsNegativeRHS) {
+      // TODO: Handle fneg(fabs)
+      if (IsFabs) {
+        // fabs(x) o> -k -> fcmp ord x, x
+        // fabs(x) u> -k -> true
+        // fabs(x) o< -k -> false
+        // fabs(x) u< -k -> fcmp uno x, x
+        switch (Pred) {
+        case FCmpInst::FCMP_OGT:
+        case FCmpInst::FCMP_OGE:
+          return {Src, ~fcNan, fcNan};
+        case FCmpInst::FCMP_UGT:
+        case FCmpInst::FCMP_UGE:
+          return {Src, fcAllFlags, fcNone};
+        case FCmpInst::FCMP_OLT:
+        case FCmpInst::FCMP_OLE:
+          return {Src, fcNone, fcAllFlags};
+        case FCmpInst::FCMP_ULT:
+        case FCmpInst::FCMP_ULE:
+          return {Src, fcNan, ~fcNan};
+        default:
+          break;
+        }
+
+        return {Invalid, fcAllFlags, fcAllFlags};
+      }
+
+      FPClassTest ClassesLE = fcNegInf | fcNegNormal;
+      FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;
+
+      if (IsDenormalRHS)
+        ClassesLE |= fcNegSubnormal;
+      else
+        ClassesGE |= fcNegNormal;
+
+      switch (Pred) {
+      case FCmpInst::FCMP_OGT:
+      case FCmpInst::FCMP_OGE:
+        return {Src, ClassesGE, ~ClassesGE | RHSClass};
+      case FCmpInst::FCMP_UGT:
+      case FCmpInst::FCMP_UGE:
+        return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
+      case FCmpInst::FCMP_OLT:
+      case FCmpInst::FCMP_OLE:
+        return {Src, ClassesLE, ~ClassesLE | RHSClass};
+      case FCmpInst::FCMP_ULT:
+      case FCmpInst::FCMP_ULE:
+        return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
+      default:
+        break;
+      }
+    } else if (IsPositiveRHS) {
+      FPClassTest ClassesGE = fcPosNormal | fcPosInf;
+      FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosSubnormal;
+      if (IsDenormalRHS)
+        ClassesGE |= fcPosSubnormal;
+      else
+        ClassesLE |= fcPosNormal;
+
+      if (IsFabs) {
+        ClassesGE = llvm::inverse_fabs(ClassesGE);
+        ClassesLE = llvm::inverse_fabs(ClassesLE);
+      }
+
+      switch (Pred) {
+      case FCmpInst::FCMP_OGT:
+      case FCmpInst::FCMP_OGE:
+        return {Src, ClassesGE, ~ClassesGE | RHSClass};
+      case FCmpInst::FCMP_UGT:
+      case FCmpInst::FCMP_UGE:
+        return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
+      case FCmpInst::FCMP_OLT:
+      case FCmpInst::FCMP_OLE:
+        return {Src, ClassesLE, ~ClassesLE | RHSClass};
+      case FCmpInst::FCMP_ULT:
+      case FCmpInst::FCMP_ULE:
+        return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
+      default:
+        break;
+      }
+    }
+
+    return {Invalid, fcAllFlags, fcAllFlags};
+  }
+
+  static std::tuple<ValueRefT, FPClassTest, FPClassTest>
+  fcmpImpliesClass(CmpInst::Predicate Pred, const FunctionT &F, ValueRefT LHS,
+                   const APFloat &ConstRHS, bool LookThroughSrc) {
+    // We can refine checks against smallest normal / largest denormal to an
+    // exact class test.
+    if (!ConstRHS.isNegative() && ConstRHS.isSmallestNormalized()) {
+      ValueRefT Src = LHS;
+      const bool IsFabs = LookThroughSrc && lookThroughFAbs(F, LHS, Src);
+
+      FPClassTest Mask;
+      // Match pattern that's used in __builtin_isnormal.
+      switch (Pred) {
+      case FCmpInst::FCMP_OLT:
+      case FCmpInst::FCMP_UGE: {
+        // fcmp olt x, smallest_normal ->
+        // fcNegInf|fcNegNormal|fcSubnormal|fcZero fcmp olt fabs(x),
+        // smallest_normal -> fcSubnormal|fcZero fcmp uge x, smallest_normal ->
+        // fcNan|fcPosNormal|fcPosInf fcmp uge fabs(x), smallest_normal ->
+        // ~(fcSubnormal|fcZero)
+        Mask = fcZero | fcSubnormal;
+        if (!IsFabs)
+          Mask |= fcNegNormal | fcNegInf;
+
+        break;
+      }
+      case FCmpInst::FCMP_OGE:
+      case FCmpInst::FCMP_ULT: {
+        // fcmp oge x, smallest_normal -> fcPosNormal | fcPosInf
+        // fcmp oge fabs(x), smallest_normal -> fcInf | fcNormal
+        // fcmp ult x, smallest_normal -> ~(fcPosNormal | fcPosInf)
+        // fcmp ult fabs(x), smallest_normal -> ~(fcInf | fcNormal)
+        Mask = fcPosInf | fcPosNormal;
+        if (IsFabs)
+          Mask |= fcNegInf | fcNegNormal;
+        break;
+      }
+      default:
+        return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(),
+                                LookThroughSrc);
+      }
+
+      // Invert the comparison for the unordered cases.
+      if (FCmpInst::isUnordered(Pred))
+        Mask = ~Mask;
+
+      return exactClass(Src, Mask);
+    }
+
+    return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(), LookThroughSrc);
+  }
+
+  static std::tuple<ValueRefT, FPClassTest, FPClassTest>
+  fcmpImpliesClass(CmpInst::Predicate Pred, const FunctionT &F, ValueRefT LHS,
+                   ValueRefT RHS, bool LookThroughSrc) {
+    std::optional<APFloat> ConstRHS = matchConstantFloat(F, RHS);
+    if (!ConstRHS)
+      return {Invalid, fcAllFlags, fcAllFlags};
+
+    // TODO: Just call computeKnownFPClass for RHS to handle non-constants.
+    return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
+  }
+};
+
+} // namespace llvm
+
+#endif // LLVM_ADT_GENERICFLOATINGPOINTPREDICATEUTILS_H
diff --git a/llvm/include/llvm/Analysis/FloatingPointPredicateUtils.h b/llvm/include/llvm/Analysis/FloatingPointPredicateUtils.h
new file mode 100644
index 0000000000000..68c9876988fbc
--- /dev/null
+++ b/llvm/include/llvm/Analysis/FloatingPointPredicateUtils.h
@@ -0,0 +1,74 @@
+//===- llvm/Analysis/FloatingPointPredicateUtils.h ----------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_FLOATINGPOINTPREDICATEUTILS_H
+#define LLVM_ANALYSIS_FLOATINGPOINTPREDICATEUTILS_H
+
+#include "llvm/ADT/GenericFloatingPointPredicateUtils.h"
+#include "llvm/IR/SSAContext.h"
+
+namespace llvm {
+
+using FloatingPointPredicateUtils =
+    GenericFloatingPointPredicateUtils<SSAContext>;
+
+/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
+/// same result as an fcmp with the given operands.
+///
+/// If \p LookThroughSrc is true, consider the input value when computing the
+/// mask.
+///
+/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
+/// element will always be LHS.
+inline std::pair<Value *, FPClassTest>
+fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
+                Value *RHS, bool LookThroughSrc = true) {
+  return FloatingPointPredicateUtils::fcmpToClassTest(Pred, F, LHS, RHS,
+                                                      LookThroughSrc = true);
+}
+
+/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
+/// same result as an fcmp with the given operands.
+///
+/// If \p LookThroughSrc is true, consider the input value when computing the
+/// mask.
+///
+/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
+/// element will always be LHS.
+inline std::pair<Value *, FPClassTest>
+fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
+                const APFloat *ConstRHS, bool LookThroughSrc = true) {
+  return FloatingPointPredicateUtils::fcmpToClassTest(Pred, F, LHS, *ConstRHS,
+                                                      LookThroughSrc);
+}
+
+inline std::tuple<Value *, FPClassTest, FPClassTest>
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                 FPClassTest RHSClass, bool LookThroughSrc = true) {
+  return FloatingPointPredicateUtils::fcmpImpliesClass(Pred, F, LHS, RHSClass,
+                                                       LookThroughSrc);
+}
+
+inline std::tuple<Value *, FPClassTest, FPClassTest>
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                 const APFloat &ConstRHS, bool LookThroughSrc = true) {
+  return FloatingPointPredicateUtils::fcmpImpliesClass(Pred, F, LHS, ConstRHS,
+                                                       LookThroughSrc);
+}
+
+inline std::tuple<Value *, FPClassTest, FPClassTest>
+fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
+                 Value *RHS, bool LookThroughSrc = true) {
+  return FloatingPointPredicateUtils::fcmpImpliesClass(Pred, F, LHS, RHS,
+                                                       LookThroughSrc);
+}
+
+} // namespace llvm
+
+#endif // LLVM_ANALYSIS_FLOATINGPOINTPREDICATEUTILS_H
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 61dbb07e7128e..919e575ea0236 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -213,49 +213,6 @@ Intrinsic::ID getIntrinsicForCallSite(const CallBase &CB,
 bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
                     bool &TrueIfSigned);
 
-/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
-/// same result as an fcmp with the given operands.
-///
-/// If \p LookThroughSrc is true, consider the input value when computing the
-/// mask.
-///
-/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
-/// element will always be LHS.
-std::pair<Value *, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
-                                                const Function &F, Value *LHS,
-                                                Value *RHS,
-                                                bool LookThroughSrc = true);
-std::pair<Value *, FPClassTest> fcmpToClassTest(CmpInst::Predicate Pred,
-                                                const Function &F, Value *LHS,
-                                                const APFloat *ConstRHS,
-                                                bool LookThroughSrc = true);
-
-/// Compute the possible floating-point classes that \p LHS could be based on
-/// fcmp \Pred \p LHS, \p RHS.
-///
-/// \returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
-///
-/// If the compare returns an exact class test, ClassesIfTrue == ~ClassesIfFalse
-///
-/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
-/// only succeed for a test of x > 0 implies positive, but not x > 1).
-///
-/// If \p LookThroughSrc is true, consider the input value when computing the
-/// mask. This may look through sign bit operations.
-///
-/// If \p LookThroughSrc is false, ignore the source value (i.e. the first pair
-/// element will always be LHS.
-///
-std::tuple<Value *, FPClassTest, FPClassTest>
-fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
-                 Value *RHS, bool LookThroughSrc = true);
-std::tuple<Value *, FPClassTest, FPClassTest>
-fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
-                 FPClassTest RHS, bool LookThroughSrc = true);
-std::tuple<Value *, FPClassTest, FPClassTest>
-fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
-                 const APFloat &RHS, bool LookThroughSrc = true);
-
 /// Determine which floating-point classes are valid for \p V, and return them
 /// in KnownFPClass bit sets.
 ///
diff --git a/llvm/include/llvm/CodeGen/MachineFloatingPointPredicateUtils.h b/llvm/include/llvm/CodeGen/MachineFloatingPointPredicateUtils.h
new file mode 100644
index 0000000000000..36a7690e18564
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/MachineFloatingPointPredicateUtils.h
@@ -0,0 +1,52 @@
+//===-- MachineFloatingPointModeUtils.h -----*- C++ ---------------------*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+/// This file declares the MIR specialization of the GenericConvergenceVerifier
+/// template.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_MACHINEFLOATINGPOINTPREDICATEUTILS_H
+#define LLVM_CODEGEN_MACHINEFLOATINGPOINTPREDICATEUTILS_H
+
+#include "llvm/ADT/GenericFloatingPointPredicateUtils.h"
+#include "llvm/CodeGen/MachineSSAContext.h"
+
+namespace llvm {
+
+using MachineFloatingPointPredicateUtils =
+    GenericFloatingPointPredicateUtils<MachineSSAContext>;
+
+/// Compute the possible floating-point classes that \p LHS could be based on
+/// fcmp \Pred \p LHS, \p RHS.
+///
+/// \returns { TestedValue, ClassesIfTrue, ClassesIfFalse }
+///
+/// If the compare returns an exact class test, ClassesIfTrue ==
+/// ~ClassesIfFalse
+///
+/// This is a less exact version of fcmpToClassTest (e.g. fcmpToClassTest will
+/// only succeed for a test of x > 0 implies positive, but not x > 1).
+///
+/// If \p LookThroughSrc is true, consider the input value when computing the
+/// mask. This may look through sign bit operations.
+///
+/// If \p LookThroughSrc is false, ignore the source value (i.e. the first
+/// pair element will always be LHS.
+///
+inline std::tuple<Register, FPClassTest, FPClassTest>
+fcmpImpliesClass(CmpInst::Predicate Pred, const MachineFunction &MF,
+                 Register LHS, Register RHS, bool LookThroughSrc = true) {
+  return MachineFloatingPointPredicateUtils::fcmpImpliesClass(
+      Pred, MF, LHS, RHS, LookThroughSrc);
+}
+
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_MACHINEFLOATINGPOINTPREDICATEUTILS_H
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index a17a75e6fbcac..e884f11f0f758 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -74,6 +74,7 @@ add_llvm_component_library(LLVMAnalysis
   DXILResource.cpp
   DXILMetadataAnalysis.cpp
   EphemeralValuesCache.cpp
+  FloatingPointPredicateUtils.cpp
   FunctionPropertiesAnalysis.cpp
   GlobalsModRef.cpp
   GuardUtils.cpp
diff --git a/llvm/lib/Analysis/FloatingPointPredicateUtils.cpp b/llvm/lib/Analysis/FloatingPointPredicateUtils.cpp
new file mode 100644
index 0000000000000..2a53380597123
--- /dev/null
+++ b/llvm/lib/Analysis/FloatingPointPredicateUtils.cpp
@@ -0,0 +1,42 @@
+//===- FloatingPointPredicateUtils.cpp - -----------*- C++
+//-*--------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/FloatingPointPredicateUtils.h"
+#include "llvm/IR/PatternMatch.h"
+#include <optional>
+
+namespace llvm {
+
+using namespace PatternMatch;
+
+template <>
+DenormalMode FloatingPointPredicateUtils::queryDenormalMode(const Function &F,
+                                                            Value *Val) {
+  Type *Ty = Val->getType()->getScalarType();
+  return F.getDenormalMode(Ty->getFltSemantics());
+}
+
+template <>
+bool FloatingPointPredicateUtils::lookThroughFAbs(const Function &F, Value *LHS,
+                                                  Value *&Src) {
+  return match(LHS, m_FAbs(m_Value(Src)));
+}
+
+template <>
+std::optional<APFloat>
+FloatingPointPredicateUtils::matchConstantFloat(const Function &F, Value *Val) {
+  const APFloat *ConstVal;
+
+  if (!match(Val, m_APFloatAllowPoison(ConstVal)))
+    return std::nullopt;
+
+  return *ConstVal;
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 85e3be9cc45c3..23e147ba8c6a1 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -26,6 +26,7 @@
 #include "llvm/Analysis/CaptureTracking.h"
 #include "llvm/Analysis/CmpInstAnalysis.h"
 #include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/FloatingPointPredicateUtils.h"
 #include "llvm/Analysis/InstSimplifyFolder.h"
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/LoopAnalysisManager.h"
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8405678aa9680..5b9338b500ecc 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/FloatingPointMode.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -27,6 +28,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/DomConditionCache.h"
+#include "llvm/Analysis/FloatingPointPredicateUtils.h"
 #include "llvm/Analysis/GuardUtils.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/Loads.h"
@@ -4498,13 +4500,6 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB,
   return Intrinsic::not_intrinsic;
 }
 
-/// Return true if it's possible to assume IEEE treatment of input denormals in
-/// \p F for \p Val.
-static bool inputDenormalIsIEEE(const Function &F, const Type *Ty) {
-  Ty = Ty->getScalarType();
-  return F.getDenormalMode(Ty->getFltSemantics()).Input == DenormalMode::IEEE;
-}
-
 static bool outputDenormalIsIEEEOrPosZero(const Function &F, const Type *Ty) {
   Ty = Ty->getScalarType();
   DenormalMode Mode = F.getDenormalMode(Ty->getFltSemantics());
@@ -4550,421 +4545,6 @@ bool llvm::isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
   }
 }
 
-/// Returns a pair of values, which if passed to llvm.is.fpclass, returns the
-/// same result as an fcmp with the given operands.
-std::pair<Value *, FPClassTest> llvm::fcmpToClassTest(FCmpInst::Predicate Pred,
-                                                      const Function &F,
-                                                      Value *LHS, Value *RHS,
-                                                      bool LookThroughSrc) {
-  const APFloat *ConstRHS;
-  if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
-    return {nullptr, fcAllFlags};
-
-  return fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
-}
-
-std::pair<Value *, FPClassTest>
-llvm::fcmpToClassTest(FCmpInst::Predicate Pred, const Function &F, Value *LHS,
-                      const APFloat *ConstRHS, bool LookThroughSrc) {
-
-  auto [Src, ClassIfTrue, ClassIfFalse] =
-      fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
-  if (Src && ClassIfTrue == ~ClassIfFalse)
-    return {Src, ClassIfTrue};
-  return {nullptr, fcAllFlags};
-}
-
-/// Return the return value for fcmpImpliesClass for a compare that produces an
-/// exact class test.
-static std::tuple<Value *, FPClassTest, FPClassTest> exactClass(Value *V,
-                                                                FPClassTest M) {
-  return {V, M, ~M};
-}
-
-std::tuple<Value *, FPClassTest, FPClassTest>
-llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
-                       FPClassTest RHSClass, bool LookThroughSrc) {
-  assert(RHSClass != fcNone);
-  Value *Src = LHS;
-
-  if (Pred == FCmpInst::FCMP_TRUE)
-    return exactClass(Src, fcAllFlags);
-
-  if (Pred == FCmpInst::FCMP_FALSE)
-    return exactClass(Src, fcNone);
-
-  const FPClassTest OrigClass = RHSClass;
-
-  const bool IsNegativeRHS = (RHSClass & fcNegative) == RHSClass;
-  const bool IsPositiveRHS = (RHSClass & fcPositive) == RHSClass;
-  const bool IsNaN = (RHSClass & ~fcNan) == fcNone;
-
-  if (IsNaN) {
-    // fcmp o__ x, nan -> false
-    // fcmp u__ x, nan -> true
-    return exactClass(Src, CmpInst::isOrdered(Pred) ? fcNone : fcAllFlags);
-  }
-
-  // fcmp ord x, zero|normal|subnormal|inf -> ~fcNan
-  if (Pred == FCmpInst::FCMP_ORD)
-    return exactClass(Src, ~fcNan);
-
-  // fcmp uno x, zero|normal|subnormal|inf -> fcNan
-  if (Pred == FCmpInst::FCMP_UNO)
-    return exactClass(Src, fcNan);
-
-  const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
-  if (IsFabs)
-    RHSClass = llvm::inverse_fabs(RHSClass);
-
-  const bool IsZero = (OrigClass & fcZero) == OrigClass;
-  if (IsZero) {
-    assert(Pred != FCmpInst::FCMP_ORD && Pred != FCmpInst::FCMP_UNO);
-    // Compares with fcNone are only exactly equal to fcZero if input denormals
-    // are not flushed.
-    // TODO: Handle DAZ by expanding masks to cover subnormal cases.
-    if (!inputDenormalIsIEEE(F, LHS->getType()))
-      return {nullptr, fcAllFlags, fcAllFlags};
-
-    switch (Pred) {
-    case FCmpInst::FCMP_OEQ: // Match x == 0.0
-      return exactClass(Src, fcZero);
-    case FCmpInst::FCMP_UEQ: // Match isnan(x) || (x == 0.0)
-      return exactClass(Src, fcZero | fcNan);
-    case FCmpInst::FCMP_UNE: // Match (x != 0.0)
-      return exactClass(Src, ~fcZero);
-    case FCmpInst::FCMP_ONE: // Match !isnan(x) && x != 0.0
-      return exactClass(Src, ~fcNan & ~fcZero);
-    case FCmpInst::FCMP_ORD:
-      // Canonical form of ord/uno is with a zero. We could also handle
-      // non-canonical other non-NaN constants or LHS == RHS.
-      return exactClass(Src, ~fcNan);
-    case FCmpInst::FCMP_UNO:
-      return exactClass(Src, fcNan);
-    case FCmpInst::FCMP_OGT: // x > 0
-      return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf);
-    case FCmpInst::FCMP_UGT: // isnan(x) || x > 0
-      return exactClass(Src, fcPosSubnormal | fcPosNormal | fcPosInf | fcNan);
-    case FCmpInst::FCMP_OGE: // x >= 0
-      return exactClass(Src, fcPositive | fcNegZero);
-    case FCmpInst::FCMP_UGE: // isnan(x) || x >= 0
-      return exactClass(Src, fcPositive | fcNegZero | fcNan);
-    case FCmpInst::FCMP_OLT: // x < 0
-      return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf);
-    case FCmpInst::FCMP_ULT: // isnan(x) || x < 0
-      return exactClass(Src, fcNegSubnormal | fcNegNormal | fcNegInf | fcNan);
-    case FCmpInst::FCMP_OLE: // x <= 0
-      return exactClass(Src, fcNegative | fcPosZero);
-    case FCmpInst::FCMP_ULE: // isnan(x) || x <= 0
-      return exactClass(Src, fcNegative | fcPosZero | fcNan);
-    default:
-      llvm_unreachable("all compare types are handled");
-    }
-
-    return {nullptr, fcAllFlags, fcAllFlags};
-  }
-
-  const bool IsDenormalRHS = (OrigClass & fcSubnormal) == OrigClass;
-
-  const bool IsInf = (OrigClass & fcInf) == OrigClass;
-  if (IsInf) {
-    FPClassTest Mask = fcAllFlags;
-
-    switch (Pred) {
-    case FCmpInst::FCMP_OEQ:
-    case FCmpInst::FCMP_UNE: {
-      // Match __builtin_isinf patterns
-      //
-      //   fcmp oeq x, +inf -> is_fpclass x, fcPosInf
-      //   fcmp oeq fabs(x), +inf -> is_fpclass x, fcInf
-      //   fcmp oeq x, -inf -> is_fpclass x, fcNegInf
-      //   fcmp oeq fabs(x), -inf -> is_fpclass x, 0 -> false
-      //
-      //   fcmp une x, +inf -> is_fpclass x, ~fcPosInf
-      //   fcmp une fabs(x), +inf -> is_fpclass x, ~fcInf
-      //   fcmp une x, -inf -> is_fpclass x, ~fcNegInf
-      //   fcmp une fabs(x), -inf -> is_fpclass x, fcAllFlags -> true
-      if (IsNegativeRHS) {
-        Mask = fcNegInf;
-        if (IsFabs)
-          Mask = fcNone;
-      } else {
-        Mask = fcPosInf;
-        if (IsFabs)
-          Mask |= fcNegInf;
-      }
-      break;
-    }
-    case FCmpInst::FCMP_ONE:
-    case FCmpInst::FCMP_UEQ: {
-      // Match __builtin_isinf patterns
-      //   fcmp one x, -inf -> is_fpclass x, fcNegInf
-      //   fcmp one fabs(x), -inf -> is_fpclass x, ~fcNegInf & ~fcNan
-      //   fcmp one x, +inf -> is_fpclass x, ~fcNegInf & ~fcNan
-      //   fcmp one fabs(x), +inf -> is_fpclass x, ~fcInf & fcNan
-      //
-      //   fcmp ueq x, +inf -> is_fpclass x, fcPosInf|fcNan
-      //   fcmp ueq (fabs x), +inf -> is_fpclass x, fcInf|fcNan
-      //   fcmp ueq x, -inf -> is_fpclass x, fcNegInf|fcNan
-      //   fcmp ueq fabs(x), -inf -> is_fpclass x, fcNan
-      if (IsNegativeRHS) {
-        Mask = ~fcNegInf & ~fcNan;
-        if (IsFabs)
-          Mask = ~fcNan;
-      } else {
-        Mask = ~fcPosInf & ~fcNan;
-        if (IsFabs)
-          Mask &= ~fcNegInf;
-      }
-
-      break;
-    }
-    case FCmpInst::FCMP_OLT:
-    case FCmpInst::FCMP_UGE: {
-      if (IsNegativeRHS) {
-        // No value is ordered and less than negative infinity.
-        // All values are unordered with or at least negative infinity.
-        // fcmp olt x, -inf -> false
-        // fcmp uge x, -inf -> true
-        Mask = fcNone;
-        break;
-      }
-
-      // fcmp olt fabs(x), +inf -> fcFinite
-      // fcmp uge fabs(x), +inf -> ~fcFinite
-      // fcmp olt x, +inf -> fcFinite|fcNegInf
-      // fcmp uge x, +inf -> ~(fcFinite|fcNegInf)
-      Mask = fcFinite;
-      if (!IsFabs)
-        Mask |= fcNegInf;
-      break;
-    }
-    case FCmpInst::FCMP_OGE:
-    case FCmpInst::FCMP_ULT: {
-      if (IsNegativeRHS) {
-        // fcmp oge x, -inf -> ~fcNan
-        // fcmp oge fabs(x), -inf -> ~fcNan
-        // fcmp ult x, -inf -> fcNan
-        // fcmp ult fabs(x), -inf -> fcNan
-        Mask = ~fcNan;
-        break;
-      }
-
-      // fcmp oge fabs(x), +inf -> fcInf
-      // fcmp oge x, +inf -> fcPosInf
-      // fcmp ult fabs(x), +inf -> ~fcInf
-      // fcmp ult x, +inf -> ~fcPosInf
-      Mask = fcPosInf;
-      if (IsFabs)
-        Mask |= fcNegInf;
-      break;
-    }
-    case FCmpInst::FCMP_OGT:
-    case FCmpInst::FCMP_ULE: {
-      if (IsNegativeRHS) {
-        // fcmp ogt x, -inf -> fcmp one x, -inf
-        // fcmp ogt fabs(x), -inf -> fcmp ord x, x
-        // fcmp ule x, -inf -> fcmp ueq x, -inf
-        // fcmp ule fabs(x), -inf -> fcmp uno x, x
-        Mask = IsFabs ? ~fcNan : ~(fcNegInf | fcNan);
-        break;
-      }
-
-      // No value is ordered and greater than infinity.
-      Mask = fcNone;
-      break;
-    }
-    case FCmpInst::FCMP_OLE:
-    case FCmpInst::FCMP_UGT: {
-      if (IsNegativeRHS) {
-        Mask = IsFabs ? fcNone : fcNegInf;
-        break;
-      }
-
-      // fcmp ole x, +inf -> fcmp ord x, x
-      // fcmp ole fabs(x), +inf -> fcmp ord x, x
-      // fcmp ole x, -inf -> fcmp oeq x, -inf
-      // fcmp ole fabs(x), -inf -> false
-      Mask = ~fcNan;
-      break;
-    }
-    default:
-      llvm_unreachable("all compare types are handled");
-    }
-
-    // Invert the comparison for the unordered cases.
-    if (FCmpInst::isUnordered(Pred))
-      Mask = ~Mask;
-
-    return exactClass(Src, Mask);
-  }
-
-  if (Pred == FCmpInst::FCMP_OEQ)
-    return {Src, RHSClass, fcAllFlags};
-
-  if (Pred == FCmpInst::FCMP_UEQ) {
-    FPClassTest Class = RHSClass | fcNan;
-    return {Src, Class, ~fcNan};
-  }
-
-  if (Pred == FCmpInst::FCMP_ONE)
-    return {Src, ~fcNan, RHSClass | fcNan};
-
-  if (Pred == FCmpInst::FCMP_UNE)
-    return {Src, fcAllFlags, RHSClass};
-
-  assert((RHSClass == fcNone || RHSClass == fcPosNormal ||
-          RHSClass == fcNegNormal || RHSClass == fcNormal ||
-          RHSClass == fcPosSubnormal || RHSClass == fcNegSubnormal ||
-          RHSClass == fcSubnormal) &&
-         "should have been recognized as an exact class test");
-
-  if (IsNegativeRHS) {
-    // TODO: Handle fneg(fabs)
-    if (IsFabs) {
-      // fabs(x) o> -k -> fcmp ord x, x
-      // fabs(x) u> -k -> true
-      // fabs(x) o< -k -> false
-      // fabs(x) u< -k -> fcmp uno x, x
-      switch (Pred) {
-      case FCmpInst::FCMP_OGT:
-      case FCmpInst::FCMP_OGE:
-        return {Src, ~fcNan, fcNan};
-      case FCmpInst::FCMP_UGT:
-      case FCmpInst::FCMP_UGE:
-        return {Src, fcAllFlags, fcNone};
-      case FCmpInst::FCMP_OLT:
-      case FCmpInst::FCMP_OLE:
-        return {Src, fcNone, fcAllFlags};
-      case FCmpInst::FCMP_ULT:
-      case FCmpInst::FCMP_ULE:
-        return {Src, fcNan, ~fcNan};
-      default:
-        break;
-      }
-
-      return {nullptr, fcAllFlags, fcAllFlags};
-    }
-
-    FPClassTest ClassesLE = fcNegInf | fcNegNormal;
-    FPClassTest ClassesGE = fcPositive | fcNegZero | fcNegSubnormal;
-
-    if (IsDenormalRHS)
-      ClassesLE |= fcNegSubnormal;
-    else
-      ClassesGE |= fcNegNormal;
-
-    switch (Pred) {
-    case FCmpInst::FCMP_OGT:
-    case FCmpInst::FCMP_OGE:
-      return {Src, ClassesGE, ~ClassesGE | RHSClass};
-    case FCmpInst::FCMP_UGT:
-    case FCmpInst::FCMP_UGE:
-      return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
-    case FCmpInst::FCMP_OLT:
-    case FCmpInst::FCMP_OLE:
-      return {Src, ClassesLE, ~ClassesLE | RHSClass};
-    case FCmpInst::FCMP_ULT:
-    case FCmpInst::FCMP_ULE:
-      return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
-    default:
-      break;
-    }
-  } else if (IsPositiveRHS) {
-    FPClassTest ClassesGE = fcPosNormal | fcPosInf;
-    FPClassTest ClassesLE = fcNegative | fcPosZero | fcPosSubnormal;
-    if (IsDenormalRHS)
-      ClassesGE |= fcPosSubnormal;
-    else
-      ClassesLE |= fcPosNormal;
-
-    if (IsFabs) {
-      ClassesGE = llvm::inverse_fabs(ClassesGE);
-      ClassesLE = llvm::inverse_fabs(ClassesLE);
-    }
-
-    switch (Pred) {
-    case FCmpInst::FCMP_OGT:
-    case FCmpInst::FCMP_OGE:
-      return {Src, ClassesGE, ~ClassesGE | RHSClass};
-    case FCmpInst::FCMP_UGT:
-    case FCmpInst::FCMP_UGE:
-      return {Src, ClassesGE | fcNan, ~(ClassesGE | fcNan) | RHSClass};
-    case FCmpInst::FCMP_OLT:
-    case FCmpInst::FCMP_OLE:
-      return {Src, ClassesLE, ~ClassesLE | RHSClass};
-    case FCmpInst::FCMP_ULT:
-    case FCmpInst::FCMP_ULE:
-      return {Src, ClassesLE | fcNan, ~(ClassesLE | fcNan) | RHSClass};
-    default:
-      break;
-    }
-  }
-
-  return {nullptr, fcAllFlags, fcAllFlags};
-}
-
-std::tuple<Value *, FPClassTest, FPClassTest>
-llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
-                       const APFloat &ConstRHS, bool LookThroughSrc) {
-  // We can refine checks against smallest normal / largest denormal to an
-  // exact class test.
-  if (!ConstRHS.isNegative() && ConstRHS.isSmallestNormalized()) {
-    Value *Src = LHS;
-    const bool IsFabs = LookThroughSrc && match(LHS, m_FAbs(m_Value(Src)));
-
-    FPClassTest Mask;
-    // Match pattern that's used in __builtin_isnormal.
-    switch (Pred) {
-    case FCmpInst::FCMP_OLT:
-    case FCmpInst::FCMP_UGE: {
-      // fcmp olt x, smallest_normal -> fcNegInf|fcNegNormal|fcSubnormal|fcZero
-      // fcmp olt fabs(x), smallest_normal -> fcSubnormal|fcZero
-      // fcmp uge x, smallest_normal -> fcNan|fcPosNormal|fcPosInf
-      // fcmp uge fabs(x), smallest_normal -> ~(fcSubnormal|fcZero)
-      Mask = fcZero | fcSubnormal;
-      if (!IsFabs)
-        Mask |= fcNegNormal | fcNegInf;
-
-      break;
-    }
-    case FCmpInst::FCMP_OGE:
-    case FCmpInst::FCMP_ULT: {
-      // fcmp oge x, smallest_normal -> fcPosNormal | fcPosInf
-      // fcmp oge fabs(x), smallest_normal -> fcInf | fcNormal
-      // fcmp ult x, smallest_normal -> ~(fcPosNormal | fcPosInf)
-      // fcmp ult fabs(x), smallest_normal -> ~(fcInf | fcNormal)
-      Mask = fcPosInf | fcPosNormal;
-      if (IsFabs)
-        Mask |= fcNegInf | fcNegNormal;
-      break;
-    }
-    default:
-      return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(),
-                              LookThroughSrc);
-    }
-
-    // Invert the comparison for the unordered cases.
-    if (FCmpInst::isUnordered(Pred))
-      Mask = ~Mask;
-
-    return exactClass(Src, Mask);
-  }
-
-  return fcmpImpliesClass(Pred, F, LHS, ConstRHS.classify(), LookThroughSrc);
-}
-
-std::tuple<Value *, FPClassTest, FPClassTest>
-llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
-                       Value *RHS, bool LookThroughSrc) {
-  const APFloat *ConstRHS;
-  if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
-    return {nullptr, fcAllFlags, fcAllFlags};
-
-  // TODO: Just call computeKnownFPClass for RHS to handle non-constants.
-  return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
-}
-
 static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
                                         unsigned Depth, bool CondIsTrue,
                                         const Instruction *CxtI,
diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index 5dd6413431255..dcf7c08d499e1 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -119,6 +119,7 @@ add_llvm_component_library(LLVMCodeGen
   MachineCycleAnalysis.cpp
   MachineDebugify.cpp
   MachineDomTreeUpdater.cpp
+  MachineFloatingPointPredicateUtils.cpp
   MachineDominanceFrontier.cpp
   MachineDominators.cpp
   MachineFrameInfo.cpp
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 76f27623c8656..52263026d6cea 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -24,6 +24,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/FloatingPointPredicateUtils.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
diff --git a/llvm/lib/CodeGen/MachineFloatingPointPredicateUtils.cpp b/llvm/lib/CodeGen/MachineFloatingPointPredicateUtils.cpp
new file mode 100644
index 0000000000000..c4c5ad61c0162
--- /dev/null
+++ b/llvm/lib/CodeGen/MachineFloatingPointPredicateUtils.cpp
@@ -0,0 +1,49 @@
+//===- MachineConvergenceVerifier.cpp - Verify convergencectrl ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineFloatingPointPredicateUtils.h"
+#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/CodeGen/LowLevelTypeUtils.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/MachineSSAContext.h"
+#include "llvm/IR/Constants.h"
+#include <optional>
+
+namespace llvm {
+
+using namespace MIPatternMatch;
+
+template <>
+DenormalMode
+MachineFloatingPointPredicateUtils::queryDenormalMode(const MachineFunction &MF,
+                                                      Register Val) {
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+  LLT Ty = MRI.getType(Val).getScalarType();
+  return MF.getDenormalMode(getFltSemanticForLLT(Ty));
+}
+
+template <>
+bool MachineFloatingPointPredicateUtils::lookThroughFAbs(
+    const MachineFunction &MF, Register LHS, Register &Src) {
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+  return mi_match(LHS, MRI, m_GFabs(m_Reg(Src)));
+}
+
+template <>
+std::optional<APFloat> MachineFloatingPointPredicateUtils::matchConstantFloat(
+    const MachineFunction &MF, Register Val) {
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+  const ConstantFP *ConstVal;
+  if (mi_match(Val, MRI, m_GFCst(ConstVal)))
+    return ConstVal->getValueAPF();
+
+  return std::nullopt;
+}
+
+} // namespace llvm
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 508aef63a3128..d90c22672a5ec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -12,6 +12,7 @@
 
 #include "InstCombineInternal.h"
 #include "llvm/Analysis/CmpInstAnalysis.h"
+#include "llvm/Analysis/FloatingPointPredicateUtils.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Intrinsics.h"
diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index e1baa389bbc66..a5050542b8186 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/FloatingPointPredicateUtils.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Dominators.h"



More information about the llvm-commits mailing list