[llvm] IR: introduce struct with CmpInst::Predicate and samesign (PR #116867)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 30 04:18:44 PST 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/116867

>From 1a3107561d3c0575e27fdfe36d7f5f9cda48467d Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 19 Nov 2024 15:25:05 +0000
Subject: [PATCH 1/3] IR: introduce ICmpInst::PredicateSign

Introduce CmpInst::PredicateSign, an abstraction over a floating-point
predicate, and a pack of an integer predicate with samesign information,
in order to ease extending large portions of the codebase that take a
CmpInst::Predicate to respect the samesign flag.

We have chosen to demonstrate the utility of this new abstraction by
migrating ValueTracking, InstructionSimplify, and InstCombine from
CmpInst::Predicate to CmpInst::PredicateSign. There should be no
functional changes, as we don't perform any extra optimizations with
samesign in this patch.

The design approach taken by this patch allows for unaudited callers of
APIs that take a CmpInst::PredicateSign to silently drop the samesign
information; it does not pose a correctness issue, and allows us to
migrate the codebase piece-wise.
---
 .../llvm/Analysis/InstSimplifyFolder.h        |  1 +
 .../llvm/Analysis/InstructionSimplify.h       |  7 +-
 llvm/include/llvm/Analysis/ValueTracking.h    |  7 +-
 llvm/include/llvm/IR/CmpPredicate.h           | 56 +++++++++++
 llvm/include/llvm/IR/InstrTypes.h             | 18 ++++
 llvm/include/llvm/IR/Instructions.h           | 75 +++++++++++----
 .../Transforms/InstCombine/InstCombiner.h     |  9 +-
 llvm/lib/Analysis/InstructionSimplify.cpp     | 96 +++++++++----------
 llvm/lib/Analysis/ValueTracking.cpp           | 32 +++----
 llvm/lib/IR/Instructions.cpp                  | 11 +--
 .../InstCombine/InstCombineAndOrXor.cpp       |  2 +-
 .../InstCombine/InstCombineCompares.cpp       | 35 ++++---
 .../InstCombine/InstCombineInternal.h         | 17 ++--
 .../InstCombine/InstructionCombining.cpp      |  6 +-
 llvm/unittests/IR/InstructionsTest.cpp        | 22 +++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 13 ++-
 16 files changed, 269 insertions(+), 138 deletions(-)
 create mode 100644 llvm/include/llvm/IR/CmpPredicate.h

diff --git a/llvm/include/llvm/Analysis/InstSimplifyFolder.h b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
index 430c3edc2f0dc7..d4ae4dcc918cf3 100644
--- a/llvm/include/llvm/Analysis/InstSimplifyFolder.h
+++ b/llvm/include/llvm/Analysis/InstSimplifyFolder.h
@@ -22,6 +22,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/TargetFolder.h"
+#include "llvm/IR/CmpPredicate.h"
 #include "llvm/IR/IRBuilderFolder.h"
 #include "llvm/IR/Instruction.h"
 
diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index cf7d3e044188a6..fa291eeef198b9 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -44,6 +44,7 @@ class DataLayout;
 class DominatorTree;
 class Function;
 class Instruction;
+class CmpPredicate;
 class LoadInst;
 struct LoopStandardAnalysisResults;
 class Pass;
@@ -152,11 +153,11 @@ Value *simplifyOrInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 Value *simplifyXorInst(Value *LHS, Value *RHS, const SimplifyQuery &Q);
 
 /// Given operands for an ICmpInst, fold the result or return null.
-Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
                         const SimplifyQuery &Q);
 
 /// Given operands for an FCmpInst, fold the result or return null.
-Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                         FastMathFlags FMF, const SimplifyQuery &Q);
 
 /// Given operands for a SelectInst, fold the result or return null.
@@ -200,7 +201,7 @@ Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1, ArrayRef<int> Mask,
 //=== Helper functions for higher up the class hierarchy.
 
 /// Given operands for a CmpInst, fold the result or return null.
-Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q);
 
 /// Given operand for a UnaryOperator, fold the result or return null.
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 2b0377903ac8e3..3bc81705fb814b 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1245,8 +1245,7 @@ std::optional<bool> isImpliedCondition(const Value *LHS, const Value *RHS,
                                        const DataLayout &DL,
                                        bool LHSIsTrue = true,
                                        unsigned Depth = 0);
-std::optional<bool> isImpliedCondition(const Value *LHS,
-                                       CmpInst::Predicate RHSPred,
+std::optional<bool> isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
                                        const Value *RHSOp0, const Value *RHSOp1,
                                        const DataLayout &DL,
                                        bool LHSIsTrue = true,
@@ -1257,8 +1256,8 @@ std::optional<bool> isImpliedCondition(const Value *LHS,
 std::optional<bool> isImpliedByDomCondition(const Value *Cond,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
-std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
-                                            const Value *LHS, const Value *RHS,
+std::optional<bool> isImpliedByDomCondition(CmpPredicate Pred, const Value *LHS,
+                                            const Value *RHS,
                                             const Instruction *ContextI,
                                             const DataLayout &DL);
 
diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
new file mode 100644
index 00000000000000..99754fdf2b7686
--- /dev/null
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -0,0 +1,56 @@
+//===- CmpPredicate.h - CmpInst Predicate with samesign information -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A CmpInst::Predicate with any samesign information (applicable to ICmpInst).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_CMPPREDICATE_H
+#define LLVM_IR_CMPPREDICATE_H
+
+#include "llvm/IR/InstrTypes.h"
+
+namespace llvm {
+/// An abstraction over a floating-point predicate, and a pack of an integer
+/// predicate with samesign information. Functions in ICmpInst construct and
+/// return this type in place of a Predicate. It is also implictly constructed
+/// with a Predicate, dropping samesign information.
+class CmpPredicate {
+  CmpInst::Predicate Pred;
+  bool HasSameSign;
+
+public:
+  CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false)
+      : Pred(Pred), HasSameSign(HasSameSign) {
+    assert(!HasSameSign || CmpInst::isIntPredicate(Pred));
+  }
+
+  inline operator CmpInst::Predicate() const { return Pred; }
+
+  inline bool hasSameSign() const { return HasSameSign; }
+
+  static std::optional<CmpPredicate> getMatching(CmpPredicate A,
+                                                 CmpPredicate B) {
+    if (A.Pred == B.Pred)
+      return A.HasSameSign == B.HasSameSign ? A : CmpPredicate(A.Pred);
+    if (A.HasSameSign &&
+        A.Pred == CmpInst::getFlippedSignednessPredicate(B.Pred))
+      return B.Pred;
+    if (B.HasSameSign &&
+        B.Pred == CmpInst::getFlippedSignednessPredicate(A.Pred))
+      return A.Pred;
+    return {};
+  }
+
+  inline bool operator==(CmpInst::Predicate P) const { return Pred == P; }
+
+  inline bool operator==(CmpPredicate) const = delete;
+};
+} // namespace llvm
+
+#endif
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index e6332a16df7d5f..dd31f2e3d0a747 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -728,6 +728,24 @@ class CmpInst : public Instruction {
           InsertPosition InsertBefore = nullptr,
           Instruction *FlagsSource = nullptr);
 
+  /// Return the signed version of the predicate: variant that operates on
+  /// Predicate; used by the corresponding function in ICmpInst, to operate with
+  /// CmpPredicate.
+  static Predicate getSignedPredicate(Predicate Pred);
+
+  /// Return the unsigned version of the predicate: variant that operates on
+  /// Predicate; used by the corresponding function in ICmpInst, to operate with
+  /// CmpPredicate.
+  static Predicate getUnsignedPredicate(Predicate Pred);
+
+  /// Return the unsigned version of the signed predicate pred or the signed
+  /// version of the signed predicate pred: variant that operates on Predicate;
+  /// used by the corresponding function in ICmpInst, to operate with
+  /// CmpPredicate.
+  static Predicate getFlippedSignednessPredicate(Predicate Pred);
+
+  friend class CmpPredicate;
+
 public:
   // allocate space for exactly two operands
   void *operator new(size_t S) { return User::operator new(S, AllocMarker); }
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 605964af5d676c..4d3def15fad734 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -24,6 +24,7 @@
 #include "llvm/ADT/iterator.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/IR/CFG.h"
+#include "llvm/IR/CmpPredicate.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GEPNoWrapFlags.h"
@@ -1203,38 +1204,78 @@ class ICmpInst: public CmpInst {
 #endif
   }
 
+  /// @returns the predicate along with samesign information.
+  CmpPredicate getCmpPredicate() const {
+    return {getPredicate(), hasSameSign()};
+  }
+
+  /// @returns the inverse predicate along with samesign information: static
+  /// variant.
+  static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred) {
+    return {getInversePredicate(Pred), Pred.hasSameSign()};
+  }
+
+  /// @returns the inverse predicate along with samesign information.
+  CmpPredicate getInverseCmpPredicate() const {
+    return getInverseCmpPredicate(getCmpPredicate());
+  }
+
+  /// @returns the swapped predicate along with samesign information: static
+  /// variant.
+  static CmpPredicate getSwappedCmpPredicate(CmpPredicate Pred) {
+    return {getSwappedPredicate(Pred), Pred.hasSameSign()};
+  }
+
+  /// @returns the swapped predicate along with samesign information.
+  CmpPredicate getSwappedCmpPredicate() const {
+    return getSwappedPredicate(getCmpPredicate());
+  }
+
   /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
   /// @returns the predicate that would be the result if the operand were
   /// regarded as signed.
-  /// Return the signed version of the predicate.
-  Predicate getSignedPredicate() const {
-    return getSignedPredicate(getPredicate());
+  /// Return the signed version of the predicate along with samesign
+  /// information.
+  CmpPredicate getSignedPredicate() const {
+    return getSignedPredicate(getCmpPredicate());
   }
 
-  /// Return the signed version of the predicate: static variant.
-  static Predicate getSignedPredicate(Predicate pred);
+  /// Return the signed version of the predicate along with samesign
+  /// information: static variant.
+  static CmpPredicate getSignedPredicate(CmpPredicate Pred) {
+    return {CmpInst::getSignedPredicate(Pred), Pred.hasSameSign()};
+  }
 
   /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
   /// @returns the predicate that would be the result if the operand were
   /// regarded as unsigned.
-  /// Return the unsigned version of the predicate.
-  Predicate getUnsignedPredicate() const {
-    return getUnsignedPredicate(getPredicate());
+  /// Return the unsigned version of the predicate along with samesign
+  /// information.
+  CmpPredicate getUnsignedPredicate() const {
+    return getUnsignedPredicate(getCmpPredicate());
   }
 
-  /// Return the unsigned version of the predicate: static variant.
-  static Predicate getUnsignedPredicate(Predicate pred);
+  /// Return the unsigned version of the predicate along with samesign
+  /// information: static variant.
+  static CmpPredicate getUnsignedPredicate(CmpPredicate Pred) {
+    return {CmpInst::getUnsignedPredicate(Pred), Pred.hasSameSign()};
+  }
 
-  /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
+  /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ
   /// @returns the unsigned version of the signed predicate pred or
-  ///          the signed version of the signed predicate pred.
-  static Predicate getFlippedSignednessPredicate(Predicate pred);
+  ///          the signed version of the signed predicate pred, along with
+  ///          samesign information.
+  /// Static variant.
+  static CmpPredicate getFlippedSignednessPredicate(CmpPredicate Pred) {
+    return {CmpInst::getFlippedSignednessPredicate(Pred), Pred.hasSameSign()};
+  }
 
-  /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert
+  /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ
   /// @returns the unsigned version of the signed predicate pred or
-  ///          the signed version of the signed predicate pred.
-  Predicate getFlippedSignednessPredicate() const {
-    return getFlippedSignednessPredicate(getPredicate());
+  ///          the signed version of the signed predicate pred, along with
+  ///          samesign information.
+  CmpPredicate getFlippedSignednessPredicate() const {
+    return getFlippedSignednessPredicate(getCmpPredicate());
   }
 
   void setSameSign(bool B = true) {
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 3075b7ebae59e6..71592058e34563 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -157,7 +157,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   /// conditional branch or select to create a compare with a canonical
   /// (inverted) predicate which is then more likely to be matched with other
   /// values.
-  static bool isCanonicalPredicate(CmpInst::Predicate Pred) {
+  static bool isCanonicalPredicate(CmpPredicate Pred) {
     switch (Pred) {
     case CmpInst::ICMP_NE:
     case CmpInst::ICMP_ULE:
@@ -185,10 +185,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
   }
 
   std::optional<std::pair<
-      CmpInst::Predicate,
-      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpInst::
-                                                                       Predicate
-                                                                           Pred,
+      CmpPredicate,
+      Constant *>> static getFlippedStrictnessPredicateAndConstant(CmpPredicate
+                                                                       Pred,
                                                                    Constant *C);
 
   static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) {
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 01b0a089aab718..c5c8495db958c2 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -63,9 +63,9 @@ static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &,
                             unsigned);
 static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &,
                             const SimplifyQuery &, unsigned);
-static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &,
-                              unsigned);
-static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+static Value *simplifyCmpInst(CmpPredicate, Value *, Value *,
+                              const SimplifyQuery &, unsigned);
+static Value *simplifyICmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                                const SimplifyQuery &Q, unsigned MaxRecurse);
 static Value *simplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned);
 static Value *simplifyXorInst(Value *, Value *, const SimplifyQuery &,
@@ -132,8 +132,7 @@ static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); }
 static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); }
 
 /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"?
-static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
-                          Value *RHS) {
+static bool isSameCompare(Value *V, CmpPredicate Pred, Value *LHS, Value *RHS) {
   CmpInst *Cmp = dyn_cast<CmpInst>(V);
   if (!Cmp)
     return false;
@@ -150,10 +149,9 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,
 ///  %cmp = icmp sle i32 %sel, %rhs
 /// Compose new comparison by substituting %sel with either %tv or %fv
 /// and see if it simplifies.
-static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
-                                 Value *RHS, Value *Cond,
-                                 const SimplifyQuery &Q, unsigned MaxRecurse,
-                                 Constant *TrueOrFalse) {
+static Value *simplifyCmpSelCase(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                 Value *Cond, const SimplifyQuery &Q,
+                                 unsigned MaxRecurse, Constant *TrueOrFalse) {
   Value *SimplifiedCmp = simplifyCmpInst(Pred, LHS, RHS, Q, MaxRecurse);
   if (SimplifiedCmp == Cond) {
     // %cmp simplified to the select condition (%cond).
@@ -167,18 +165,16 @@ static Value *simplifyCmpSelCase(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Simplify comparison with true branch of select
-static Value *simplifyCmpSelTrueCase(CmpInst::Predicate Pred, Value *LHS,
-                                     Value *RHS, Value *Cond,
-                                     const SimplifyQuery &Q,
+static Value *simplifyCmpSelTrueCase(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                     Value *Cond, const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
   return simplifyCmpSelCase(Pred, LHS, RHS, Cond, Q, MaxRecurse,
                             getTrue(Cond->getType()));
 }
 
 /// Simplify comparison with false branch of select
-static Value *simplifyCmpSelFalseCase(CmpInst::Predicate Pred, Value *LHS,
-                                      Value *RHS, Value *Cond,
-                                      const SimplifyQuery &Q,
+static Value *simplifyCmpSelFalseCase(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                      Value *Cond, const SimplifyQuery &Q,
                                       unsigned MaxRecurse) {
   return simplifyCmpSelCase(Pred, LHS, RHS, Cond, Q, MaxRecurse,
                             getFalse(Cond->getType()));
@@ -471,9 +467,8 @@ static Value *threadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,
 /// We can simplify %cmp1 to true, because both branches of select are
 /// less than 3. We compose new comparison by substituting %tmp with both
 /// branches of select and see if it can be simplified.
-static Value *threadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS,
-                                  Value *RHS, const SimplifyQuery &Q,
-                                  unsigned MaxRecurse) {
+static Value *threadCmpOverSelect(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                  const SimplifyQuery &Q, unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
   if (!MaxRecurse--)
     return nullptr;
@@ -564,7 +559,7 @@ static Value *threadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS,
 /// comparison by seeing whether comparing with all of the incoming phi values
 /// yields the same result every time. If so returns the common result,
 /// otherwise returns null.
-static Value *threadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,
+static Value *threadCmpOverPHI(CmpPredicate Pred, Value *LHS, Value *RHS,
                                const SimplifyQuery &Q, unsigned MaxRecurse) {
   // Recursion is always used, so bail out at once if we already hit the limit.
   if (!MaxRecurse--)
@@ -1001,7 +996,7 @@ Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
 /// Given a predicate and two operands, return true if the comparison is true.
 /// This is a helper for div/rem simplification where we return some other value
 /// when we can prove a relationship between the operands.
-static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS,
+static bool isICmpTrue(CmpPredicate Pred, Value *LHS, Value *RHS,
                        const SimplifyQuery &Q, unsigned MaxRecurse) {
   Value *V = simplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse);
   Constant *C = dyn_cast_or_null<Constant>(V);
@@ -2601,7 +2596,7 @@ static Type *getCompareTy(Value *Op) {
 /// Rummage around inside V looking for something equivalent to the comparison
 /// "LHS Pred RHS". Return such a value if found, otherwise return null.
 /// Helper function for analyzing max/min idioms.
-static Value *extractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
+static Value *extractEquivalentCondition(Value *V, CmpPredicate Pred,
                                          Value *LHS, Value *RHS) {
   SelectInst *SI = dyn_cast<SelectInst>(V);
   if (!SI)
@@ -2710,8 +2705,8 @@ static bool haveNonOverlappingStorage(const Value *V1, const Value *V2) {
 // If the C and C++ standards are ever made sufficiently restrictive in this
 // area, it may be possible to update LLVM's semantics accordingly and reinstate
 // this optimization.
-static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
-                                    Value *RHS, const SimplifyQuery &Q) {
+static Constant *computePointerICmp(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                    const SimplifyQuery &Q) {
   assert(LHS->getType() == RHS->getType() && "Must have same types");
   const DataLayout &DL = Q.DL;
   const TargetLibraryInfo *TLI = Q.TLI;
@@ -2859,8 +2854,8 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Fold an icmp when its operands have i1 scalar type.
-static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
-                                  Value *RHS, const SimplifyQuery &Q) {
+static Value *simplifyICmpOfBools(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                  const SimplifyQuery &Q) {
   Type *ITy = getCompareTy(LHS); // The return type.
   Type *OpTy = LHS->getType();   // The operand type.
   if (!OpTy->isIntOrIntVectorTy(1))
@@ -2962,8 +2957,8 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
 }
 
 /// Try hard to fold icmp with zero RHS because this is a common case.
-static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
-                                   Value *RHS, const SimplifyQuery &Q) {
+static Value *simplifyICmpWithZero(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                   const SimplifyQuery &Q) {
   if (!match(RHS, m_Zero()))
     return nullptr;
 
@@ -3022,7 +3017,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
+static Value *simplifyICmpWithConstant(CmpPredicate Pred, Value *LHS,
                                        Value *RHS, const InstrInfoQuery &IIQ) {
   Type *ITy = getCompareTy(RHS); // The return type.
 
@@ -3070,9 +3065,8 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
-                                         BinaryOperator *LBO, Value *RHS,
-                                         const SimplifyQuery &Q,
+static Value *simplifyICmpWithBinOpOnLHS(CmpPredicate Pred, BinaryOperator *LBO,
+                                         Value *RHS, const SimplifyQuery &Q,
                                          unsigned MaxRecurse) {
   Type *ITy = getCompareTy(RHS); // The return type.
 
@@ -3227,8 +3221,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
 // *) C1 < C2 && C1 >= 0, or
 // *) C2 < C1 && C1 <= 0.
 //
-static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
-                                    Value *RHS, const InstrInfoQuery &IIQ) {
+static bool trySimplifyICmpWithAdds(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                    const InstrInfoQuery &IIQ) {
   // TODO: only support icmp slt for now.
   if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo)
     return false;
@@ -3252,8 +3246,8 @@ static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
 /// TODO: A large part of this logic is duplicated in InstCombine's
 /// foldICmpBinOp(). We should be able to share that and avoid the code
 /// duplication.
-static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
-                                    Value *RHS, const SimplifyQuery &Q,
+static Value *simplifyICmpWithBinOp(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                    const SimplifyQuery &Q,
                                     unsigned MaxRecurse) {
   BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
   BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS);
@@ -3486,8 +3480,8 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
 
 /// simplify integer comparisons where at least one operand of the compare
 /// matches an integer min/max idiom.
-static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,
-                                     Value *RHS, const SimplifyQuery &Q,
+static Value *simplifyICmpWithMinMax(CmpPredicate Pred, Value *LHS, Value *RHS,
+                                     const SimplifyQuery &Q,
                                      unsigned MaxRecurse) {
   Type *ITy = getCompareTy(LHS); // The return type.
   Value *A, *B;
@@ -3671,7 +3665,7 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
-static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate,
+static Value *simplifyICmpWithDominatingAssume(CmpPredicate Predicate,
                                                Value *LHS, Value *RHS,
                                                const SimplifyQuery &Q) {
   // Gracefully handle instructions that have not been inserted yet.
@@ -3694,8 +3688,8 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate,
   return nullptr;
 }
 
-static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
-                                             Value *LHS, Value *RHS) {
+static Value *simplifyICmpWithIntrinsicOnLHS(CmpPredicate Pred, Value *LHS,
+                                             Value *RHS) {
   auto *II = dyn_cast<IntrinsicInst>(LHS);
   if (!II)
     return nullptr;
@@ -3757,9 +3751,8 @@ static std::optional<ConstantRange> getRange(Value *V,
 
 /// Given operands for an ICmpInst, see if we can fold the result.
 /// If not, this returns null.
-static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+static Value *simplifyICmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
                                const SimplifyQuery &Q, unsigned MaxRecurse) {
-  CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
   assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!");
 
   if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
@@ -4066,17 +4059,16 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   return nullptr;
 }
 
-Value *llvm::simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *llvm::simplifyICmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                               const SimplifyQuery &Q) {
   return ::simplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit);
 }
 
 /// Given operands for an FCmpInst, see if we can fold the result.
 /// If not, this returns null.
-static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+static Value *simplifyFCmpInst(CmpPredicate Pred, Value *LHS, Value *RHS,
                                FastMathFlags FMF, const SimplifyQuery &Q,
                                unsigned MaxRecurse) {
-  CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
   assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!");
 
   if (Constant *CLHS = dyn_cast<Constant>(LHS)) {
@@ -4301,7 +4293,7 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   return nullptr;
 }
 
-Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                               FastMathFlags FMF, const SimplifyQuery &Q) {
   return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
 }
@@ -4538,7 +4530,7 @@ static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
 }
 
 static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
-                                     ICmpInst::Predicate Pred, Value *TVal,
+                                     CmpPredicate Pred, Value *TVal,
                                      Value *FVal) {
   // Canonicalize common cmp+sel operand as CmpLHS.
   if (CmpRHS == TVal || CmpRHS == FVal) {
@@ -4612,8 +4604,8 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
 /// An alternative way to test if a bit is set or not uses sgt/slt instead of
 /// eq/ne.
 static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
-                                           ICmpInst::Predicate Pred,
-                                           Value *TrueVal, Value *FalseVal) {
+                                           CmpPredicate Pred, Value *TrueVal,
+                                           Value *FalseVal) {
   if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred))
     return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask,
                                  Res->Pred == ICmpInst::ICMP_EQ);
@@ -6123,14 +6115,14 @@ Value *llvm::simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
 }
 
 /// Given operands for a CmpInst, see if we can fold the result.
-static Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+static Value *simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                               const SimplifyQuery &Q, unsigned MaxRecurse) {
-  if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate))
+  if (CmpInst::isIntPredicate(Predicate))
     return simplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse);
   return simplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse);
 }
 
-Value *llvm::simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
+Value *llvm::simplifyCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
                              const SimplifyQuery &Q) {
   return ::simplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit);
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index c48068afc04816..490f4dc8056941 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9104,7 +9104,7 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
 }
 
 /// Return true if "icmp Pred LHS RHS" is always true.
-static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
+static bool isTruePredicate(CmpPredicate Pred, const Value *LHS,
                             const Value *RHS) {
   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
     return true;
@@ -9186,8 +9186,8 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
 /// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
 /// ALHS ARHS" is true.  Otherwise, return std::nullopt.
 static std::optional<bool>
-isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
-                      const Value *ARHS, const Value *BLHS, const Value *BRHS) {
+isImpliedCondOperands(CmpPredicate Pred, const Value *ALHS, const Value *ARHS,
+                      const Value *BLHS, const Value *BRHS) {
   switch (Pred) {
   default:
     return std::nullopt;
@@ -9256,18 +9256,16 @@ static std::optional<bool> isImpliedCondCommonOperandWithCR(
 /// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
 /// is true.  Return false if LHS implies RHS is false. Otherwise, return
 /// std::nullopt if we can't infer anything.
-static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
-                                              CmpInst::Predicate RPred,
-                                              const Value *R0, const Value *R1,
-                                              const DataLayout &DL,
-                                              bool LHSIsTrue) {
+static std::optional<bool>
+isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
+                   const Value *R1, const DataLayout &DL, bool LHSIsTrue) {
   Value *L0 = LHS->getOperand(0);
   Value *L1 = LHS->getOperand(1);
 
   // The rest of the logic assumes the LHS condition is true.  If that's not the
   // case, invert the predicate to make it so.
-  CmpInst::Predicate LPred =
-      LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate();
+  CmpPredicate LPred =
+      LHSIsTrue ? LHS->getCmpPredicate() : LHS->getInverseCmpPredicate();
 
   // We can have non-canonical operands, so try to normalize any common operand
   // to L0/R0.
@@ -9342,10 +9340,10 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
       (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
       (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
       match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
-    return LPred == RPred;
+    return CmpPredicate::getMatching(LPred, RPred).has_value();
 
-  if (LPred == RPred)
-    return isImpliedCondOperands(LPred, L0, L1, R0, R1);
+  if (auto P = CmpPredicate::getMatching(LPred, RPred))
+    return isImpliedCondOperands(*P, L0, L1, R0, R1);
 
   return std::nullopt;
 }
@@ -9355,7 +9353,7 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
 /// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select'
 /// instruction.
 static std::optional<bool>
-isImpliedCondAndOr(const Instruction *LHS, CmpInst::Predicate RHSPred,
+isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred,
                    const Value *RHSOp0, const Value *RHSOp1,
                    const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
   // The LHS must be an 'or', 'and', or a 'select' instruction.
@@ -9385,7 +9383,7 @@ isImpliedCondAndOr(const Instruction *LHS, CmpInst::Predicate RHSPred,
 }
 
 std::optional<bool>
-llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred,
+llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
                          const Value *RHSOp0, const Value *RHSOp1,
                          const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
   // Bail out when we hit the limit.
@@ -9439,7 +9437,7 @@ std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
 
   if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS)) {
     if (auto Implied = isImpliedCondition(
-            LHS, RHSCmp->getPredicate(), RHSCmp->getOperand(0),
+            LHS, RHSCmp->getCmpPredicate(), RHSCmp->getOperand(0),
             RHSCmp->getOperand(1), DL, LHSIsTrue, Depth))
       return InvertRHS ? !*Implied : *Implied;
     return std::nullopt;
@@ -9516,7 +9514,7 @@ std::optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
   return std::nullopt;
 }
 
-std::optional<bool> llvm::isImpliedByDomCondition(CmpInst::Predicate Pred,
+std::optional<bool> llvm::isImpliedByDomCondition(CmpPredicate Pred,
                                                   const Value *LHS,
                                                   const Value *RHS,
                                                   const Instruction *ContextI,
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 065ce3a0172837..3c529203e32600 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3581,7 +3581,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, CmpInst::Predicate Pred) {
   return OS;
 }
 
-ICmpInst::Predicate ICmpInst::getSignedPredicate(Predicate pred) {
+ICmpInst::Predicate CmpInst::getSignedPredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown icmp predicate!");
     case ICMP_EQ: case ICMP_NE:
@@ -3594,7 +3594,7 @@ ICmpInst::Predicate ICmpInst::getSignedPredicate(Predicate pred) {
   }
 }
 
-ICmpInst::Predicate ICmpInst::getUnsignedPredicate(Predicate pred) {
+ICmpInst::Predicate CmpInst::getUnsignedPredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown icmp predicate!");
     case ICMP_EQ: case ICMP_NE:
@@ -3841,10 +3841,9 @@ std::optional<bool> ICmpInst::compare(const KnownBits &LHS,
   }
 }
 
-CmpInst::Predicate ICmpInst::getFlippedSignednessPredicate(Predicate pred) {
-  assert(CmpInst::isRelational(pred) &&
-         "Call only with non-equality predicates!");
-
+CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
+  if (CmpInst::isEquality(pred))
+    return pred;
   if (isSigned(pred))
     return getUnsignedPredicate(pred);
   if (isUnsigned(pred))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index b4033fc2a418a1..24a969f221e5bb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -59,7 +59,7 @@ Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo,
 
   // V >= Min && V <  Hi --> V <  Hi
   // V <  Min || V >= Hi --> V >= Hi
-  ICmpInst::Predicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE;
+  CmpPredicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE;
   if (isSigned ? Lo.isMinSignedValue() : Lo.isMinValue()) {
     Pred = isSigned ? ICmpInst::getSignedPredicate(Pred) : Pred;
     return Builder.CreateICmp(Pred, V, ConstantInt::get(Ty, Hi));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fed21db393ed22..419591ac56b207 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -631,7 +631,7 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base, GEPNoWrapFlags NW,
 /// We can look through PHIs, GEPs and casts in order to determine a common base
 /// between GEPLHS and RHS.
 static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
-                                              ICmpInst::Predicate Cond,
+                                              CmpPredicate Cond,
                                               const DataLayout &DL,
                                               InstCombiner &IC) {
   // FIXME: Support vector of pointers.
@@ -675,8 +675,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS,
 /// Fold comparisons between a GEP instruction and something else. At this point
 /// we know that the GEP is on the LHS of the comparison.
 Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
-                                           ICmpInst::Predicate Cond,
-                                           Instruction &I) {
+                                           CmpPredicate Cond, Instruction &I) {
   // Don't transform signed compares of GEPs into index compares. Even if the
   // GEP is inbounds, the final add of the base pointer can have signed overflow
   // and would change the result of the icmp.
@@ -912,7 +911,7 @@ bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) {
 
 /// Fold "icmp pred (X+C), X".
 Instruction *InstCombinerImpl::foldICmpAddOpConst(Value *X, const APInt &C,
-                                                  ICmpInst::Predicate Pred) {
+                                                  CmpPredicate Pred) {
   // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
   // so the values can never be equal.  Similarly for all other "or equals"
   // operators.
@@ -3949,8 +3948,8 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
 }
 
 static Instruction *
-foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred,
-                                     SaturatingInst *II, const APInt &C,
+foldICmpUSubSatOrUAddSatWithConstant(CmpPredicate Pred, SaturatingInst *II,
+                                     const APInt &C,
                                      InstCombiner::BuilderTy &Builder) {
   // This transform may end up producing more than one instruction for the
   // intrinsic, so limit it to one user of the intrinsic.
@@ -4034,7 +4033,7 @@ foldICmpUSubSatOrUAddSatWithConstant(ICmpInst::Predicate Pred,
 }
 
 static Instruction *
-foldICmpOfCmpIntrinsicWithConstant(ICmpInst::Predicate Pred, IntrinsicInst *I,
+foldICmpOfCmpIntrinsicWithConstant(CmpPredicate Pred, IntrinsicInst *I,
                                    const APInt &C,
                                    InstCombiner::BuilderTy &Builder) {
   std::optional<ICmpInst::Predicate> NewPredicate = std::nullopt;
@@ -4233,9 +4232,8 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
   return nullptr;
 }
 
-Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
-                                              SelectInst *SI, Value *RHS,
-                                              const ICmpInst &I) {
+Instruction *InstCombinerImpl::foldSelectICmp(CmpPredicate Pred, SelectInst *SI,
+                                              Value *RHS, const ICmpInst &I) {
   // Try to fold the comparison into the select arms, which will cause the
   // select to be converted into a logical and/or.
   auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * {
@@ -4404,7 +4402,7 @@ static bool isMaskOrZero(const Value *V, bool Not, const SimplifyQuery &Q,
 /// The Mask can be a constant, too.
 /// For some predicates, the operands are commutative.
 /// For others, x can only be on a specific side.
-static Value *foldICmpWithLowBitMaskedVal(ICmpInst::Predicate Pred, Value *Op0,
+static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
                                           Value *Op1, const SimplifyQuery &Q,
                                           InstCombiner &IC) {
 
@@ -5515,8 +5513,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
 /// Fold icmp Pred min|max(X, Y), Z.
 Instruction *InstCombinerImpl::foldICmpWithMinMax(Instruction &I,
                                                   MinMaxIntrinsic *MinMax,
-                                                  Value *Z,
-                                                  ICmpInst::Predicate Pred) {
+                                                  Value *Z, CmpPredicate Pred) {
   Value *X = MinMax->getLHS();
   Value *Y = MinMax->getRHS();
   if (ICmpInst::isSigned(Pred) && !MinMax->isSigned())
@@ -6869,8 +6866,8 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
   return nullptr;
 }
 
-std::optional<std::pair<CmpInst::Predicate, Constant *>>
-InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
+std::optional<std::pair<CmpPredicate, Constant *>>
+InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred,
                                                        Constant *C) {
   assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
          "Only for relational integer predicates.");
@@ -7276,7 +7273,7 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
 }
 
 // This helper will be called with icmp operands in both orders.
-Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
+Instruction *InstCombinerImpl::foldICmpCommutative(CmpPredicate Pred,
                                                    Value *Op0, Value *Op1,
                                                    ICmpInst &CxtI) {
   // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'.
@@ -7404,7 +7401,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
     Changed = true;
   }
 
-  if (Value *V = simplifyICmpInst(I.getPredicate(), Op0, Op1, Q))
+  if (Value *V = simplifyICmpInst(I.getCmpPredicate(), Op0, Op1, Q))
     return replaceInstUsesWith(I, V);
 
   // Comparing -val or val with non-zero is the same as just comparing val
@@ -7511,10 +7508,10 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
   if (Instruction *Res = foldICmpInstWithConstantNotInt(I))
     return Res;
 
-  if (Instruction *Res = foldICmpCommutative(I.getPredicate(), Op0, Op1, I))
+  if (Instruction *Res = foldICmpCommutative(I.getCmpPredicate(), Op0, Op1, I))
     return Res;
   if (Instruction *Res =
-          foldICmpCommutative(I.getSwappedPredicate(), Op1, Op0, I))
+          foldICmpCommutative(I.getSwappedCmpPredicate(), Op1, Op0, I))
     return Res;
 
   if (I.isCommutative()) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 0508ed48fc19c4..28474fec8238ee 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -652,10 +652,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// folded operation.
   void PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN);
 
-  Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
-                           ICmpInst::Predicate Cond, Instruction &I);
-  Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI,
-                              Value *RHS, const ICmpInst &I);
+  Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, CmpPredicate Cond,
+                           Instruction &I);
+  Instruction *foldSelectICmp(CmpPredicate Pred, SelectInst *SI, Value *RHS,
+                              const ICmpInst &I);
   bool foldAllocaCmp(AllocaInst *Alloca);
   Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI,
                                             GetElementPtrInst *GEP,
@@ -663,8 +663,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                             ConstantInt *AndCst = nullptr);
   Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
                                     Constant *RHSC);
-  Instruction *foldICmpAddOpConst(Value *X, const APInt &C,
-                                  ICmpInst::Predicate Pred);
+  Instruction *foldICmpAddOpConst(Value *X, const APInt &C, CmpPredicate Pred);
   Instruction *foldICmpWithCastOp(ICmpInst &ICmp);
   Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp);
 
@@ -678,7 +677,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                                    const APInt &C);
   Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
   Instruction *foldICmpWithMinMax(Instruction &I, MinMaxIntrinsic *MinMax,
-                                  Value *Z, ICmpInst::Predicate Pred);
+                                  Value *Z, CmpPredicate Pred);
   Instruction *foldICmpEquality(ICmpInst &Cmp);
   Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I);
   Instruction *foldSignBitTest(ICmpInst &I);
@@ -736,8 +735,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                                const APInt &C);
   Instruction *foldICmpBitCast(ICmpInst &Cmp);
   Instruction *foldICmpWithTrunc(ICmpInst &Cmp);
-  Instruction *foldICmpCommutative(ICmpInst::Predicate Pred, Value *Op0,
-                                   Value *Op1, ICmpInst &CxtI);
+  Instruction *foldICmpCommutative(CmpPredicate Pred, Value *Op0, Value *Op1,
+                                   ICmpInst &CxtI);
 
   // Helpers of visitSelectInst().
   Instruction *foldSelectOfBools(SelectInst &SI);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 32f2a30afad48f..3325a1868ebde4 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1753,9 +1753,9 @@ static Value *simplifyInstructionWithPHI(Instruction &I, PHINode *PN,
   if (TerminatorBI && TerminatorBI->isConditional() &&
       TerminatorBI->getSuccessor(0) != TerminatorBI->getSuccessor(1) && ICmp) {
     bool LHSIsTrue = TerminatorBI->getSuccessor(0) == PN->getParent();
-    std::optional<bool> ImpliedCond =
-        isImpliedCondition(TerminatorBI->getCondition(), ICmp->getPredicate(),
-                           Ops[0], Ops[1], DL, LHSIsTrue);
+    std::optional<bool> ImpliedCond = isImpliedCondition(
+        TerminatorBI->getCondition(), ICmp->getCmpPredicate(), Ops[0], Ops[1],
+        DL, LHSIsTrue);
     if (ImpliedCond)
       return ConstantInt::getBool(I.getType(), ImpliedCond.value());
   }
diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index 0af812564c0267..b4dbc4ed435aad 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -1923,5 +1923,27 @@ TEST(InstructionsTest, AtomicSyncscope) {
   EXPECT_TRUE(LLVMIsAtomicSingleThread(CmpXchg));
 }
 
+TEST(InstructionsTest, CmpPredicate) {
+  CmpPredicate P0(CmpInst::ICMP_ULE, false), P1(CmpInst::ICMP_ULE, true),
+      P2(CmpInst::ICMP_SLE, false), P3(CmpInst::ICMP_SLT, false);
+  CmpPredicate Q0 = P0, Q1 = P1, Q2 = P2;
+  CmpInst::Predicate R0 = P0, R1 = P1, R2 = P2;
+
+  EXPECT_EQ(*CmpPredicate::getMatching(P0, P1), CmpInst::ICMP_ULE);
+  EXPECT_EQ(CmpPredicate::getMatching(P0, P1)->hasSameSign(), false);
+  EXPECT_EQ(*CmpPredicate::getMatching(P1, P1), CmpInst::ICMP_ULE);
+  EXPECT_EQ(CmpPredicate::getMatching(P1, P1)->hasSameSign(), true);
+  EXPECT_EQ(CmpPredicate::getMatching(P0, P2), std::nullopt);
+  EXPECT_EQ(*CmpPredicate::getMatching(P1, P2), CmpInst::ICMP_SLE);
+  EXPECT_EQ(CmpPredicate::getMatching(P1, P2)->hasSameSign(), false);
+  EXPECT_EQ(CmpPredicate::getMatching(P1, P3), std::nullopt);
+  EXPECT_FALSE(Q0.hasSameSign());
+  EXPECT_TRUE(Q1.hasSameSign());
+  EXPECT_FALSE(Q2.hasSameSign());
+  EXPECT_EQ(P0, R0);
+  EXPECT_EQ(P1, R1);
+  EXPECT_EQ(P2, R2);
+}
+
 } // end anonymous namespace
 } // end namespace llvm
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 874c32c2d4398f..db06e2fc42366f 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5838,8 +5838,17 @@ define void @foo(i32 %i0, i32 %i1) {
     checkCommonPredicates(ICmp, LLVMICmp);
     EXPECT_EQ(ICmp->isSigned(), LLVMICmp->isSigned());
     EXPECT_EQ(ICmp->isUnsigned(), LLVMICmp->isUnsigned());
-    EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
-    EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
+    EXPECT_EQ(
+        static_cast<llvm::CmpInst::Predicate>(ICmp->getSignedPredicate()),
+        static_cast<llvm::CmpInst::Predicate>(LLVMICmp->getSignedPredicate()));
+    EXPECT_EQ(ICmp->getSignedPredicate().hasSameSign(),
+              LLVMICmp->getSignedPredicate().hasSameSign());
+    EXPECT_EQ(
+        static_cast<llvm::CmpInst::Predicate>(ICmp->getUnsignedPredicate()),
+        static_cast<llvm::CmpInst::Predicate>(
+            LLVMICmp->getUnsignedPredicate()));
+    EXPECT_EQ(ICmp->getUnsignedPredicate().hasSameSign(),
+              LLVMICmp->getUnsignedPredicate().hasSameSign());
   }
   auto *NewCmp =
       sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),

>From d017cfc4f1405cbc29a78d4b175da8313a49ce57 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 29 Nov 2024 14:14:37 +0000
Subject: [PATCH 2/3] CmpPredicate: address review

---
 llvm/include/llvm/IR/CmpPredicate.h           | 26 ++++-------
 llvm/include/llvm/IR/InstrTypes.h             | 18 --------
 llvm/include/llvm/IR/Instructions.h           | 46 +++++++------------
 llvm/lib/Analysis/ValueTracking.cpp           | 22 +++++----
 llvm/lib/IR/Instructions.cpp                  | 22 +++++++--
 .../InstCombine/InstCombineAndOrXor.cpp       |  2 +-
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 13 +-----
 7 files changed, 59 insertions(+), 90 deletions(-)

diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
index 99754fdf2b7686..c234269f0084b2 100644
--- a/llvm/include/llvm/IR/CmpPredicate.h
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -17,9 +17,9 @@
 
 namespace llvm {
 /// An abstraction over a floating-point predicate, and a pack of an integer
-/// predicate with samesign information. Functions in ICmpInst construct and
-/// return this type in place of a Predicate. It is also implictly constructed
-/// with a Predicate, dropping samesign information.
+/// predicate with samesign information. Some functions in ICmpInst construct
+/// and return this type in place of a Predicate. It is also implictly
+/// constructed with a Predicate, dropping samesign information.
 class CmpPredicate {
   CmpInst::Predicate Pred;
   bool HasSameSign;
@@ -30,26 +30,16 @@ class CmpPredicate {
     assert(!HasSameSign || CmpInst::isIntPredicate(Pred));
   }
 
-  inline operator CmpInst::Predicate() const { return Pred; }
+  operator CmpInst::Predicate() const { return Pred; }
 
-  inline bool hasSameSign() const { return HasSameSign; }
+  bool hasSameSign() const { return HasSameSign; }
 
   static std::optional<CmpPredicate> getMatching(CmpPredicate A,
-                                                 CmpPredicate B) {
-    if (A.Pred == B.Pred)
-      return A.HasSameSign == B.HasSameSign ? A : CmpPredicate(A.Pred);
-    if (A.HasSameSign &&
-        A.Pred == CmpInst::getFlippedSignednessPredicate(B.Pred))
-      return B.Pred;
-    if (B.HasSameSign &&
-        B.Pred == CmpInst::getFlippedSignednessPredicate(A.Pred))
-      return A.Pred;
-    return {};
-  }
+                                                 CmpPredicate B);
 
-  inline bool operator==(CmpInst::Predicate P) const { return Pred == P; }
+  bool operator==(CmpInst::Predicate P) const { return Pred == P; }
 
-  inline bool operator==(CmpPredicate) const = delete;
+  bool operator==(CmpPredicate) const = delete;
 };
 } // namespace llvm
 
diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index dd31f2e3d0a747..e6332a16df7d5f 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -728,24 +728,6 @@ class CmpInst : public Instruction {
           InsertPosition InsertBefore = nullptr,
           Instruction *FlagsSource = nullptr);
 
-  /// Return the signed version of the predicate: variant that operates on
-  /// Predicate; used by the corresponding function in ICmpInst, to operate with
-  /// CmpPredicate.
-  static Predicate getSignedPredicate(Predicate Pred);
-
-  /// Return the unsigned version of the predicate: variant that operates on
-  /// Predicate; used by the corresponding function in ICmpInst, to operate with
-  /// CmpPredicate.
-  static Predicate getUnsignedPredicate(Predicate Pred);
-
-  /// Return the unsigned version of the signed predicate pred or the signed
-  /// version of the signed predicate pred: variant that operates on Predicate;
-  /// used by the corresponding function in ICmpInst, to operate with
-  /// CmpPredicate.
-  static Predicate getFlippedSignednessPredicate(Predicate Pred);
-
-  friend class CmpPredicate;
-
 public:
   // allocate space for exactly two operands
   void *operator new(size_t S) { return User::operator new(S, AllocMarker); }
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 4d3def15fad734..a42bf6bca1b9fb 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1226,56 +1226,44 @@ class ICmpInst: public CmpInst {
     return {getSwappedPredicate(Pred), Pred.hasSameSign()};
   }
 
-  /// @returns the swapped predicate along with samesign information.
-  CmpPredicate getSwappedCmpPredicate() const {
+  /// @returns the swapped predicate.
+  Predicate getSwappedCmpPredicate() const {
     return getSwappedPredicate(getCmpPredicate());
   }
 
   /// For example, EQ->EQ, SLE->SLE, UGT->SGT, etc.
   /// @returns the predicate that would be the result if the operand were
   /// regarded as signed.
-  /// Return the signed version of the predicate along with samesign
-  /// information.
-  CmpPredicate getSignedPredicate() const {
-    return getSignedPredicate(getCmpPredicate());
+  /// Return the signed version of the predicate.
+  Predicate getSignedPredicate() const {
+    return getSignedPredicate(getPredicate());
   }
 
-  /// Return the signed version of the predicate along with samesign
-  /// information: static variant.
-  static CmpPredicate getSignedPredicate(CmpPredicate Pred) {
-    return {CmpInst::getSignedPredicate(Pred), Pred.hasSameSign()};
-  }
+  /// Return the signed version of the predicate: static variant.
+  static Predicate getSignedPredicate(Predicate Pred);
 
   /// For example, EQ->EQ, SLE->ULE, UGT->UGT, etc.
   /// @returns the predicate that would be the result if the operand were
   /// regarded as unsigned.
-  /// Return the unsigned version of the predicate along with samesign
-  /// information.
-  CmpPredicate getUnsignedPredicate() const {
-    return getUnsignedPredicate(getCmpPredicate());
+  /// Return the unsigned version of the predicate.
+  Predicate getUnsignedPredicate() const {
+    return getUnsignedPredicate(getPredicate());
   }
 
-  /// Return the unsigned version of the predicate along with samesign
-  /// information: static variant.
-  static CmpPredicate getUnsignedPredicate(CmpPredicate Pred) {
-    return {CmpInst::getUnsignedPredicate(Pred), Pred.hasSameSign()};
-  }
+  /// Return the unsigned version of the predicate: static variant.
+  static Predicate getUnsignedPredicate(Predicate Pred);
 
   /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ
   /// @returns the unsigned version of the signed predicate pred or
-  ///          the signed version of the signed predicate pred, along with
-  ///          samesign information.
+  ///          the signed version of the signed predicate pred.
   /// Static variant.
-  static CmpPredicate getFlippedSignednessPredicate(CmpPredicate Pred) {
-    return {CmpInst::getFlippedSignednessPredicate(Pred), Pred.hasSameSign()};
-  }
+  static Predicate getFlippedSignednessPredicate(Predicate Pred);
 
   /// For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ
   /// @returns the unsigned version of the signed predicate pred or
-  ///          the signed version of the signed predicate pred, along with
-  ///          samesign information.
-  CmpPredicate getFlippedSignednessPredicate() const {
-    return getFlippedSignednessPredicate(getCmpPredicate());
+  ///          the signed version of the signed predicate pred.
+  Predicate getFlippedSignednessPredicate() const {
+    return getFlippedSignednessPredicate(getPredicate());
   }
 
   void setSameSign(bool B = true) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 490f4dc8056941..54b5305d254e48 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9104,7 +9104,7 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
 }
 
 /// Return true if "icmp Pred LHS RHS" is always true.
-static bool isTruePredicate(CmpPredicate Pred, const Value *LHS,
+static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
                             const Value *RHS) {
   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
     return true;
@@ -9186,8 +9186,8 @@ static bool isTruePredicate(CmpPredicate Pred, const Value *LHS,
 /// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
 /// ALHS ARHS" is true.  Otherwise, return std::nullopt.
 static std::optional<bool>
-isImpliedCondOperands(CmpPredicate Pred, const Value *ALHS, const Value *ARHS,
-                      const Value *BLHS, const Value *BRHS) {
+isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
+                      const Value *ARHS, const Value *BLHS, const Value *BRHS) {
   switch (Pred) {
   default:
     return std::nullopt;
@@ -9256,16 +9256,18 @@ static std::optional<bool> isImpliedCondCommonOperandWithCR(
 /// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
 /// is true.  Return false if LHS implies RHS is false. Otherwise, return
 /// std::nullopt if we can't infer anything.
-static std::optional<bool>
-isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
-                   const Value *R1, const DataLayout &DL, bool LHSIsTrue) {
+static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
+                                              CmpInst::Predicate RPred,
+                                              const Value *R0, const Value *R1,
+                                              const DataLayout &DL,
+                                              bool LHSIsTrue) {
   Value *L0 = LHS->getOperand(0);
   Value *L1 = LHS->getOperand(1);
 
   // The rest of the logic assumes the LHS condition is true.  If that's not the
   // case, invert the predicate to make it so.
-  CmpPredicate LPred =
-      LHSIsTrue ? LHS->getCmpPredicate() : LHS->getInverseCmpPredicate();
+  CmpInst::Predicate LPred =
+      LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate();
 
   // We can have non-canonical operands, so try to normalize any common operand
   // to L0/R0.
@@ -9342,8 +9344,8 @@ isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
       match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
     return CmpPredicate::getMatching(LPred, RPred).has_value();
 
-  if (auto P = CmpPredicate::getMatching(LPred, RPred))
-    return isImpliedCondOperands(*P, L0, L1, R0, R1);
+  if (LPred == RPred)
+    return isImpliedCondOperands(LPred, L0, L1, R0, R1);
 
   return std::nullopt;
 }
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 3c529203e32600..a5810e0adb1f7e 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3581,7 +3581,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, CmpInst::Predicate Pred) {
   return OS;
 }
 
-ICmpInst::Predicate CmpInst::getSignedPredicate(Predicate pred) {
+ICmpInst::Predicate ICmpInst::getSignedPredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown icmp predicate!");
     case ICMP_EQ: case ICMP_NE:
@@ -3594,7 +3594,7 @@ ICmpInst::Predicate CmpInst::getSignedPredicate(Predicate pred) {
   }
 }
 
-ICmpInst::Predicate CmpInst::getUnsignedPredicate(Predicate pred) {
+ICmpInst::Predicate ICmpInst::getUnsignedPredicate(Predicate pred) {
   switch (pred) {
     default: llvm_unreachable("Unknown icmp predicate!");
     case ICMP_EQ: case ICMP_NE:
@@ -3841,7 +3841,7 @@ std::optional<bool> ICmpInst::compare(const KnownBits &LHS,
   }
 }
 
-CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
+CmpInst::Predicate ICmpInst::getFlippedSignednessPredicate(Predicate pred) {
   if (CmpInst::isEquality(pred))
     return pred;
   if (isSigned(pred))
@@ -3915,6 +3915,22 @@ bool CmpInst::isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
   return isImpliedTrueByMatchingCmp(Pred1, getInversePredicate(Pred2));
 }
 
+//===----------------------------------------------------------------------===//
+//                       CmpPredicate Implementation
+//===----------------------------------------------------------------------===//
+std::optional<CmpPredicate> CmpPredicate::getMatching(CmpPredicate A,
+                                                      CmpPredicate B) {
+  if (A.Pred == B.Pred)
+    return A.HasSameSign == B.HasSameSign ? A : CmpPredicate(A.Pred);
+  if (A.HasSameSign &&
+      A.Pred == ICmpInst::getFlippedSignednessPredicate(B.Pred))
+    return B.Pred;
+  if (B.HasSameSign &&
+      B.Pred == ICmpInst::getFlippedSignednessPredicate(A.Pred))
+    return A.Pred;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 //                        SwitchInst Implementation
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 24a969f221e5bb..b4033fc2a418a1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -59,7 +59,7 @@ Value *InstCombinerImpl::insertRangeTest(Value *V, const APInt &Lo,
 
   // V >= Min && V <  Hi --> V <  Hi
   // V <  Min || V >= Hi --> V >= Hi
-  CmpPredicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE;
+  ICmpInst::Predicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE;
   if (isSigned ? Lo.isMinSignedValue() : Lo.isMinValue()) {
     Pred = isSigned ? ICmpInst::getSignedPredicate(Pred) : Pred;
     return Builder.CreateICmp(Pred, V, ConstantInt::get(Ty, Hi));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index db06e2fc42366f..874c32c2d4398f 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -5838,17 +5838,8 @@ define void @foo(i32 %i0, i32 %i1) {
     checkCommonPredicates(ICmp, LLVMICmp);
     EXPECT_EQ(ICmp->isSigned(), LLVMICmp->isSigned());
     EXPECT_EQ(ICmp->isUnsigned(), LLVMICmp->isUnsigned());
-    EXPECT_EQ(
-        static_cast<llvm::CmpInst::Predicate>(ICmp->getSignedPredicate()),
-        static_cast<llvm::CmpInst::Predicate>(LLVMICmp->getSignedPredicate()));
-    EXPECT_EQ(ICmp->getSignedPredicate().hasSameSign(),
-              LLVMICmp->getSignedPredicate().hasSameSign());
-    EXPECT_EQ(
-        static_cast<llvm::CmpInst::Predicate>(ICmp->getUnsignedPredicate()),
-        static_cast<llvm::CmpInst::Predicate>(
-            LLVMICmp->getUnsignedPredicate()));
-    EXPECT_EQ(ICmp->getUnsignedPredicate().hasSameSign(),
-              LLVMICmp->getUnsignedPredicate().hasSameSign());
+    EXPECT_EQ(ICmp->getSignedPredicate(), LLVMICmp->getSignedPredicate());
+    EXPECT_EQ(ICmp->getUnsignedPredicate(), LLVMICmp->getUnsignedPredicate());
   }
   auto *NewCmp =
       sandboxir::CmpInst::create(llvm::CmpInst::ICMP_ULE, F.getArg(0),

>From e2aee89f35177ee597a08b06bedf88129fe16746 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Sat, 30 Nov 2024 12:18:06 +0000
Subject: [PATCH 3/3] CmpPredicate: add header comments

---
 llvm/include/llvm/IR/CmpPredicate.h | 20 ++++++++++++++++++--
 llvm/lib/IR/Instructions.cpp        |  1 +
 2 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
index c234269f0084b2..ccaedaa3fde757 100644
--- a/llvm/include/llvm/IR/CmpPredicate.h
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -18,27 +18,43 @@
 namespace llvm {
 /// An abstraction over a floating-point predicate, and a pack of an integer
 /// predicate with samesign information. Some functions in ICmpInst construct
-/// and return this type in place of a Predicate. It is also implictly
-/// constructed with a Predicate, dropping samesign information.
+/// and return this type in place of a Predicate.
 class CmpPredicate {
   CmpInst::Predicate Pred;
   bool HasSameSign;
 
 public:
+  // Constructed implictly with a either Predicate and samesign information, or
+  // just a Predicate, dropping samesign information.
   CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false)
       : Pred(Pred), HasSameSign(HasSameSign) {
     assert(!HasSameSign || CmpInst::isIntPredicate(Pred));
   }
 
+  // Implictly converts to the underlying Predicate, dropping samesign
+  // information.
   operator CmpInst::Predicate() const { return Pred; }
 
+  // Query samesign information, for optimizations.
   bool hasSameSign() const { return HasSameSign; }
 
+  // Compares two CmpPredicates taking samesign into account and returns the
+  // canonicalized CmpPredicate if they match. An alternative to operator==.
+  //
+  // For example,
+  //   samesign ult + samesign ult -> samesign ult
+  //   samesign ult + ult -> ult
+  //   samesign ult + slt -> slt
+  //   ult + ult -> ult
+  //   ult + slt -> std::nullopt
   static std::optional<CmpPredicate> getMatching(CmpPredicate A,
                                                  CmpPredicate B);
 
+  // An operator== on the underlying Predicate.
   bool operator==(CmpInst::Predicate P) const { return Pred == P; }
 
+  // There is no operator== defined on CmpPredicate. Use getMatching instead to
+  // get the canonicalized matching CmpPredicate.
   bool operator==(CmpPredicate) const = delete;
 };
 } // namespace llvm
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index a5810e0adb1f7e..4f07a4c4dd017a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3918,6 +3918,7 @@ bool CmpInst::isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
 //===----------------------------------------------------------------------===//
 //                       CmpPredicate Implementation
 //===----------------------------------------------------------------------===//
+
 std::optional<CmpPredicate> CmpPredicate::getMatching(CmpPredicate A,
                                                       CmpPredicate B) {
   if (A.Pred == B.Pred)



More information about the llvm-commits mailing list