[llvm] PatternMatch: migrate to CmpPredicate (PR #118534)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 13 03:18:10 PST 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/118534

>From 2cd636a4eead1f42b0ebab3be563bee4b45da073 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 20 Nov 2024 16:49:37 +0000
Subject: [PATCH 1/2] PatternMatch: migrate to CmpPredicate

With the introduction of CmpPredicate in 51a895a (IR: introduce struct
with CmpInst::Predicate and samesign), PatternMatch is one of the first
key pieces of infrastructure that must be updated to match a CmpInst
respecting samesign information. Implement this change to Cmp-matchers.

This is a preparatory step in migrating the codebase over to
CmpPredicate. Since we no functional changes are desired at this stage,
we have chosen not to migrate CmpPredicate::operator==(CmpPredicate)
calls to use CmpPredicate::getMatching(), as that would have visible
impact on tests that are not yet written: instead, we call
CmpPredicate::operator==(Predicate), preserving the old behavior, while
also inserting a few FIXME comments for follow-ups.
---
 llvm/include/llvm/IR/CmpPredicate.h           |  22 ++++
 llvm/include/llvm/IR/PatternMatch.h           | 108 +++++++++---------
 llvm/lib/Analysis/IVDescriptors.cpp           |   4 +-
 llvm/lib/Analysis/InstructionSimplify.cpp     |  16 +--
 llvm/lib/Analysis/OverflowInstAnalysis.cpp    |   2 +-
 llvm/lib/Analysis/ValueTracking.cpp           |  22 ++--
 llvm/lib/CodeGen/CodeGenPrepare.cpp           |   4 +-
 llvm/lib/CodeGen/ExpandMemCmp.cpp             |   2 +-
 llvm/lib/IR/Instructions.cpp                  |  21 ++++
 .../AArch64/AArch64TargetTransformInfo.cpp    |   2 +-
 .../Target/AMDGPU/AMDGPUCodeGenPrepare.cpp    |   2 +-
 .../AMDGPU/AMDGPUInstCombineIntrinsic.cpp     |   2 +-
 .../Hexagon/HexagonLoopIdiomRecognition.cpp   |   4 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp       |   2 +-
 .../InstCombine/InstCombineAddSub.cpp         |   8 +-
 .../InstCombine/InstCombineAndOrXor.cpp       |  24 ++--
 .../InstCombine/InstCombineCompares.cpp       |  26 ++---
 .../InstCombine/InstCombineSelect.cpp         |  54 ++++-----
 .../InstCombine/InstCombineVectorOps.cpp      |   4 +-
 .../InstCombine/InstructionCombining.cpp      |   6 +-
 .../Transforms/Scalar/CallSiteSplitting.cpp   |   4 +-
 .../Scalar/ConstraintElimination.cpp          |   8 +-
 .../Scalar/DeadStoreElimination.cpp           |   2 +-
 llvm/lib/Transforms/Scalar/EarlyCSE.cpp       |   6 +-
 llvm/lib/Transforms/Scalar/GuardWidening.cpp  |   2 +-
 llvm/lib/Transforms/Scalar/JumpThreading.cpp  |   4 +-
 llvm/lib/Transforms/Scalar/LICM.cpp           |  11 +-
 llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp |   2 +-
 .../Transforms/Scalar/LoopIdiomRecognize.cpp  |   4 +-
 .../Transforms/Scalar/SimpleLoopUnswitch.cpp  |  10 +-
 llvm/lib/Transforms/Utils/LoopPeel.cpp        |   2 +-
 .../Utils/ScalarEvolutionExpander.cpp         |   2 +-
 llvm/lib/Transforms/Utils/SimplifyIndVar.cpp  |   8 +-
 .../Transforms/Vectorize/SLPVectorizer.cpp    |  16 +--
 .../Transforms/Vectorize/VectorCombine.cpp    |  10 +-
 llvm/unittests/IR/PatternMatch.cpp            |   4 +-
 36 files changed, 237 insertions(+), 193 deletions(-)

diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
index 4b1be7beb2b663a..ae027305cb4199d 100644
--- a/llvm/include/llvm/IR/CmpPredicate.h
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -24,6 +24,9 @@ class CmpPredicate {
   bool HasSameSign;
 
 public:
+  /// Default constructor.
+  CmpPredicate() : Pred(CmpInst::BAD_ICMP_PREDICATE), HasSameSign(false) {}
+
   /// Constructed implictly with a either Predicate and samesign information, or
   /// just a Predicate, dropping samesign information.
   CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false)
@@ -52,11 +55,30 @@ class CmpPredicate {
 
   /// An operator== on the underlying Predicate.
   bool operator==(CmpInst::Predicate P) const { return Pred == P; }
+  bool operator!=(CmpInst::Predicate P) const { return Pred != P; }
 
   /// There is no operator== defined on CmpPredicate. Use getMatching instead to
   /// get the canonicalized matching CmpPredicate.
   bool operator==(CmpPredicate) const = delete;
+  bool operator!=(CmpPredicate) const = delete;
+
+  /// TypeSwitch over the CmpInst and either do ICmpInst::getCmpPredicate() or
+  /// FCmpInst::getPredicate().
+  static CmpPredicate get(const CmpInst *Cmp);
+
+  /// Get the swapped predicate of a CmpPredicate, using
+  /// CmpInst::isIntPredicate().
+  static CmpPredicate getSwapped(CmpPredicate P);
+
+  /// Get the swapped predicate of a CmpInst.
+  static CmpPredicate getSwapped(const CmpInst *Cmp);
+
+  /// Provided to facilitate storing a CmpPredicate in data structures that
+  /// require hashing.
+  friend hash_code hash_value(const CmpPredicate &Arg); // NOLINT
 };
+
+[[nodiscard]] hash_code hash_value(const CmpPredicate &Arg);
 } // namespace llvm
 
 #endif
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index fc4c0124d00b841..ed7e1eff6f0005b 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -688,14 +688,14 @@ inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
 }
 
 struct icmp_pred_with_threshold {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   const APInt *Thr;
   bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); }
 };
 /// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
 /// to Threshold. For vectors, this includes constants with undefined elements.
 inline cst_pred_ty<icmp_pred_with_threshold>
-m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) {
+m_SpecificInt_ICMP(CmpPredicate Predicate, const APInt &Threshold) {
   cst_pred_ty<icmp_pred_with_threshold> P;
   P.Pred = Predicate;
   P.Thr = &Threshold;
@@ -1557,16 +1557,16 @@ template <typename T> inline Exact_match<T> m_Exact(const T &SubPattern) {
 // Matchers for CmpInst classes
 //
 
-template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
+template <typename LHS_t, typename RHS_t, typename Class,
           bool Commutable = false>
 struct CmpClass_match {
-  PredicateTy *Predicate;
+  CmpPredicate *Predicate;
   LHS_t L;
   RHS_t R;
 
   // The evaluation order is always stable, regardless of Commutability.
   // The LHS is always matched first.
-  CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS)
+  CmpClass_match(CmpPredicate &Pred, const LHS_t &LHS, const RHS_t &RHS)
       : Predicate(&Pred), L(LHS), R(RHS) {}
   CmpClass_match(const LHS_t &LHS, const RHS_t &RHS)
       : Predicate(nullptr), L(LHS), R(RHS) {}
@@ -1575,12 +1575,13 @@ struct CmpClass_match {
     if (auto *I = dyn_cast<Class>(V)) {
       if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) {
         if (Predicate)
-          *Predicate = I->getPredicate();
+          *Predicate = CmpPredicate::get(I);
         return true;
-      } else if (Commutable && L.match(I->getOperand(1)) &&
-                 R.match(I->getOperand(0))) {
+      }
+      if (Commutable && L.match(I->getOperand(1)) &&
+          R.match(I->getOperand(0))) {
         if (Predicate)
-          *Predicate = I->getSwappedPredicate();
+          *Predicate = CmpPredicate::getSwapped(I);
         return true;
       }
     }
@@ -1589,60 +1590,58 @@ struct CmpClass_match {
 };
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
-m_Cmp(CmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(Pred, L, R);
+inline CmpClass_match<LHS, RHS, CmpInst> m_Cmp(CmpPredicate &Pred, const LHS &L,
+                                               const RHS &R) {
+  return CmpClass_match<LHS, RHS, CmpInst>(Pred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
-m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(Pred, L, R);
+inline CmpClass_match<LHS, RHS, ICmpInst> m_ICmp(CmpPredicate &Pred,
+                                                 const LHS &L, const RHS &R) {
+  return CmpClass_match<LHS, RHS, ICmpInst>(Pred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
-m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(Pred, L, R);
+inline CmpClass_match<LHS, RHS, FCmpInst> m_FCmp(CmpPredicate &Pred,
+                                                 const LHS &L, const RHS &R) {
+  return CmpClass_match<LHS, RHS, FCmpInst>(Pred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
-m_Cmp(const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(L, R);
+inline CmpClass_match<LHS, RHS, CmpInst> m_Cmp(const LHS &L, const RHS &R) {
+  return CmpClass_match<LHS, RHS, CmpInst>(L, R);
 }
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
-m_ICmp(const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(L, R);
+inline CmpClass_match<LHS, RHS, ICmpInst> m_ICmp(const LHS &L, const RHS &R) {
+  return CmpClass_match<LHS, RHS, ICmpInst>(L, R);
 }
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
-m_FCmp(const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(L, R);
+inline CmpClass_match<LHS, RHS, FCmpInst> m_FCmp(const LHS &L, const RHS &R) {
+  return CmpClass_match<LHS, RHS, FCmpInst>(L, R);
 }
 
 // Same as CmpClass, but instead of saving Pred as out output variable, match a
 // specific input pred for equality.
-template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
+template <typename LHS_t, typename RHS_t, typename Class,
           bool Commutable = false>
 struct SpecificCmpClass_match {
-  const PredicateTy Predicate;
+  const CmpPredicate Predicate;
   LHS_t L;
   RHS_t R;
 
-  SpecificCmpClass_match(PredicateTy Pred, const LHS_t &LHS, const RHS_t &RHS)
+  SpecificCmpClass_match(CmpPredicate Pred, const LHS_t &LHS, const RHS_t &RHS)
       : Predicate(Pred), L(LHS), R(RHS) {}
 
   template <typename OpTy> bool match(OpTy *V) {
     if (auto *I = dyn_cast<Class>(V)) {
-      if (I->getPredicate() == Predicate && L.match(I->getOperand(0)) &&
-          R.match(I->getOperand(1)))
+      if (CmpPredicate::getMatching(CmpPredicate::get(I), Predicate) &&
+          L.match(I->getOperand(0)) && R.match(I->getOperand(1)))
         return true;
       if constexpr (Commutable) {
-        if (I->getPredicate() == Class::getSwappedPredicate(Predicate) &&
+        if (CmpPredicate::getMatching(CmpPredicate::get(I),
+                                      CmpPredicate::getSwapped(Predicate)) &&
             L.match(I->getOperand(1)) && R.match(I->getOperand(0)))
           return true;
       }
@@ -1653,31 +1652,27 @@ struct SpecificCmpClass_match {
 };
 
 template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
-m_SpecificCmp(CmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
-  return SpecificCmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(
-      MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, CmpInst>
+m_SpecificCmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+  return SpecificCmpClass_match<LHS, RHS, CmpInst>(MatchPred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
-m_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
-  return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(
-      MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, ICmpInst>
+m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+  return SpecificCmpClass_match<LHS, RHS, ICmpInst>(MatchPred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
-m_c_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
-  return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(
-      MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, ICmpInst, true>
+m_c_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+  return SpecificCmpClass_match<LHS, RHS, ICmpInst, true>(MatchPred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
-m_SpecificFCmp(FCmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
-  return SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(
-      MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, FCmpInst>
+m_SpecificFCmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+  return SpecificCmpClass_match<LHS, RHS, FCmpInst>(MatchPred, L, R);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2468,7 +2463,7 @@ struct UAddWithOverflow_match {
 
   template <typename OpTy> bool match(OpTy *V) {
     Value *ICmpLHS, *ICmpRHS;
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
       return false;
 
@@ -2738,16 +2733,15 @@ inline AnyBinaryOp_match<LHS, RHS, true> m_c_BinOp(const LHS &L, const RHS &R) {
 /// Matches an ICmp with a predicate over LHS and RHS in either order.
 /// Swaps the predicate if operands are commuted.
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
-m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(Pred, L,
-                                                                       R);
+inline CmpClass_match<LHS, RHS, ICmpInst, true>
+m_c_ICmp(CmpPredicate &Pred, const LHS &L, const RHS &R) {
+  return CmpClass_match<LHS, RHS, ICmpInst, true>(Pred, L, R);
 }
 
 template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
-m_c_ICmp(const LHS &L, const RHS &R) {
-  return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(L, R);
+inline CmpClass_match<LHS, RHS, ICmpInst, true> m_c_ICmp(const LHS &L,
+                                                         const RHS &R) {
+  return CmpClass_match<LHS, RHS, ICmpInst, true>(L, R);
 }
 
 /// Matches a specific opcode with LHS and RHS in either order.
diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index e1eb219cf977e19..9670ec7f043944d 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -628,7 +628,7 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
                                      Instruction *I, InstDesc &Prev) {
   // We must handle the select(cmp(),x,y) as a single instruction. Advance to
   // the select.
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
     if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
       return InstDesc(Select, Prev.getRecKind());
@@ -668,7 +668,7 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
 
   // We must handle the select(cmp()) as a single instruction. Advance to the
   // select.
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
     if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
       return InstDesc(Select, Prev.getRecKind());
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 62edea38745b13d..3325cd972cf1eb3 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1500,12 +1500,12 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,
                                          const SimplifyQuery &Q) {
   Value *X, *Y;
 
-  ICmpInst::Predicate EqPred;
+  CmpPredicate EqPred;
   if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) ||
       !ICmpInst::isEquality(EqPred))
     return nullptr;
 
-  ICmpInst::Predicate UnsignedPred;
+  CmpPredicate UnsignedPred;
 
   Value *A, *B;
   // Y = (A - B);
@@ -1644,7 +1644,7 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1,
 static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
                                         const InstrInfoQuery &IIQ) {
   // (icmp (add V, C0), C1) & (icmp V, C0)
-  ICmpInst::Predicate Pred0, Pred1;
+  CmpPredicate Pred0, Pred1;
   const APInt *C0, *C1;
   Value *V;
   if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
@@ -1691,7 +1691,7 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
 /// Try to simplify and/or of icmp with ctpop intrinsic.
 static Value *simplifyAndOrOfICmpsWithCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1,
                                             bool IsAnd) {
-  ICmpInst::Predicate Pred0, Pred1;
+  CmpPredicate Pred0, Pred1;
   Value *X;
   const APInt *C;
   if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)),
@@ -1735,7 +1735,7 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
 static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
                                        const InstrInfoQuery &IIQ) {
   // (icmp (add V, C0), C1) | (icmp V, C0)
-  ICmpInst::Predicate Pred0, Pred1;
+  CmpPredicate Pred0, Pred1;
   const APInt *C0, *C1;
   Value *V;
   if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
@@ -1891,7 +1891,7 @@ static Value *simplifyAndOrWithICmpEq(unsigned Opcode, Value *Op0, Value *Op1,
                                       unsigned MaxRecurse) {
   assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
          "Must be and/or");
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *A, *B;
   if (!match(Op0, m_ICmp(Pred, m_Value(A), m_Value(B))) ||
       !ICmpInst::isEquality(Pred))
@@ -4614,7 +4614,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
                                          Value *FalseVal,
                                          const SimplifyQuery &Q,
                                          unsigned MaxRecurse) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *CmpLHS, *CmpRHS;
   if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
     return nullptr;
@@ -4738,7 +4738,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
 static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
                                      const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
-  FCmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *CmpLHS, *CmpRHS;
   if (!match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
     return nullptr;
diff --git a/llvm/lib/Analysis/OverflowInstAnalysis.cpp b/llvm/lib/Analysis/OverflowInstAnalysis.cpp
index 8bfd6642f76027a..40f71f4a8db46d4 100644
--- a/llvm/lib/Analysis/OverflowInstAnalysis.cpp
+++ b/llvm/lib/Analysis/OverflowInstAnalysis.cpp
@@ -20,7 +20,7 @@ using namespace llvm::PatternMatch;
 
 bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd,
                                             Use *&Y) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *X, *NotOp1;
   int XIdx;
   IntrinsicInst *II;
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f2c6949e535d2a9..4b940557903daff 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -259,7 +259,7 @@ bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
 
 bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
   return !I->user_empty() && all_of(I->users(), [](const User *U) {
-    ICmpInst::Predicate P;
+    CmpPredicate P;
     return match(U, m_ICmp(P, m_Value(), m_Zero())) && ICmpInst::isEquality(P);
   });
 }
@@ -614,7 +614,7 @@ static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
     // runtime of ~O(#assumes * #values).
 
     Value *RHS;
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
     if (!match(I->getArgOperand(0), m_c_ICmp(Pred, m_V, m_Value(RHS))))
       continue;
@@ -1602,7 +1602,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
         // See if we can further use a conditional branch into the phi
         // to help us determine the range of the value.
         if (!Known2.isConstant()) {
-          ICmpInst::Predicate Pred;
+          CmpPredicate Pred;
           const APInt *RHSC;
           BasicBlock *TrueSucc, *FalseSucc;
           // TODO: Use RHS Value and compute range from its known bits.
@@ -2255,7 +2255,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
 static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
                                              const Value *Cond,
                                              bool CondIsTrue) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   const APInt *RHSC;
   if (!match(Cond, m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Specific(V)),
                           m_APInt(RHSC))))
@@ -2580,7 +2580,7 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V,
 
     // Consider only compare instructions uniquely controlling a branch
     Value *RHS;
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (!match(U, m_c_ICmp(Pred, m_Specific(V), m_Value(RHS))))
       continue;
 
@@ -3009,7 +3009,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
       // The condition of the select dominates the true/false arm. Check if the
       // condition implies that a given arm is non-zero.
       Value *X;
-      CmpInst::Predicate Pred;
+      CmpPredicate Pred;
       if (!match(I->getOperand(0), m_c_ICmp(Pred, m_Specific(Op), m_Value(X))))
         return false;
 
@@ -3037,7 +3037,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
         return true;
       RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
       // Check if the branch on the phi excludes zero.
-      ICmpInst::Predicate Pred;
+      CmpPredicate Pred;
       Value *X;
       BasicBlock *TrueSucc, *FalseSucc;
       if (match(RecQ.CxtI,
@@ -4895,7 +4895,7 @@ static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
                                 KnownFromContext);
     return;
   }
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *LHS;
   uint64_t ClassVal = 0;
   const APFloat *CRHS;
@@ -5135,7 +5135,7 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
     FPClassTest MaskIfFalse = fcAllFlags;
     uint64_t ClassVal = 0;
     const Function *F = cast<Instruction>(Op)->getFunction();
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     Value *CmpLHS, *CmpRHS;
     if (F && match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
       // If the select filters out a value based on the class, it no longer
@@ -8571,7 +8571,7 @@ bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
 bool llvm::isKnownInversion(const Value *X, const Value *Y) {
   // Handle X = icmp pred A, B, Y = icmp pred A, C.
   Value *A, *B, *C;
-  ICmpInst::Predicate Pred1, Pred2;
+  CmpPredicate Pred1, Pred2;
   if (!match(X, m_ICmp(Pred1, m_Value(A), m_Value(B))) ||
       !match(Y, m_c_ICmp(Pred2, m_Specific(A), m_Value(C))))
     return false;
@@ -10054,7 +10054,7 @@ void llvm::findValuesAffectedByCondition(
     if (!Visited.insert(V).second)
       continue;
 
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     Value *A, *B, *X;
 
     if (IsAssume) {
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 83c6ecd401039f9..5c712e4f007d392 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1885,7 +1885,7 @@ static bool foldICmpWithDominatingICmp(CmpInst *Cmp,
     return false;
 
   Value *CmpOp0 = Cmp->getOperand(0), *CmpOp1 = Cmp->getOperand(1);
-  ICmpInst::Predicate DomPred;
+  CmpPredicate DomPred;
   if (!match(DomCond, m_ICmp(DomPred, m_Specific(CmpOp0), m_Specific(CmpOp1))))
     return false;
   if (DomPred != ICmpInst::ICMP_SGT && DomPred != ICmpInst::ICMP_SLT)
@@ -2155,7 +2155,7 @@ bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
 static bool adjustIsPower2Test(CmpInst *Cmp, const TargetLowering &TLI,
                                const TargetTransformInfo &TTI,
                                const DataLayout &DL) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!match(Cmp, m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(), m_One())))
     return false;
   if (!ICmpInst::isEquality(Pred))
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index a1acb4ef3683805..f8ca7e370f6ef9c 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -668,7 +668,7 @@ Value *MemCmpExpansion::getMemCmpOneBlock() {
   // We can generate more optimal code with a smaller number of operations
   if (CI->hasOneUser()) {
     auto *UI = cast<Instruction>(*CI->user_begin());
-    ICmpInst::Predicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE;
+    CmpPredicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE;
     uint64_t Shift;
     bool NeedsZExt = false;
     // This is a special case because instead of checking if the result is less
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 4f07a4c4dd017a6..d27863574dcf25a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -3932,6 +3933,26 @@ std::optional<CmpPredicate> CmpPredicate::getMatching(CmpPredicate A,
   return {};
 }
 
+CmpPredicate CmpPredicate::get(const CmpInst *Cmp) {
+  return TypeSwitch<const CmpInst *, CmpPredicate>(Cmp)
+      .Case<ICmpInst>([](auto *ICI) { return ICI->getCmpPredicate(); })
+      .Case<FCmpInst>([](auto *FCI) { return FCI->getPredicate(); });
+}
+
+CmpPredicate CmpPredicate::getSwapped(CmpPredicate P) {
+  return CmpInst::isIntPredicate(P)
+             ? ICmpInst::getSwappedCmpPredicate(P)
+             : CmpPredicate{CmpInst::getSwappedPredicate(P)};
+}
+
+CmpPredicate CmpPredicate::getSwapped(const CmpInst *Cmp) {
+  return getSwapped(get(Cmp));
+}
+
+hash_code llvm::hash_value(const CmpPredicate &Arg) { // NOLINT
+  return hash_combine(Arg.Pred, Arg.HasSameSign);
+}
+
 //===----------------------------------------------------------------------===//
 //                        SwitchInst Implementation
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 283fe4c3caa6022..73896ed5033c126 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3605,7 +3605,7 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
     // If VecPred is not set, check if we can get a predicate from the context
     // instruction, if its type matches the requested ValTy.
     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
-      CmpInst::Predicate CurrentPred;
+      CmpPredicate CurrentPred;
       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
                             m_Value())))
         VecPred = CurrentPred;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
index 75e20c793016815..2523e369985b0eb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
@@ -1688,7 +1688,7 @@ bool AMDGPUCodeGenPrepareImpl::visitSelectInst(SelectInst &I) {
   Value *TrueVal = I.getTrueValue();
   Value *FalseVal = I.getFalseValue();
   Value *CmpVal;
-  FCmpInst::Predicate Pred;
+  CmpPredicate Pred;
 
   if (ST.has16BitInsts() && needsPromotionToI32(I.getType())) {
     if (UA.isUniform(&I))
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 41b33ac8a7eb4b6..8b1b398606583e9 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -960,7 +960,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
       return &II;
     }
 
-    CmpInst::Predicate SrcPred;
+    CmpPredicate SrcPred;
     Value *SrcLHS;
     Value *SrcRHS;
 
diff --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
index 705e1f43851f7ad..46a8ab395d32bdc 100644
--- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
@@ -695,7 +695,7 @@ bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI,
 
   using namespace PatternMatch;
 
-  CmpInst::Predicate P;
+  CmpPredicate P;
   Value *A = nullptr, *B = nullptr, *C = nullptr;
 
   if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) &&
@@ -810,7 +810,7 @@ bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI,
   using namespace PatternMatch;
 
   Value *C = nullptr;
-  CmpInst::Predicate P;
+  CmpPredicate P;
   bool TrueIfZero;
 
   if (match(CondV, m_c_ICmp(P, m_Value(C), m_Zero()))) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 871de16d66b6c5c..83e4bcddaee1068 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -31477,7 +31477,7 @@ static bool shouldExpandCmpArithRMWInIR(AtomicRMWInst *AI) {
     return false;
 
   Value *Op = AI->getOperand(1);
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Instruction *I = AI->user_back();
   AtomicRMWInst::BinOp Opc = AI->getOperation();
   if (Opc == AtomicRMWInst::Add) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index fe0d88fcc6ee4b1..7a184a19d7c54ab 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1289,7 +1289,7 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
   // Note that, by the time we end up here, if possible, ugt has been
   // canonicalized into eq.
   const APInt *MaskC, *MaskCCmp;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!match(Add.getOperand(1),
              m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
                            m_APInt(MaskCCmp)))))
@@ -1382,7 +1382,7 @@ Instruction *InstCombinerImpl::
   // `select` itself may be appropriately extended, look past that.
   SkipExtInMagic(Select);
 
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   const APInt *Thr;
   Value *SignExtendingValue, *Zero;
   bool ShouldSignext;
@@ -1654,7 +1654,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
 
   // sext(A < B) + zext(A > B) => ucmp/scmp(A, B)
-  ICmpInst::Predicate LTPred, GTPred;
+  CmpPredicate LTPred, GTPred;
   if (match(&I,
             m_c_Add(m_SExt(m_c_ICmp(LTPred, m_Value(A), m_Value(B))),
                     m_ZExt(m_c_ICmp(GTPred, m_Deferred(A), m_Deferred(B))))) &&
@@ -1841,7 +1841,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // -->
   // BW - ctlz(A - 1, false)
   const APInt *XorC;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (match(&I,
             m_c_Add(
                 m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 314b1f0b43e3b59..dff9304be64ddb0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -738,7 +738,7 @@ static Value *
 foldAndOrOfICmpsWithPow2AndWithZero(InstCombiner::BuilderTy &Builder,
                                     ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
                                     const SimplifyQuery &Q) {
-  CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
+  CmpPredicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
   // Make sure we have right compares for our op.
   if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred)
     return nullptr;
@@ -875,7 +875,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
   // Try to match/decompose into:  icmp eq (X & Mask), 0
   auto tryToDecompose = [](ICmpInst *ICmp, Value *&X,
                            APInt &UnsetBitsMask) -> bool {
-    CmpInst::Predicate Pred = ICmp->getPredicate();
+    CmpPredicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
     auto Res =
         llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
@@ -944,7 +944,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
 static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd,
                                    InstCombiner::BuilderTy &Builder,
                                    InstCombinerImpl &IC) {
-  CmpInst::Predicate Pred0, Pred1;
+  CmpPredicate Pred0, Pred1;
   Value *X;
   if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)),
                           m_SpecificInt(1))) ||
@@ -1117,12 +1117,12 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
                                          const SimplifyQuery &Q,
                                          InstCombiner::BuilderTy &Builder) {
   Value *ZeroCmpOp;
-  ICmpInst::Predicate EqPred;
+  CmpPredicate EqPred;
   if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) ||
       !ICmpInst::isEquality(EqPred))
     return nullptr;
 
-  ICmpInst::Predicate UnsignedPred;
+  CmpPredicate UnsignedPred;
 
   Value *A, *B;
   if (match(UnsignedICmp,
@@ -1281,7 +1281,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
                                           const SimplifyQuery &Q) {
   // Match an equality compare with a non-poison constant as Cmp0.
   // Also, give up if the compare can be constant-folded to avoid looping.
-  ICmpInst::Predicate Pred0;
+  CmpPredicate Pred0;
   Value *X;
   Constant *C;
   if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) ||
@@ -1295,7 +1295,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
   // common operand as operand 1 (Pred1 is swapped if the common operand was
   // operand 0).
   Value *Y;
-  ICmpInst::Predicate Pred1;
+  CmpPredicate Pred1;
   if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Specific(X))))
     return nullptr;
 
@@ -1326,7 +1326,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
 Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1,
                                                      ICmpInst *ICmp2,
                                                      bool IsAnd) {
-  ICmpInst::Predicate Pred1, Pred2;
+  CmpPredicate Pred1, Pred2;
   Value *V1, *V2;
   const APInt *C1, *C2;
   if (!match(ICmp1, m_ICmp(Pred1, m_Value(V1), m_APInt(C1))) ||
@@ -1348,12 +1348,12 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1,
     return nullptr;
 
   ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
-      IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, *C1);
+      IsAnd ? ICmpInst::getInverseCmpPredicate(Pred1) : Pred1, *C1);
   if (Offset1)
     CR1 = CR1.subtract(*Offset1);
 
   ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
-      IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, *C2);
+      IsAnd ? ICmpInst::getInverseCmpPredicate(Pred2) : Pred2, *C2);
   if (Offset2)
     CR2 = CR2.subtract(*Offset2);
 
@@ -3943,7 +3943,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
           canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
     return V;
 
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *Mul, *Ov, *MulIsNotZero, *UMulWithOv;
   // Check if the OR weakens the overflow condition for umul.with.overflow by
   // treating any non-zero result as overflow. In that case, we overflow if both
@@ -4608,7 +4608,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
   }
 
   // not (cmp A, B) = !cmp A, B
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (match(NotOp, m_Cmp(Pred, m_Value(), m_Value())) &&
       (NotOp->hasOneUse() ||
        InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(NotOp),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 54053c4c9e28e8f..d6fdade25559feb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1173,7 +1173,7 @@ Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
   // This fold is only valid for equality predicates.
   if (!I.isEquality())
     return nullptr;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *X, *Y, *Zero;
   if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))),
                         m_CombineAnd(m_Zero(), m_Value(Zero)))))
@@ -1190,7 +1190,7 @@ Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
 /// by one-less-than-bitwidth into a sign test on the original value.
 Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) {
   Instruction *Val;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
     return nullptr;
 
@@ -1404,7 +1404,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
   };
 
   for (BranchInst *BI : DC.conditionsFor(X)) {
-    ICmpInst::Predicate DomPred;
+    CmpPredicate DomPred;
     const APInt *DomC;
     if (!match(BI->getCondition(),
                m_ICmp(DomPred, m_Specific(X), m_APInt(DomC))))
@@ -1517,7 +1517,7 @@ Instruction *
 InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
                                               const SimplifyQuery &Q) {
   Value *X, *Y;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   bool YIsSExt = false;
   // Try to match icmp (trunc X), (trunc Y)
   if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) {
@@ -3249,7 +3249,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
   //        i32 Equal,
   //        i32 (select i1 (a < b), i32 Less, i32 Greater)
   // where Equal, Less and Greater are placeholders for any three constants.
-  ICmpInst::Predicate PredA;
+  CmpPredicate PredA;
   if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) ||
       !ICmpInst::isEquality(PredA))
     return false;
@@ -3260,7 +3260,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
     std::swap(EqualVal, UnequalVal);
   if (!match(EqualVal, m_ConstantInt(Equal)))
     return false;
-  ICmpInst::Predicate PredB;
+  CmpPredicate PredB;
   Value *LHS2, *RHS2;
   if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)),
                                   m_ConstantInt(Less), m_ConstantInt(Greater))))
@@ -4565,7 +4565,7 @@ static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
 static Value *
 foldICmpWithTruncSignExtendedVal(ICmpInst &I,
                                  InstCombiner::BuilderTy &Builder) {
-  ICmpInst::Predicate SrcPred;
+  CmpPredicate SrcPred;
   Value *X;
   const APInt *C0, *C1; // FIXME: non-splats, potentially with undef.
   // We are ok with 'shl' having multiple uses, but 'ashr' must be one-use.
@@ -4811,7 +4811,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
 /// Note that the comparison is commutative, while inverted (u>=, ==) predicate
 /// will mean that we are looking for the opposite answer.
 Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *X, *Y;
   Instruction *Mul;
   Instruction *Div;
@@ -4881,7 +4881,7 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
 
 static Instruction *foldICmpXNegX(ICmpInst &I,
                                   InstCombiner::BuilderTy &Builder) {
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *X;
   if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) {
 
@@ -6822,7 +6822,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
 /// then try to reduce patterns based on that limit.
 Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
   Value *X, *Y;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
 
   // X must be 0 and bool must be true for "ULT":
   // X <u (zext i1 Y) --> (X == 0) & Y
@@ -6837,7 +6837,7 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
     return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
 
   // icmp eq/ne X, (zext/sext (icmp eq/ne X, C))
-  ICmpInst::Predicate Pred1, Pred2;
+  CmpPredicate Pred1, Pred2;
   const APInt *C;
   Instruction *ExtI;
   if (match(&I, m_c_ICmp(Pred1, m_Value(X),
@@ -7107,7 +7107,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
 //   (X l>> Y) == 0
 static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
                                             InstCombiner::BuilderTy &Builder) {
-  ICmpInst::Predicate Pred, NewPred;
+  CmpPredicate Pred, NewPred;
   Value *X, *Y;
   if (match(&Cmp,
             m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) {
@@ -7272,7 +7272,7 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
                                        const DataLayout &DL) {
   if (I.getType()->isVectorTy())
     return nullptr;
-  ICmpInst::Predicate OuterPred, InnerPred;
+  CmpPredicate OuterPred, InnerPred;
   Value *LHS, *RHS;
 
   // Match lowering of @llvm.vector.reduce.and. Turn
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c7a0c35d099cc4e..50dfb58cadb17bb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -58,7 +58,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
   // The select condition must be an equality compare with a constant operand.
   Value *X;
   Constant *C;
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C))))
     return nullptr;
 
@@ -425,17 +425,19 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
 
     // icmp with a common operand also can have the common operand
     // pulled after the select.
-    ICmpInst::Predicate TPred, FPred;
+    CmpPredicate TPred, FPred;
     if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) &&
         match(FI, m_ICmp(FPred, m_Value(), m_Value()))) {
-      if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) {
-        bool Swapped = TPred != FPred;
+      // FIXME: Use CmpPredicate::getMatching here.
+      CmpInst::Predicate T = TPred, F = FPred;
+      if (T == F || T == ICmpInst::getSwappedCmpPredicate(F)) {
+        bool Swapped = T != F;
         if (Value *MatchOp =
                 getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) {
           Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
                                                SI.getName() + ".v", &SI);
           return new ICmpInst(
-              MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred),
+              MatchIsOpZero ? TPred : ICmpInst::getSwappedCmpPredicate(TPred),
               MatchOp, NewSel);
         }
       }
@@ -640,7 +642,7 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp,
 static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal,
                                        Value *FVal,
                                        InstCombiner::BuilderTy &Builder) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *AndVal;
   if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero())))
     return nullptr;
@@ -867,7 +869,7 @@ static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) {
   auto *TrueVal = SI.getTrueValue();
   auto *FalseVal = SI.getFalseValue();
   Value *X, *Y;
-  ICmpInst::Predicate Predicate;
+  CmpPredicate Predicate;
 
   // Assuming that constant compared with zero is not undef (but it may be
   // a vector with some undef elements). Otherwise (when a constant is undef)
@@ -1527,7 +1529,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
     return nullptr;
 
   Value *Cmp1;
-  ICmpInst::Predicate Pred1;
+  CmpPredicate Pred1;
   Constant *C2;
   Value *ReplacementLow, *ReplacementHigh;
   if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow),
@@ -1636,7 +1638,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
 static Instruction *
 tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
                                          InstCombinerImpl &IC) {
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *X;
   Constant *C0;
   if (!match(&Cmp, m_OneUse(m_ICmp(
@@ -1734,7 +1736,7 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
                                           InstCombiner::BuilderTy &Builder) {
   const APInt *CmpC;
   Value *V;
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC))))
     return nullptr;
 
@@ -1890,7 +1892,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
   BinaryOperator *BOp;
   Constant *C1, *C2, *C3;
   Value *X;
-  ICmpInst::Predicate Predicate;
+  CmpPredicate Predicate;
 
   if (!match(Cmp, m_ICmp(Predicate, m_Value(X), m_Constant(C1))))
     return nullptr;
@@ -2138,7 +2140,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
   auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) {
     Type *Ty = Limit->getType();
 
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     Value *TrueVal, *FalseVal, *Op;
     const APInt *C;
     if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)),
@@ -2347,7 +2349,7 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,
   Value *TVal = Sel.getTrueValue();
   Value *FVal = Sel.getFalseValue();
 
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *A, *B;
   if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B))))
     return nullptr;
@@ -2552,7 +2554,7 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
   Value *X;
   const APInt *C;
   bool IsTrueIfSignSet;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
                                    m_APInt(C)))) ||
       !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType)
@@ -2748,7 +2750,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
   Value *TrueVal = SI.getTrueValue();
   Value *FalseVal = SI.getFalseValue();
 
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *Op, *RemRes, *Remainder;
   const APInt *C;
   bool TrueIfSigned = false;
@@ -2807,7 +2809,7 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy
   //   a = select c, x, y   ;
   //   f(a, c)              ; f(poison, 1) cannot happen, but if a is folded
   //                        ; to y, this can happen.
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (FI->hasOneUse() &&
       match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
       (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) {
@@ -2856,7 +2858,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
   for (bool Swap : {false, true}) {
     Value *TrueVal = SI.getTrueValue();
     Value *X = SI.getFalseValue();
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
 
     if (Swap)
       std::swap(TrueVal, X);
@@ -2936,7 +2938,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
     if (Swap)
       std::swap(TrueVal, X);
 
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     const APInt *C;
     bool TrueIfSigned;
     if (!match(CondVal,
@@ -2980,7 +2982,7 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
   Value *X = SI.getTrueValue();
   Value *XBiasedHighBits = SI.getFalseValue();
 
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *XLowBits;
   if (!match(Cond, m_ICmp(Pred, m_Value(XLowBits), m_ZeroInt())) ||
       !ICmpInst::isEquality(Pred))
@@ -3159,7 +3161,7 @@ static bool impliesPoisonOrCond(const Value *ValAssumedPoison, const Value *V,
     Value *LHS = ICmp->getOperand(0);
     const APInt *RHSC1;
     const APInt *RHSC2;
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (ICmp->hasSameSign() &&
         match(ICmp->getOperand(1), m_APIntForbidPoison(RHSC1)) &&
         match(V, m_ICmp(Pred, m_Specific(LHS), m_APIntAllowPoison(RHSC2)))) {
@@ -3170,7 +3172,7 @@ static bool impliesPoisonOrCond(const Value *ValAssumedPoison, const Value *V,
                               APInt::getZero(BitWidth))
               : ConstantRange(APInt::getZero(BitWidth),
                               APInt::getSignedMinValue(BitWidth));
-      return CRX.icmp(Expected ? Pred : ICmpInst::getInversePredicate(Pred),
+      return CRX.icmp(Expected ? Pred : ICmpInst::getInverseCmpPredicate(Pred),
                       *RHSC2);
     }
   }
@@ -3539,7 +3541,7 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder,
 
   Value *FalseVal = SI.getFalseValue();
   Value *TrueVal = SI.getTrueValue();
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   const APInt *Cond1;
   Value *Cond0, *Ctlz, *CtlzOp;
   if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
@@ -3590,7 +3592,7 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
   Value *TV = SI.getTrueValue();
   Value *FV = SI.getFalseValue();
 
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *LHS, *RHS;
   if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
     return nullptr;
@@ -3610,7 +3612,7 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
   bool IsSigned = ICmpInst::isSigned(Pred);
 
   bool Replace = false;
-  ICmpInst::Predicate ExtendedCmpPredicate;
+  CmpPredicate ExtendedCmpPredicate;
   // (x < y) ? -1 : zext(x != y)
   // (x < y) ? -1 : zext(x > y)
   if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
@@ -3630,7 +3632,7 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
     Replace = true;
 
   // (x == y) ? 0 : (x > y ? 1 : -1)
-  ICmpInst::Predicate FalseBranchSelectPredicate;
+  CmpPredicate FalseBranchSelectPredicate;
   const APInt *InnerTV, *InnerFV;
   if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero()) &&
       match(FV, m_Select(m_c_ICmp(FalseBranchSelectPredicate, m_Specific(LHS),
@@ -3730,7 +3732,7 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
   Instruction *FAdd;
   Constant *C;
   Value *X, *Z;
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
 
   // Note: OneUse check for `Cmp` is necessary because it makes sure that other
   // InstCombine folds don't undo this transformation and cause an infinite
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 09eafd09451b246..77f1763d1193969 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -86,7 +86,7 @@ static bool cheapToScalarize(Value *V, Value *EI) {
     if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
       return true;
 
-  CmpInst::Predicate UnusedPred;
+  CmpPredicate UnusedPred;
   if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1)))))
     if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
       return true;
@@ -486,7 +486,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
   }
 
   Value *X, *Y;
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
       cheapToScalarize(SrcVec, Index)) {
     // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8f55e5b3cc28a26..4abf2317290413f 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3478,7 +3478,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI,
   // Validate the rest of constraint #1 by matching on the pred branch.
   Instruction *TI = PredBB->getTerminator();
   BasicBlock *TrueBB, *FalseBB;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (!match(TI, m_Br(m_ICmp(Pred,
                              m_CombineOr(m_Specific(Op),
                                          m_Specific(Op->stripPointerCasts())),
@@ -3759,7 +3759,7 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
     return replaceOperand(BI, 0, ConstantInt::getFalse(Cond->getType()));
 
   // Canonicalize, for example, fcmp_one -> fcmp_oeq.
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   if (match(Cond, m_OneUse(m_FCmp(Pred, m_Value(), m_Value()))) &&
       !isCanonicalPredicate(Pred)) {
     // Swap destinations and condition.
@@ -3820,7 +3820,7 @@ static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
   if (CstBB != SI.getDefaultDest())
     return nullptr;
   Value *X = Select->getOperand(3 - CstOpIdx);
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   const APInt *RHSC;
   if (!match(Select->getCondition(),
              m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
index b8571ba0748998f..bbc7a005b9ff4f0 100644
--- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -132,7 +132,7 @@ static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To,
   if (!BI || !BI->isConditional())
     return;
 
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *Cond = BI->getCondition();
   if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
     return;
@@ -142,7 +142,7 @@ static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To,
     if (isCondRelevantToAnyCallArgument(Cmp, CB))
       Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
                                      ? Pred
-                                     : Cmp->getInversePredicate()});
+                                     : Cmp->getInverseCmpPredicate()});
 }
 
 /// Record ICmp conditions relevant to any argument in CB following Pred's
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index e64fc153cf3d26f..589bfd05bb5d554 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -922,7 +922,7 @@ void State::addInfoForInductions(BasicBlock &BB) {
 
   Value *A;
   Value *B;
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
 
   if (!match(BB.getTerminator(),
              m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value())))
@@ -1089,7 +1089,7 @@ void State::addInfoFor(BasicBlock &BB) {
     switch (ID) {
     case Intrinsic::assume: {
       Value *A, *B;
-      CmpInst::Predicate Pred;
+      CmpPredicate Pred;
       if (!match(I.getOperand(0), m_ICmp(Pred, m_Value(A), m_Value(B))))
         break;
       if (GuaranteedToExecute) {
@@ -1537,7 +1537,7 @@ static bool checkOrAndOpImpliedByOther(
   while (!Worklist.empty()) {
     Value *Val = Worklist.pop_back_val();
     Value *LHS, *RHS;
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (match(Val, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) {
       // For OR, check if the negated condition implies CmpToCheck.
       if (IsOr)
@@ -1833,7 +1833,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
       }
     };
 
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (!CB.isConditionFact()) {
       Value *X;
       if (match(CB.Inst, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) {
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 09e8301b772d96b..4799640089fa9a6 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -2053,7 +2053,7 @@ struct DSEState {
       return false;
 
     Instruction *ICmpL;
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (!match(BI->getCondition(),
                m_c_ICmp(Pred,
                         m_CombineAnd(m_Load(m_Specific(StorePtr)),
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index cd4846e006031d4..682c5c3d8c63404 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -192,7 +192,7 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,
   // mechanism that may remove flags to increase the likelihood of CSE.
 
   Flavor = SPF_UNKNOWN;
-  CmpInst::Predicate Pred;
+  CmpPredicate Pred;
 
   if (!match(Cond, m_ICmp(Pred, m_Specific(A), m_Specific(B)))) {
     // Check for commuted variants of min/max by swapping predicate.
@@ -279,7 +279,7 @@ static unsigned getHashValueImpl(SimpleValue Val) {
     // Hash general selects to allow matching commuted true/false operands.
 
     // If we do not have a compare as the condition, just hash in the condition.
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     Value *X, *Y;
     if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y))))
       return hash_combine(Inst->getOpcode(), Cond, A, B);
@@ -451,7 +451,7 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
     // this code, as we simplify the double-negation before hashing the second
     // select (and so still succeed at CSEing them).
     if (LHSA == RHSB && LHSB == RHSA) {
-      CmpInst::Predicate PredL, PredR;
+      CmpPredicate PredL, PredR;
       Value *X, *Y;
       if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) &&
           match(CondR, m_Cmp(PredR, m_Specific(X), m_Specific(Y))) &&
diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
index a8fda0c6ab9cbe2..2978b7990a6ebe1 100644
--- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
@@ -727,7 +727,7 @@ GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist,
     // L >u C0 && L >u C1  ->  L >u max(C0, C1)
     ConstantInt *RHS0, *RHS1;
     Value *LHS;
-    ICmpInst::Predicate Pred0, Pred1;
+    CmpPredicate Pred0, Pred1;
     // TODO: Support searching for pairs to merge from both whole lists of
     // ChecksToHoist and ChecksToWiden.
     if (ChecksToWiden.size() == 1 && ChecksToHoist.size() == 1 &&
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 16110cd25bc61c5..300a564e222e163 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -591,7 +591,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
       // 'getPredicateOnEdge' method. This would be able to handle value
       // inequalities better, for example if the compare is "X < 4" and "X < 3"
       // is known true but "X < 4" itself is not available.
-      CmpInst::Predicate Pred;
+      CmpPredicate Pred;
       Value *Val;
       Constant *Cst;
       if (!PredCst && match(V, m_Cmp(Pred, m_Value(Val), m_Constant(Cst))))
@@ -2744,7 +2744,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
 // Pred is a predecessor of BB with an unconditional branch to BB. SI is
 // a Select instruction in Pred. BB has other predecessors and SI is used in
 // a PHI node in BB. SI has no other use.
-// A new basic block, NewBB, is created and SI is converted to compare and 
+// A new basic block, NewBB, is created and SI is converted to compare and
 // conditional branch. SI is erased from parent.
 void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
                                           SelectInst *SI, PHINode *SIUse,
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 3ade32027289317..a5d5eecb1ebf823 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -2430,8 +2430,8 @@ static bool hoistMinMax(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
   } else
     return false;
 
-  auto MatchICmpAgainstInvariant = [&](Value *C, ICmpInst::Predicate &P,
-                                       Value *&LHS, Value *&RHS) {
+  auto MatchICmpAgainstInvariant = [&](Value *C, CmpPredicate &P, Value *&LHS,
+                                       Value *&RHS) {
     if (!match(C, m_OneUse(m_ICmp(P, m_Value(LHS), m_Value(RHS)))))
       return false;
     if (!LHS->getType()->isIntegerTy())
@@ -2448,12 +2448,13 @@ static bool hoistMinMax(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
       P = ICmpInst::getInversePredicate(P);
     return true;
   };
-  ICmpInst::Predicate P1, P2;
+  CmpPredicate P1, P2;
   Value *LHS1, *LHS2, *RHS1, *RHS2;
   if (!MatchICmpAgainstInvariant(Cond1, P1, LHS1, RHS1) ||
       !MatchICmpAgainstInvariant(Cond2, P2, LHS2, RHS2))
     return false;
-  if (P1 != P2 || LHS1 != LHS2)
+  // FIXME: Use CmpPredicate::getMatching here.
+  if (P1 != static_cast<CmpInst::Predicate>(P2) || LHS1 != LHS2)
     return false;
 
   // Everything is fine, we can do the transform.
@@ -2678,7 +2679,7 @@ static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
                         MemorySSAUpdater &MSSAU, AssumptionCache *AC,
                         DominatorTree *DT) {
   using namespace PatternMatch;
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *LHS, *RHS;
   if (!match(&I, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
     return false;
diff --git a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
index ff077624802be2f..73f1942849ac2f6 100644
--- a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
@@ -32,7 +32,7 @@ struct ConditionInfo {
   /// ICmp instruction with this condition
   ICmpInst *ICmp = nullptr;
   /// Preciate info
-  ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
+  CmpPredicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
   /// AddRec llvm value
   Value *AddRecValue = nullptr;
   /// Non PHI AddRec llvm value
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 05cf638d3f09dff..ba1c2241aea9acd 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2432,7 +2432,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
 
   // Step 1: Check if the loop backedge is in desirable form.
 
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   Value *CmpLHS, *CmpRHS;
   BasicBlock *TrueBB, *FalseBB;
   if (!match(LoopHeaderBB->getTerminator(),
@@ -2797,7 +2797,7 @@ static bool detectShiftUntilZeroIdiom(Loop *CurLoop, ScalarEvolution *SE,
 
   // Step 1: Check if the loop backedge, condition is in desirable form.
 
-  ICmpInst::Predicate Pred;
+  CmpPredicate Pred;
   BasicBlock *TrueBB, *FalseBB;
   if (!match(LoopHeaderBB->getTerminator(),
              m_Br(m_Instruction(ValShiftedIsZero), m_BasicBlock(TrueBB),
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index d8ef450eeb9a15f..0712ff77151e293 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -2990,9 +2990,11 @@ static bool collectUnswitchCandidates(
 /// into its equivalent where `Pred` is something that we support for injected
 /// invariants (so far it is limited to ult), LHS in canonicalized form is
 /// non-invariant and RHS is an invariant.
-static void canonicalizeForInvariantConditionInjection(
-    ICmpInst::Predicate &Pred, Value *&LHS, Value *&RHS, BasicBlock *&IfTrue,
-    BasicBlock *&IfFalse, const Loop &L) {
+static void canonicalizeForInvariantConditionInjection(CmpPredicate &Pred,
+                                                       Value *&LHS, Value *&RHS,
+                                                       BasicBlock *&IfTrue,
+                                                       BasicBlock *&IfFalse,
+                                                       const Loop &L) {
   if (!L.contains(IfTrue)) {
     Pred = ICmpInst::getInversePredicate(Pred);
     std::swap(IfTrue, IfFalse);
@@ -3235,7 +3237,7 @@ static bool collectUnswitchCandidatesWithInjections(
   // other).
   for (auto *DTN = DT.getNode(Latch); L.contains(DTN->getBlock());
        DTN = DTN->getIDom()) {
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     Value *LHS = nullptr, *RHS = nullptr;
     BasicBlock *IfTrue = nullptr, *IfFalse = nullptr;
     auto *BB = DTN->getBlock();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 3cbde39b30b4e41..9a24c1b0d03de7e 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -378,7 +378,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
       return;
     }
 
-    CmpInst::Predicate Pred;
+    CmpPredicate Pred;
     if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal))))
       return;
 
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 791d528823972d7..0bc752a92340750 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1816,7 +1816,7 @@ bool SCEVExpander::hasRelatedExistingExpansion(const SCEV *S,
 
   // Look for suitable value in simple conditions at the loop exits.
   for (BasicBlock *BB : ExitingBlocks) {
-    ICmpInst::Predicate Pred;
+    CmpPredicate Pred;
     Instruction *LHS, *RHS;
 
     if (!match(BB->getTerminator(),
diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index 7fca1a6aa526054..f05d32d980e5a93 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -2164,16 +2164,14 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,
       !NarrowDefRHS->isNonNegative())
     return;
 
-  auto UpdateRangeFromCondition = [&] (Value *Condition,
-                                       bool TrueDest) {
-    CmpInst::Predicate Pred;
+  auto UpdateRangeFromCondition = [&](Value *Condition, bool TrueDest) {
+    CmpPredicate Pred;
     Value *CmpRHS;
     if (!match(Condition, m_ICmp(Pred, m_Specific(NarrowDefLHS),
                                  m_Value(CmpRHS))))
       return;
 
-    CmpInst::Predicate P =
-            TrueDest ? Pred : CmpInst::getInversePredicate(Pred);
+    CmpPredicate P = TrueDest ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
 
     auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS));
     auto CmpConstrainedLHSRange =
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dd1e53a05ebeb86..e2742fbd9b4f096 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -514,7 +514,7 @@ static bool isCommutative(Instruction *I) {
                 BO->uses(),
                 [](const Use &U) {
                   // Commutative, if icmp eq/ne sub, 0
-                  ICmpInst::Predicate Pred;
+                  CmpPredicate Pred;
                   if (match(U.getUser(),
                             m_ICmp(Pred, m_Specific(U.get()), m_Zero())) &&
                       (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE))
@@ -11463,7 +11463,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   case Instruction::FCmp:
   case Instruction::ICmp:
   case Instruction::Select: {
-    CmpInst::Predicate VecPred, SwappedVecPred;
+    CmpPredicate VecPred, SwappedVecPred;
     auto MatchCmp = m_Cmp(VecPred, m_Value(), m_Value());
     if (match(VL0, m_Select(MatchCmp, m_Value(), m_Value())) ||
         match(VL0, MatchCmp))
@@ -11477,13 +11477,15 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         return InstructionCost(TTI::TCC_Free);
 
       auto *VI = cast<Instruction>(UniqueValues[Idx]);
-      CmpInst::Predicate CurrentPred = ScalarTy->isFloatingPointTy()
-                                           ? CmpInst::BAD_FCMP_PREDICATE
-                                           : CmpInst::BAD_ICMP_PREDICATE;
+      CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy()
+                                     ? CmpInst::BAD_FCMP_PREDICATE
+                                     : CmpInst::BAD_ICMP_PREDICATE;
       auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value());
+      // FIXME: Use CmpPredicate::getMatching here.
       if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) &&
            !match(VI, MatchCmp)) ||
-          (CurrentPred != VecPred && CurrentPred != SwappedVecPred))
+          (CurrentPred != static_cast<CmpInst::Predicate>(VecPred) &&
+           CurrentPred != static_cast<CmpInst::Predicate>(SwappedVecPred)))
         VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy()
                                        ? CmpInst::BAD_FCMP_PREDICATE
                                        : CmpInst::BAD_ICMP_PREDICATE;
@@ -19391,7 +19393,7 @@ class HorizontalReduction {
       // %3 = extractelement <2 x i32> %a, i32 0
       // %4 = extractelement <2 x i32> %a, i32 1
       // %select = select i1 %cond, i32 %3, i32 %4
-      CmpInst::Predicate Pred;
+      CmpPredicate Pred;
       Instruction *L1;
       Instruction *L2;
 
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index ebbd05e6d47afcf..772e0aed9b6c02c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -596,7 +596,7 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
     return false;
 
   Instruction *I0, *I1;
-  CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+  CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
     return false;
@@ -922,7 +922,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
 /// Match a vector binop or compare instruction with at least one inserted
 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
-  CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+  CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
   Value *Ins0, *Ins1;
   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
@@ -1062,9 +1062,11 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
   Instruction *I0, *I1;
   Constant *C0, *C1;
-  CmpInst::Predicate P0, P1;
+  CmpPredicate P0, P1;
+  // FIXME: Use CmpPredicate::getMatching here.
   if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
-      !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) || P0 != P1)
+      !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) ||
+      P0 != static_cast<CmpInst::Predicate>(P1))
     return false;
 
   // The compare operands must be extracts of the same vector with constant
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 367ba6ab52a5965..47fde5782a13bce 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -2381,7 +2381,7 @@ TYPED_TEST(MutableConstTest, ICmp) {
 
   ValueType MatchL;
   ValueType MatchR;
-  ICmpInst::Predicate MatchPred;
+  CmpPredicate MatchPred;
 
   EXPECT_TRUE(m_ICmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
                   .match((InstructionType)IRB.CreateICmp(Pred, L, R)));
@@ -2473,7 +2473,7 @@ TYPED_TEST(MutableConstTest, FCmp) {
 
   ValueType MatchL;
   ValueType MatchR;
-  FCmpInst::Predicate MatchPred;
+  CmpPredicate MatchPred;
 
   EXPECT_TRUE(m_FCmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
                   .match((InstructionType)IRB.CreateFCmp(Pred, L, R)));

>From f46e64f8bf35881fdfb7bc15516c6a36ca0df1d6 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 13 Dec 2024 10:58:36 +0000
Subject: [PATCH 2/2] PatternMatch, CmpPredicate: address review

---
 llvm/include/llvm/IR/CmpPredicate.h |  7 +++----
 llvm/include/llvm/IR/PatternMatch.h |  2 +-
 llvm/lib/IR/Instructions.cpp        | 11 ++++-------
 3 files changed, 8 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
index ae027305cb4199d..ce78e4311f9f826 100644
--- a/llvm/include/llvm/IR/CmpPredicate.h
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -62,12 +62,11 @@ class CmpPredicate {
   bool operator==(CmpPredicate) const = delete;
   bool operator!=(CmpPredicate) const = delete;
 
-  /// TypeSwitch over the CmpInst and either do ICmpInst::getCmpPredicate() or
-  /// FCmpInst::getPredicate().
+  /// Do a ICmpInst::getCmpPredicate() or CmpInst::getPredicate(), as
+  /// appropriate.
   static CmpPredicate get(const CmpInst *Cmp);
 
-  /// Get the swapped predicate of a CmpPredicate, using
-  /// CmpInst::isIntPredicate().
+  /// Get the swapped predicate of a CmpPredicate.
   static CmpPredicate getSwapped(CmpPredicate P);
 
   /// Get the swapped predicate of a CmpInst.
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index ed7e1eff6f0005b..cc0e8d598ff1eaf 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -695,7 +695,7 @@ struct icmp_pred_with_threshold {
 /// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
 /// to Threshold. For vectors, this includes constants with undefined elements.
 inline cst_pred_ty<icmp_pred_with_threshold>
-m_SpecificInt_ICMP(CmpPredicate Predicate, const APInt &Threshold) {
+m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) {
   cst_pred_ty<icmp_pred_with_threshold> P;
   P.Pred = Predicate;
   P.Thr = &Threshold;
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index d27863574dcf25a..d1da02c744f18c6 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -16,7 +16,6 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
-#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -3934,15 +3933,13 @@ std::optional<CmpPredicate> CmpPredicate::getMatching(CmpPredicate A,
 }
 
 CmpPredicate CmpPredicate::get(const CmpInst *Cmp) {
-  return TypeSwitch<const CmpInst *, CmpPredicate>(Cmp)
-      .Case<ICmpInst>([](auto *ICI) { return ICI->getCmpPredicate(); })
-      .Case<FCmpInst>([](auto *FCI) { return FCI->getPredicate(); });
+  if (auto *ICI = dyn_cast<ICmpInst>(Cmp))
+    return ICI->getCmpPredicate();
+  return Cmp->getPredicate();
 }
 
 CmpPredicate CmpPredicate::getSwapped(CmpPredicate P) {
-  return CmpInst::isIntPredicate(P)
-             ? ICmpInst::getSwappedCmpPredicate(P)
-             : CmpPredicate{CmpInst::getSwappedPredicate(P)};
+  return {CmpInst::getSwappedPredicate(P), P.hasSameSign()};
 }
 
 CmpPredicate CmpPredicate::getSwapped(const CmpInst *Cmp) {



More information about the llvm-commits mailing list