[llvm] PatternMatch: migrate to CmpPredicate (PR #118534)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 13 03:18:10 PST 2024
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/118534
>From 2cd636a4eead1f42b0ebab3be563bee4b45da073 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Wed, 20 Nov 2024 16:49:37 +0000
Subject: [PATCH 1/2] PatternMatch: migrate to CmpPredicate
With the introduction of CmpPredicate in 51a895a (IR: introduce struct
with CmpInst::Predicate and samesign), PatternMatch is one of the first
key pieces of infrastructure that must be updated to match a CmpInst
respecting samesign information. Implement this change to Cmp-matchers.
This is a preparatory step in migrating the codebase over to
CmpPredicate. Since we no functional changes are desired at this stage,
we have chosen not to migrate CmpPredicate::operator==(CmpPredicate)
calls to use CmpPredicate::getMatching(), as that would have visible
impact on tests that are not yet written: instead, we call
CmpPredicate::operator==(Predicate), preserving the old behavior, while
also inserting a few FIXME comments for follow-ups.
---
llvm/include/llvm/IR/CmpPredicate.h | 22 ++++
llvm/include/llvm/IR/PatternMatch.h | 108 +++++++++---------
llvm/lib/Analysis/IVDescriptors.cpp | 4 +-
llvm/lib/Analysis/InstructionSimplify.cpp | 16 +--
llvm/lib/Analysis/OverflowInstAnalysis.cpp | 2 +-
llvm/lib/Analysis/ValueTracking.cpp | 22 ++--
llvm/lib/CodeGen/CodeGenPrepare.cpp | 4 +-
llvm/lib/CodeGen/ExpandMemCmp.cpp | 2 +-
llvm/lib/IR/Instructions.cpp | 21 ++++
.../AArch64/AArch64TargetTransformInfo.cpp | 2 +-
.../Target/AMDGPU/AMDGPUCodeGenPrepare.cpp | 2 +-
.../AMDGPU/AMDGPUInstCombineIntrinsic.cpp | 2 +-
.../Hexagon/HexagonLoopIdiomRecognition.cpp | 4 +-
llvm/lib/Target/X86/X86ISelLowering.cpp | 2 +-
.../InstCombine/InstCombineAddSub.cpp | 8 +-
.../InstCombine/InstCombineAndOrXor.cpp | 24 ++--
.../InstCombine/InstCombineCompares.cpp | 26 ++---
.../InstCombine/InstCombineSelect.cpp | 54 ++++-----
.../InstCombine/InstCombineVectorOps.cpp | 4 +-
.../InstCombine/InstructionCombining.cpp | 6 +-
.../Transforms/Scalar/CallSiteSplitting.cpp | 4 +-
.../Scalar/ConstraintElimination.cpp | 8 +-
.../Scalar/DeadStoreElimination.cpp | 2 +-
llvm/lib/Transforms/Scalar/EarlyCSE.cpp | 6 +-
llvm/lib/Transforms/Scalar/GuardWidening.cpp | 2 +-
llvm/lib/Transforms/Scalar/JumpThreading.cpp | 4 +-
llvm/lib/Transforms/Scalar/LICM.cpp | 11 +-
llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp | 2 +-
.../Transforms/Scalar/LoopIdiomRecognize.cpp | 4 +-
.../Transforms/Scalar/SimpleLoopUnswitch.cpp | 10 +-
llvm/lib/Transforms/Utils/LoopPeel.cpp | 2 +-
.../Utils/ScalarEvolutionExpander.cpp | 2 +-
llvm/lib/Transforms/Utils/SimplifyIndVar.cpp | 8 +-
.../Transforms/Vectorize/SLPVectorizer.cpp | 16 +--
.../Transforms/Vectorize/VectorCombine.cpp | 10 +-
llvm/unittests/IR/PatternMatch.cpp | 4 +-
36 files changed, 237 insertions(+), 193 deletions(-)
diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
index 4b1be7beb2b663a..ae027305cb4199d 100644
--- a/llvm/include/llvm/IR/CmpPredicate.h
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -24,6 +24,9 @@ class CmpPredicate {
bool HasSameSign;
public:
+ /// Default constructor.
+ CmpPredicate() : Pred(CmpInst::BAD_ICMP_PREDICATE), HasSameSign(false) {}
+
/// Constructed implictly with a either Predicate and samesign information, or
/// just a Predicate, dropping samesign information.
CmpPredicate(CmpInst::Predicate Pred, bool HasSameSign = false)
@@ -52,11 +55,30 @@ class CmpPredicate {
/// An operator== on the underlying Predicate.
bool operator==(CmpInst::Predicate P) const { return Pred == P; }
+ bool operator!=(CmpInst::Predicate P) const { return Pred != P; }
/// There is no operator== defined on CmpPredicate. Use getMatching instead to
/// get the canonicalized matching CmpPredicate.
bool operator==(CmpPredicate) const = delete;
+ bool operator!=(CmpPredicate) const = delete;
+
+ /// TypeSwitch over the CmpInst and either do ICmpInst::getCmpPredicate() or
+ /// FCmpInst::getPredicate().
+ static CmpPredicate get(const CmpInst *Cmp);
+
+ /// Get the swapped predicate of a CmpPredicate, using
+ /// CmpInst::isIntPredicate().
+ static CmpPredicate getSwapped(CmpPredicate P);
+
+ /// Get the swapped predicate of a CmpInst.
+ static CmpPredicate getSwapped(const CmpInst *Cmp);
+
+ /// Provided to facilitate storing a CmpPredicate in data structures that
+ /// require hashing.
+ friend hash_code hash_value(const CmpPredicate &Arg); // NOLINT
};
+
+[[nodiscard]] hash_code hash_value(const CmpPredicate &Arg);
} // namespace llvm
#endif
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index fc4c0124d00b841..ed7e1eff6f0005b 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -688,14 +688,14 @@ inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
}
struct icmp_pred_with_threshold {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *Thr;
bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); }
};
/// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
/// to Threshold. For vectors, this includes constants with undefined elements.
inline cst_pred_ty<icmp_pred_with_threshold>
-m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) {
+m_SpecificInt_ICMP(CmpPredicate Predicate, const APInt &Threshold) {
cst_pred_ty<icmp_pred_with_threshold> P;
P.Pred = Predicate;
P.Thr = &Threshold;
@@ -1557,16 +1557,16 @@ template <typename T> inline Exact_match<T> m_Exact(const T &SubPattern) {
// Matchers for CmpInst classes
//
-template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
+template <typename LHS_t, typename RHS_t, typename Class,
bool Commutable = false>
struct CmpClass_match {
- PredicateTy *Predicate;
+ CmpPredicate *Predicate;
LHS_t L;
RHS_t R;
// The evaluation order is always stable, regardless of Commutability.
// The LHS is always matched first.
- CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS)
+ CmpClass_match(CmpPredicate &Pred, const LHS_t &LHS, const RHS_t &RHS)
: Predicate(&Pred), L(LHS), R(RHS) {}
CmpClass_match(const LHS_t &LHS, const RHS_t &RHS)
: Predicate(nullptr), L(LHS), R(RHS) {}
@@ -1575,12 +1575,13 @@ struct CmpClass_match {
if (auto *I = dyn_cast<Class>(V)) {
if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) {
if (Predicate)
- *Predicate = I->getPredicate();
+ *Predicate = CmpPredicate::get(I);
return true;
- } else if (Commutable && L.match(I->getOperand(1)) &&
- R.match(I->getOperand(0))) {
+ }
+ if (Commutable && L.match(I->getOperand(1)) &&
+ R.match(I->getOperand(0))) {
if (Predicate)
- *Predicate = I->getSwappedPredicate();
+ *Predicate = CmpPredicate::getSwapped(I);
return true;
}
}
@@ -1589,60 +1590,58 @@ struct CmpClass_match {
};
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
-m_Cmp(CmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(Pred, L, R);
+inline CmpClass_match<LHS, RHS, CmpInst> m_Cmp(CmpPredicate &Pred, const LHS &L,
+ const RHS &R) {
+ return CmpClass_match<LHS, RHS, CmpInst>(Pred, L, R);
}
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
-m_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(Pred, L, R);
+inline CmpClass_match<LHS, RHS, ICmpInst> m_ICmp(CmpPredicate &Pred,
+ const LHS &L, const RHS &R) {
+ return CmpClass_match<LHS, RHS, ICmpInst>(Pred, L, R);
}
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
-m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(Pred, L, R);
+inline CmpClass_match<LHS, RHS, FCmpInst> m_FCmp(CmpPredicate &Pred,
+ const LHS &L, const RHS &R) {
+ return CmpClass_match<LHS, RHS, FCmpInst>(Pred, L, R);
}
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
-m_Cmp(const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(L, R);
+inline CmpClass_match<LHS, RHS, CmpInst> m_Cmp(const LHS &L, const RHS &R) {
+ return CmpClass_match<LHS, RHS, CmpInst>(L, R);
}
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
-m_ICmp(const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(L, R);
+inline CmpClass_match<LHS, RHS, ICmpInst> m_ICmp(const LHS &L, const RHS &R) {
+ return CmpClass_match<LHS, RHS, ICmpInst>(L, R);
}
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
-m_FCmp(const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(L, R);
+inline CmpClass_match<LHS, RHS, FCmpInst> m_FCmp(const LHS &L, const RHS &R) {
+ return CmpClass_match<LHS, RHS, FCmpInst>(L, R);
}
// Same as CmpClass, but instead of saving Pred as out output variable, match a
// specific input pred for equality.
-template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
+template <typename LHS_t, typename RHS_t, typename Class,
bool Commutable = false>
struct SpecificCmpClass_match {
- const PredicateTy Predicate;
+ const CmpPredicate Predicate;
LHS_t L;
RHS_t R;
- SpecificCmpClass_match(PredicateTy Pred, const LHS_t &LHS, const RHS_t &RHS)
+ SpecificCmpClass_match(CmpPredicate Pred, const LHS_t &LHS, const RHS_t &RHS)
: Predicate(Pred), L(LHS), R(RHS) {}
template <typename OpTy> bool match(OpTy *V) {
if (auto *I = dyn_cast<Class>(V)) {
- if (I->getPredicate() == Predicate && L.match(I->getOperand(0)) &&
- R.match(I->getOperand(1)))
+ if (CmpPredicate::getMatching(CmpPredicate::get(I), Predicate) &&
+ L.match(I->getOperand(0)) && R.match(I->getOperand(1)))
return true;
if constexpr (Commutable) {
- if (I->getPredicate() == Class::getSwappedPredicate(Predicate) &&
+ if (CmpPredicate::getMatching(CmpPredicate::get(I),
+ CmpPredicate::getSwapped(Predicate)) &&
L.match(I->getOperand(1)) && R.match(I->getOperand(0)))
return true;
}
@@ -1653,31 +1652,27 @@ struct SpecificCmpClass_match {
};
template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>
-m_SpecificCmp(CmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
- return SpecificCmpClass_match<LHS, RHS, CmpInst, CmpInst::Predicate>(
- MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, CmpInst>
+m_SpecificCmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+ return SpecificCmpClass_match<LHS, RHS, CmpInst>(MatchPred, L, R);
}
template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>
-m_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
- return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>(
- MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, ICmpInst>
+m_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+ return SpecificCmpClass_match<LHS, RHS, ICmpInst>(MatchPred, L, R);
}
template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
-m_c_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
- return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(
- MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, ICmpInst, true>
+m_c_SpecificICmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+ return SpecificCmpClass_match<LHS, RHS, ICmpInst, true>(MatchPred, L, R);
}
template <typename LHS, typename RHS>
-inline SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
-m_SpecificFCmp(FCmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
- return SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>(
- MatchPred, L, R);
+inline SpecificCmpClass_match<LHS, RHS, FCmpInst>
+m_SpecificFCmp(CmpPredicate MatchPred, const LHS &L, const RHS &R) {
+ return SpecificCmpClass_match<LHS, RHS, FCmpInst>(MatchPred, L, R);
}
//===----------------------------------------------------------------------===//
@@ -2468,7 +2463,7 @@ struct UAddWithOverflow_match {
template <typename OpTy> bool match(OpTy *V) {
Value *ICmpLHS, *ICmpRHS;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
return false;
@@ -2738,16 +2733,15 @@ inline AnyBinaryOp_match<LHS, RHS, true> m_c_BinOp(const LHS &L, const RHS &R) {
/// Matches an ICmp with a predicate over LHS and RHS in either order.
/// Swaps the predicate if operands are commuted.
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
-m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(Pred, L,
- R);
+inline CmpClass_match<LHS, RHS, ICmpInst, true>
+m_c_ICmp(CmpPredicate &Pred, const LHS &L, const RHS &R) {
+ return CmpClass_match<LHS, RHS, ICmpInst, true>(Pred, L, R);
}
template <typename LHS, typename RHS>
-inline CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
-m_c_ICmp(const LHS &L, const RHS &R) {
- return CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(L, R);
+inline CmpClass_match<LHS, RHS, ICmpInst, true> m_c_ICmp(const LHS &L,
+ const RHS &R) {
+ return CmpClass_match<LHS, RHS, ICmpInst, true>(L, R);
}
/// Matches a specific opcode with LHS and RHS in either order.
diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp
index e1eb219cf977e19..9670ec7f043944d 100644
--- a/llvm/lib/Analysis/IVDescriptors.cpp
+++ b/llvm/lib/Analysis/IVDescriptors.cpp
@@ -628,7 +628,7 @@ RecurrenceDescriptor::isAnyOfPattern(Loop *Loop, PHINode *OrigPhi,
Instruction *I, InstDesc &Prev) {
// We must handle the select(cmp(),x,y) as a single instruction. Advance to
// the select.
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
return InstDesc(Select, Prev.getRecKind());
@@ -668,7 +668,7 @@ RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
// We must handle the select(cmp()) as a single instruction. Advance to the
// select.
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
return InstDesc(Select, Prev.getRecKind());
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 62edea38745b13d..3325cd972cf1eb3 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1500,12 +1500,12 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,
const SimplifyQuery &Q) {
Value *X, *Y;
- ICmpInst::Predicate EqPred;
+ CmpPredicate EqPred;
if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) ||
!ICmpInst::isEquality(EqPred))
return nullptr;
- ICmpInst::Predicate UnsignedPred;
+ CmpPredicate UnsignedPred;
Value *A, *B;
// Y = (A - B);
@@ -1644,7 +1644,7 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1,
static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
const InstrInfoQuery &IIQ) {
// (icmp (add V, C0), C1) & (icmp V, C0)
- ICmpInst::Predicate Pred0, Pred1;
+ CmpPredicate Pred0, Pred1;
const APInt *C0, *C1;
Value *V;
if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
@@ -1691,7 +1691,7 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
/// Try to simplify and/or of icmp with ctpop intrinsic.
static Value *simplifyAndOrOfICmpsWithCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1,
bool IsAnd) {
- ICmpInst::Predicate Pred0, Pred1;
+ CmpPredicate Pred0, Pred1;
Value *X;
const APInt *C;
if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)),
@@ -1735,7 +1735,7 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
const InstrInfoQuery &IIQ) {
// (icmp (add V, C0), C1) | (icmp V, C0)
- ICmpInst::Predicate Pred0, Pred1;
+ CmpPredicate Pred0, Pred1;
const APInt *C0, *C1;
Value *V;
if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1))))
@@ -1891,7 +1891,7 @@ static Value *simplifyAndOrWithICmpEq(unsigned Opcode, Value *Op0, Value *Op1,
unsigned MaxRecurse) {
assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
"Must be and/or");
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *A, *B;
if (!match(Op0, m_ICmp(Pred, m_Value(A), m_Value(B))) ||
!ICmpInst::isEquality(Pred))
@@ -4614,7 +4614,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
Value *FalseVal,
const SimplifyQuery &Q,
unsigned MaxRecurse) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *CmpLHS, *CmpRHS;
if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
return nullptr;
@@ -4738,7 +4738,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
const SimplifyQuery &Q,
unsigned MaxRecurse) {
- FCmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *CmpLHS, *CmpRHS;
if (!match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
return nullptr;
diff --git a/llvm/lib/Analysis/OverflowInstAnalysis.cpp b/llvm/lib/Analysis/OverflowInstAnalysis.cpp
index 8bfd6642f76027a..40f71f4a8db46d4 100644
--- a/llvm/lib/Analysis/OverflowInstAnalysis.cpp
+++ b/llvm/lib/Analysis/OverflowInstAnalysis.cpp
@@ -20,7 +20,7 @@ using namespace llvm::PatternMatch;
bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd,
Use *&Y) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X, *NotOp1;
int XIdx;
IntrinsicInst *II;
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index f2c6949e535d2a9..4b940557903daff 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -259,7 +259,7 @@ bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
return !I->user_empty() && all_of(I->users(), [](const User *U) {
- ICmpInst::Predicate P;
+ CmpPredicate P;
return match(U, m_ICmp(P, m_Value(), m_Zero())) && ICmpInst::isEquality(P);
});
}
@@ -614,7 +614,7 @@ static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
// runtime of ~O(#assumes * #values).
Value *RHS;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
if (!match(I->getArgOperand(0), m_c_ICmp(Pred, m_V, m_Value(RHS))))
continue;
@@ -1602,7 +1602,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
// See if we can further use a conditional branch into the phi
// to help us determine the range of the value.
if (!Known2.isConstant()) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *RHSC;
BasicBlock *TrueSucc, *FalseSucc;
// TODO: Use RHS Value and compute range from its known bits.
@@ -2255,7 +2255,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
const Value *Cond,
bool CondIsTrue) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *RHSC;
if (!match(Cond, m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Specific(V)),
m_APInt(RHSC))))
@@ -2580,7 +2580,7 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V,
// Consider only compare instructions uniquely controlling a branch
Value *RHS;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(U, m_c_ICmp(Pred, m_Specific(V), m_Value(RHS))))
continue;
@@ -3009,7 +3009,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
// The condition of the select dominates the true/false arm. Check if the
// condition implies that a given arm is non-zero.
Value *X;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(I->getOperand(0), m_c_ICmp(Pred, m_Specific(Op), m_Value(X))))
return false;
@@ -3037,7 +3037,7 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
return true;
RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
// Check if the branch on the phi excludes zero.
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X;
BasicBlock *TrueSucc, *FalseSucc;
if (match(RecQ.CxtI,
@@ -4895,7 +4895,7 @@ static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
KnownFromContext);
return;
}
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *LHS;
uint64_t ClassVal = 0;
const APFloat *CRHS;
@@ -5135,7 +5135,7 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
FPClassTest MaskIfFalse = fcAllFlags;
uint64_t ClassVal = 0;
const Function *F = cast<Instruction>(Op)->getFunction();
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *CmpLHS, *CmpRHS;
if (F && match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
// If the select filters out a value based on the class, it no longer
@@ -8571,7 +8571,7 @@ bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
bool llvm::isKnownInversion(const Value *X, const Value *Y) {
// Handle X = icmp pred A, B, Y = icmp pred A, C.
Value *A, *B, *C;
- ICmpInst::Predicate Pred1, Pred2;
+ CmpPredicate Pred1, Pred2;
if (!match(X, m_ICmp(Pred1, m_Value(A), m_Value(B))) ||
!match(Y, m_c_ICmp(Pred2, m_Specific(A), m_Value(C))))
return false;
@@ -10054,7 +10054,7 @@ void llvm::findValuesAffectedByCondition(
if (!Visited.insert(V).second)
continue;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *A, *B, *X;
if (IsAssume) {
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 83c6ecd401039f9..5c712e4f007d392 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1885,7 +1885,7 @@ static bool foldICmpWithDominatingICmp(CmpInst *Cmp,
return false;
Value *CmpOp0 = Cmp->getOperand(0), *CmpOp1 = Cmp->getOperand(1);
- ICmpInst::Predicate DomPred;
+ CmpPredicate DomPred;
if (!match(DomCond, m_ICmp(DomPred, m_Specific(CmpOp0), m_Specific(CmpOp1))))
return false;
if (DomPred != ICmpInst::ICMP_SGT && DomPred != ICmpInst::ICMP_SLT)
@@ -2155,7 +2155,7 @@ bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
static bool adjustIsPower2Test(CmpInst *Cmp, const TargetLowering &TLI,
const TargetTransformInfo &TTI,
const DataLayout &DL) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(Cmp, m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(), m_One())))
return false;
if (!ICmpInst::isEquality(Pred))
diff --git a/llvm/lib/CodeGen/ExpandMemCmp.cpp b/llvm/lib/CodeGen/ExpandMemCmp.cpp
index a1acb4ef3683805..f8ca7e370f6ef9c 100644
--- a/llvm/lib/CodeGen/ExpandMemCmp.cpp
+++ b/llvm/lib/CodeGen/ExpandMemCmp.cpp
@@ -668,7 +668,7 @@ Value *MemCmpExpansion::getMemCmpOneBlock() {
// We can generate more optimal code with a smaller number of operations
if (CI->hasOneUser()) {
auto *UI = cast<Instruction>(*CI->user_begin());
- ICmpInst::Predicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE;
+ CmpPredicate Pred = ICmpInst::Predicate::BAD_ICMP_PREDICATE;
uint64_t Shift;
bool NeedsZExt = false;
// This is a special case because instead of checking if the result is less
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 4f07a4c4dd017a6..d27863574dcf25a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -3932,6 +3933,26 @@ std::optional<CmpPredicate> CmpPredicate::getMatching(CmpPredicate A,
return {};
}
+CmpPredicate CmpPredicate::get(const CmpInst *Cmp) {
+ return TypeSwitch<const CmpInst *, CmpPredicate>(Cmp)
+ .Case<ICmpInst>([](auto *ICI) { return ICI->getCmpPredicate(); })
+ .Case<FCmpInst>([](auto *FCI) { return FCI->getPredicate(); });
+}
+
+CmpPredicate CmpPredicate::getSwapped(CmpPredicate P) {
+ return CmpInst::isIntPredicate(P)
+ ? ICmpInst::getSwappedCmpPredicate(P)
+ : CmpPredicate{CmpInst::getSwappedPredicate(P)};
+}
+
+CmpPredicate CmpPredicate::getSwapped(const CmpInst *Cmp) {
+ return getSwapped(get(Cmp));
+}
+
+hash_code llvm::hash_value(const CmpPredicate &Arg) { // NOLINT
+ return hash_combine(Arg.Pred, Arg.HasSameSign);
+}
+
//===----------------------------------------------------------------------===//
// SwitchInst Implementation
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 283fe4c3caa6022..73896ed5033c126 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -3605,7 +3605,7 @@ InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
// If VecPred is not set, check if we can get a predicate from the context
// instruction, if its type matches the requested ValTy.
if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
- CmpInst::Predicate CurrentPred;
+ CmpPredicate CurrentPred;
if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
m_Value())))
VecPred = CurrentPred;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
index 75e20c793016815..2523e369985b0eb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp
@@ -1688,7 +1688,7 @@ bool AMDGPUCodeGenPrepareImpl::visitSelectInst(SelectInst &I) {
Value *TrueVal = I.getTrueValue();
Value *FalseVal = I.getFalseValue();
Value *CmpVal;
- FCmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (ST.has16BitInsts() && needsPromotionToI32(I.getType())) {
if (UA.isUniform(&I))
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 41b33ac8a7eb4b6..8b1b398606583e9 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -960,7 +960,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
return &II;
}
- CmpInst::Predicate SrcPred;
+ CmpPredicate SrcPred;
Value *SrcLHS;
Value *SrcRHS;
diff --git a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
index 705e1f43851f7ad..46a8ab395d32bdc 100644
--- a/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonLoopIdiomRecognition.cpp
@@ -695,7 +695,7 @@ bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI,
using namespace PatternMatch;
- CmpInst::Predicate P;
+ CmpPredicate P;
Value *A = nullptr, *B = nullptr, *C = nullptr;
if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) &&
@@ -810,7 +810,7 @@ bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI,
using namespace PatternMatch;
Value *C = nullptr;
- CmpInst::Predicate P;
+ CmpPredicate P;
bool TrueIfZero;
if (match(CondV, m_c_ICmp(P, m_Value(C), m_Zero()))) {
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 871de16d66b6c5c..83e4bcddaee1068 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -31477,7 +31477,7 @@ static bool shouldExpandCmpArithRMWInIR(AtomicRMWInst *AI) {
return false;
Value *Op = AI->getOperand(1);
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Instruction *I = AI->user_back();
AtomicRMWInst::BinOp Opc = AI->getOperation();
if (Opc == AtomicRMWInst::Add) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index fe0d88fcc6ee4b1..7a184a19d7c54ab 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1289,7 +1289,7 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
// Note that, by the time we end up here, if possible, ugt has been
// canonicalized into eq.
const APInt *MaskC, *MaskCCmp;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(Add.getOperand(1),
m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
m_APInt(MaskCCmp)))))
@@ -1382,7 +1382,7 @@ Instruction *InstCombinerImpl::
// `select` itself may be appropriately extended, look past that.
SkipExtInMagic(Select);
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *Thr;
Value *SignExtendingValue, *Zero;
bool ShouldSignext;
@@ -1654,7 +1654,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
// sext(A < B) + zext(A > B) => ucmp/scmp(A, B)
- ICmpInst::Predicate LTPred, GTPred;
+ CmpPredicate LTPred, GTPred;
if (match(&I,
m_c_Add(m_SExt(m_c_ICmp(LTPred, m_Value(A), m_Value(B))),
m_ZExt(m_c_ICmp(GTPred, m_Deferred(A), m_Deferred(B))))) &&
@@ -1841,7 +1841,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// -->
// BW - ctlz(A - 1, false)
const APInt *XorC;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(&I,
m_c_Add(
m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 314b1f0b43e3b59..dff9304be64ddb0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -738,7 +738,7 @@ static Value *
foldAndOrOfICmpsWithPow2AndWithZero(InstCombiner::BuilderTy &Builder,
ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
const SimplifyQuery &Q) {
- CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
+ CmpPredicate Pred = IsAnd ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ;
// Make sure we have right compares for our op.
if (LHS->getPredicate() != Pred || RHS->getPredicate() != Pred)
return nullptr;
@@ -875,7 +875,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
// Try to match/decompose into: icmp eq (X & Mask), 0
auto tryToDecompose = [](ICmpInst *ICmp, Value *&X,
APInt &UnsetBitsMask) -> bool {
- CmpInst::Predicate Pred = ICmp->getPredicate();
+ CmpPredicate Pred = ICmp->getPredicate();
// Can it be decomposed into icmp eq (X & Mask), 0 ?
auto Res =
llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
@@ -944,7 +944,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
static Value *foldIsPowerOf2OrZero(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd,
InstCombiner::BuilderTy &Builder,
InstCombinerImpl &IC) {
- CmpInst::Predicate Pred0, Pred1;
+ CmpPredicate Pred0, Pred1;
Value *X;
if (!match(Cmp0, m_ICmp(Pred0, m_Intrinsic<Intrinsic::ctpop>(m_Value(X)),
m_SpecificInt(1))) ||
@@ -1117,12 +1117,12 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp,
const SimplifyQuery &Q,
InstCombiner::BuilderTy &Builder) {
Value *ZeroCmpOp;
- ICmpInst::Predicate EqPred;
+ CmpPredicate EqPred;
if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) ||
!ICmpInst::isEquality(EqPred))
return nullptr;
- ICmpInst::Predicate UnsignedPred;
+ CmpPredicate UnsignedPred;
Value *A, *B;
if (match(UnsignedICmp,
@@ -1281,7 +1281,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
const SimplifyQuery &Q) {
// Match an equality compare with a non-poison constant as Cmp0.
// Also, give up if the compare can be constant-folded to avoid looping.
- ICmpInst::Predicate Pred0;
+ CmpPredicate Pred0;
Value *X;
Constant *C;
if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) ||
@@ -1295,7 +1295,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
// common operand as operand 1 (Pred1 is swapped if the common operand was
// operand 0).
Value *Y;
- ICmpInst::Predicate Pred1;
+ CmpPredicate Pred1;
if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Specific(X))))
return nullptr;
@@ -1326,7 +1326,7 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1,
ICmpInst *ICmp2,
bool IsAnd) {
- ICmpInst::Predicate Pred1, Pred2;
+ CmpPredicate Pred1, Pred2;
Value *V1, *V2;
const APInt *C1, *C2;
if (!match(ICmp1, m_ICmp(Pred1, m_Value(V1), m_APInt(C1))) ||
@@ -1348,12 +1348,12 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1,
return nullptr;
ConstantRange CR1 = ConstantRange::makeExactICmpRegion(
- IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, *C1);
+ IsAnd ? ICmpInst::getInverseCmpPredicate(Pred1) : Pred1, *C1);
if (Offset1)
CR1 = CR1.subtract(*Offset1);
ConstantRange CR2 = ConstantRange::makeExactICmpRegion(
- IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, *C2);
+ IsAnd ? ICmpInst::getInverseCmpPredicate(Pred2) : Pred2, *C2);
if (Offset2)
CR2 = CR2.subtract(*Offset2);
@@ -3943,7 +3943,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
return V;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *Mul, *Ov, *MulIsNotZero, *UMulWithOv;
// Check if the OR weakens the overflow condition for umul.with.overflow by
// treating any non-zero result as overflow. In that case, we overflow if both
@@ -4608,7 +4608,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
}
// not (cmp A, B) = !cmp A, B
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(NotOp, m_Cmp(Pred, m_Value(), m_Value())) &&
(NotOp->hasOneUse() ||
InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(NotOp),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 54053c4c9e28e8f..d6fdade25559feb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1173,7 +1173,7 @@ Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
// This fold is only valid for equality predicates.
if (!I.isEquality())
return nullptr;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X, *Y, *Zero;
if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))),
m_CombineAnd(m_Zero(), m_Value(Zero)))))
@@ -1190,7 +1190,7 @@ Instruction *InstCombinerImpl::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) {
/// by one-less-than-bitwidth into a sign test on the original value.
Instruction *InstCombinerImpl::foldSignBitTest(ICmpInst &I) {
Instruction *Val;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
return nullptr;
@@ -1404,7 +1404,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
};
for (BranchInst *BI : DC.conditionsFor(X)) {
- ICmpInst::Predicate DomPred;
+ CmpPredicate DomPred;
const APInt *DomC;
if (!match(BI->getCondition(),
m_ICmp(DomPred, m_Specific(X), m_APInt(DomC))))
@@ -1517,7 +1517,7 @@ Instruction *
InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
const SimplifyQuery &Q) {
Value *X, *Y;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
bool YIsSExt = false;
// Try to match icmp (trunc X), (trunc Y)
if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) {
@@ -3249,7 +3249,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
// i32 Equal,
// i32 (select i1 (a < b), i32 Less, i32 Greater)
// where Equal, Less and Greater are placeholders for any three constants.
- ICmpInst::Predicate PredA;
+ CmpPredicate PredA;
if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) ||
!ICmpInst::isEquality(PredA))
return false;
@@ -3260,7 +3260,7 @@ bool InstCombinerImpl::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
std::swap(EqualVal, UnequalVal);
if (!match(EqualVal, m_ConstantInt(Equal)))
return false;
- ICmpInst::Predicate PredB;
+ CmpPredicate PredB;
Value *LHS2, *RHS2;
if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)),
m_ConstantInt(Less), m_ConstantInt(Greater))))
@@ -4565,7 +4565,7 @@ static Value *foldICmpWithLowBitMaskedVal(CmpPredicate Pred, Value *Op0,
static Value *
foldICmpWithTruncSignExtendedVal(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
- ICmpInst::Predicate SrcPred;
+ CmpPredicate SrcPred;
Value *X;
const APInt *C0, *C1; // FIXME: non-splats, potentially with undef.
// We are ok with 'shl' having multiple uses, but 'ashr' must be one-use.
@@ -4811,7 +4811,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
/// Note that the comparison is commutative, while inverted (u>=, ==) predicate
/// will mean that we are looking for the opposite answer.
Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X, *Y;
Instruction *Mul;
Instruction *Div;
@@ -4881,7 +4881,7 @@ Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) {
static Instruction *foldICmpXNegX(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X;
if (match(&I, m_c_ICmp(Pred, m_NSWNeg(m_Value(X)), m_Deferred(X)))) {
@@ -6822,7 +6822,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
/// then try to reduce patterns based on that limit.
Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
Value *X, *Y;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
// X must be 0 and bool must be true for "ULT":
// X <u (zext i1 Y) --> (X == 0) & Y
@@ -6837,7 +6837,7 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
// icmp eq/ne X, (zext/sext (icmp eq/ne X, C))
- ICmpInst::Predicate Pred1, Pred2;
+ CmpPredicate Pred1, Pred2;
const APInt *C;
Instruction *ExtI;
if (match(&I, m_c_ICmp(Pred1, m_Value(X),
@@ -7107,7 +7107,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
// (X l>> Y) == 0
static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
- ICmpInst::Predicate Pred, NewPred;
+ CmpPredicate Pred, NewPred;
Value *X, *Y;
if (match(&Cmp,
m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) {
@@ -7272,7 +7272,7 @@ static Instruction *foldReductionIdiom(ICmpInst &I,
const DataLayout &DL) {
if (I.getType()->isVectorTy())
return nullptr;
- ICmpInst::Predicate OuterPred, InnerPred;
+ CmpPredicate OuterPred, InnerPred;
Value *LHS, *RHS;
// Match lowering of @llvm.vector.reduce.and. Turn
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c7a0c35d099cc4e..50dfb58cadb17bb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -58,7 +58,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel,
// The select condition must be an equality compare with a constant operand.
Value *X;
Constant *C;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C))))
return nullptr;
@@ -425,17 +425,19 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
// icmp with a common operand also can have the common operand
// pulled after the select.
- ICmpInst::Predicate TPred, FPred;
+ CmpPredicate TPred, FPred;
if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) &&
match(FI, m_ICmp(FPred, m_Value(), m_Value()))) {
- if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) {
- bool Swapped = TPred != FPred;
+ // FIXME: Use CmpPredicate::getMatching here.
+ CmpInst::Predicate T = TPred, F = FPred;
+ if (T == F || T == ICmpInst::getSwappedCmpPredicate(F)) {
+ bool Swapped = T != F;
if (Value *MatchOp =
getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) {
Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
SI.getName() + ".v", &SI);
return new ICmpInst(
- MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred),
+ MatchIsOpZero ? TPred : ICmpInst::getSwappedCmpPredicate(TPred),
MatchOp, NewSel);
}
}
@@ -640,7 +642,7 @@ static Instruction *foldSelectICmpAndAnd(Type *SelType, const ICmpInst *Cmp,
static Value *foldSelectICmpAndZeroShl(const ICmpInst *Cmp, Value *TVal,
Value *FVal,
InstCombiner::BuilderTy &Builder) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *AndVal;
if (!match(Cmp, m_ICmp(Pred, m_Value(AndVal), m_Zero())))
return nullptr;
@@ -867,7 +869,7 @@ static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) {
auto *TrueVal = SI.getTrueValue();
auto *FalseVal = SI.getFalseValue();
Value *X, *Y;
- ICmpInst::Predicate Predicate;
+ CmpPredicate Predicate;
// Assuming that constant compared with zero is not undef (but it may be
// a vector with some undef elements). Otherwise (when a constant is undef)
@@ -1527,7 +1529,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
return nullptr;
Value *Cmp1;
- ICmpInst::Predicate Pred1;
+ CmpPredicate Pred1;
Constant *C2;
Value *ReplacementLow, *ReplacementHigh;
if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow),
@@ -1636,7 +1638,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
static Instruction *
tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
InstCombinerImpl &IC) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X;
Constant *C0;
if (!match(&Cmp, m_OneUse(m_ICmp(
@@ -1734,7 +1736,7 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
InstCombiner::BuilderTy &Builder) {
const APInt *CmpC;
Value *V;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(ICI, m_ICmp(Pred, m_Value(V), m_APInt(CmpC))))
return nullptr;
@@ -1890,7 +1892,7 @@ static Value *foldSelectWithConstOpToBinOp(ICmpInst *Cmp, Value *TrueVal,
BinaryOperator *BOp;
Constant *C1, *C2, *C3;
Value *X;
- ICmpInst::Predicate Predicate;
+ CmpPredicate Predicate;
if (!match(Cmp, m_ICmp(Predicate, m_Value(X), m_Constant(C1))))
return nullptr;
@@ -2138,7 +2140,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) {
auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) {
Type *Ty = Limit->getType();
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *TrueVal, *FalseVal, *Op;
const APInt *C;
if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)),
@@ -2347,7 +2349,7 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,
Value *TVal = Sel.getTrueValue();
Value *FVal = Sel.getFalseValue();
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *A, *B;
if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B))))
return nullptr;
@@ -2552,7 +2554,7 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
Value *X;
const APInt *C;
bool IsTrueIfSignSet;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
m_APInt(C)))) ||
!isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType)
@@ -2748,7 +2750,7 @@ static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC,
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *Op, *RemRes, *Remainder;
const APInt *C;
bool TrueIfSigned = false;
@@ -2807,7 +2809,7 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy
// a = select c, x, y ;
// f(a, c) ; f(poison, 1) cannot happen, but if a is folded
// ; to y, this can happen.
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (FI->hasOneUse() &&
match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
(Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE)) {
@@ -2856,7 +2858,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
for (bool Swap : {false, true}) {
Value *TrueVal = SI.getTrueValue();
Value *X = SI.getFalseValue();
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (Swap)
std::swap(TrueVal, X);
@@ -2936,7 +2938,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
if (Swap)
std::swap(TrueVal, X);
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *C;
bool TrueIfSigned;
if (!match(CondVal,
@@ -2980,7 +2982,7 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
Value *X = SI.getTrueValue();
Value *XBiasedHighBits = SI.getFalseValue();
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *XLowBits;
if (!match(Cond, m_ICmp(Pred, m_Value(XLowBits), m_ZeroInt())) ||
!ICmpInst::isEquality(Pred))
@@ -3159,7 +3161,7 @@ static bool impliesPoisonOrCond(const Value *ValAssumedPoison, const Value *V,
Value *LHS = ICmp->getOperand(0);
const APInt *RHSC1;
const APInt *RHSC2;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (ICmp->hasSameSign() &&
match(ICmp->getOperand(1), m_APIntForbidPoison(RHSC1)) &&
match(V, m_ICmp(Pred, m_Specific(LHS), m_APIntAllowPoison(RHSC2)))) {
@@ -3170,7 +3172,7 @@ static bool impliesPoisonOrCond(const Value *ValAssumedPoison, const Value *V,
APInt::getZero(BitWidth))
: ConstantRange(APInt::getZero(BitWidth),
APInt::getSignedMinValue(BitWidth));
- return CRX.icmp(Expected ? Pred : ICmpInst::getInversePredicate(Pred),
+ return CRX.icmp(Expected ? Pred : ICmpInst::getInverseCmpPredicate(Pred),
*RHSC2);
}
}
@@ -3539,7 +3541,7 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder,
Value *FalseVal = SI.getFalseValue();
Value *TrueVal = SI.getTrueValue();
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *Cond1;
Value *Cond0, *Ctlz, *CtlzOp;
if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(Cond0), m_APInt(Cond1))))
@@ -3590,7 +3592,7 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
Value *TV = SI.getTrueValue();
Value *FV = SI.getFalseValue();
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *LHS, *RHS;
if (!match(SI.getCondition(), m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
return nullptr;
@@ -3610,7 +3612,7 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
bool IsSigned = ICmpInst::isSigned(Pred);
bool Replace = false;
- ICmpInst::Predicate ExtendedCmpPredicate;
+ CmpPredicate ExtendedCmpPredicate;
// (x < y) ? -1 : zext(x != y)
// (x < y) ? -1 : zext(x > y)
if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
@@ -3630,7 +3632,7 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
Replace = true;
// (x == y) ? 0 : (x > y ? 1 : -1)
- ICmpInst::Predicate FalseBranchSelectPredicate;
+ CmpPredicate FalseBranchSelectPredicate;
const APInt *InnerTV, *InnerFV;
if (Pred == ICmpInst::ICMP_EQ && match(TV, m_Zero()) &&
match(FV, m_Select(m_c_ICmp(FalseBranchSelectPredicate, m_Specific(LHS),
@@ -3730,7 +3732,7 @@ static Value *foldSelectIntoAddConstant(SelectInst &SI,
Instruction *FAdd;
Constant *C;
Value *X, *Z;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
// Note: OneUse check for `Cmp` is necessary because it makes sure that other
// InstCombine folds don't undo this transformation and cause an infinite
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 09eafd09451b246..77f1763d1193969 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -86,7 +86,7 @@ static bool cheapToScalarize(Value *V, Value *EI) {
if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
return true;
- CmpInst::Predicate UnusedPred;
+ CmpPredicate UnusedPred;
if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1)))))
if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
return true;
@@ -486,7 +486,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
}
Value *X, *Y;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
cheapToScalarize(SrcVec, Index)) {
// extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8f55e5b3cc28a26..4abf2317290413f 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3478,7 +3478,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI,
// Validate the rest of constraint #1 by matching on the pred branch.
Instruction *TI = PredBB->getTerminator();
BasicBlock *TrueBB, *FalseBB;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(TI, m_Br(m_ICmp(Pred,
m_CombineOr(m_Specific(Op),
m_Specific(Op->stripPointerCasts())),
@@ -3759,7 +3759,7 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return replaceOperand(BI, 0, ConstantInt::getFalse(Cond->getType()));
// Canonicalize, for example, fcmp_one -> fcmp_oeq.
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(Cond, m_OneUse(m_FCmp(Pred, m_Value(), m_Value()))) &&
!isCanonicalPredicate(Pred)) {
// Swap destinations and condition.
@@ -3820,7 +3820,7 @@ static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
if (CstBB != SI.getDefaultDest())
return nullptr;
Value *X = Select->getOperand(3 - CstOpIdx);
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
const APInt *RHSC;
if (!match(Select->getCondition(),
m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
index b8571ba0748998f..bbc7a005b9ff4f0 100644
--- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
+++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp
@@ -132,7 +132,7 @@ static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To,
if (!BI || !BI->isConditional())
return;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *Cond = BI->getCondition();
if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant())))
return;
@@ -142,7 +142,7 @@ static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To,
if (isCondRelevantToAnyCallArgument(Cmp, CB))
Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To
? Pred
- : Cmp->getInversePredicate()});
+ : Cmp->getInverseCmpPredicate()});
}
/// Record ICmp conditions relevant to any argument in CB following Pred's
diff --git a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
index e64fc153cf3d26f..589bfd05bb5d554 100644
--- a/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/ConstraintElimination.cpp
@@ -922,7 +922,7 @@ void State::addInfoForInductions(BasicBlock &BB) {
Value *A;
Value *B;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(BB.getTerminator(),
m_Br(m_ICmp(Pred, m_Value(A), m_Value(B)), m_Value(), m_Value())))
@@ -1089,7 +1089,7 @@ void State::addInfoFor(BasicBlock &BB) {
switch (ID) {
case Intrinsic::assume: {
Value *A, *B;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(I.getOperand(0), m_ICmp(Pred, m_Value(A), m_Value(B))))
break;
if (GuaranteedToExecute) {
@@ -1537,7 +1537,7 @@ static bool checkOrAndOpImpliedByOther(
while (!Worklist.empty()) {
Value *Val = Worklist.pop_back_val();
Value *LHS, *RHS;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(Val, m_ICmp(Pred, m_Value(LHS), m_Value(RHS)))) {
// For OR, check if the negated condition implies CmpToCheck.
if (IsOr)
@@ -1833,7 +1833,7 @@ static bool eliminateConstraints(Function &F, DominatorTree &DT, LoopInfo &LI,
}
};
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!CB.isConditionFact()) {
Value *X;
if (match(CB.Inst, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) {
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 09e8301b772d96b..4799640089fa9a6 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -2053,7 +2053,7 @@ struct DSEState {
return false;
Instruction *ICmpL;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(BI->getCondition(),
m_c_ICmp(Pred,
m_CombineAnd(m_Load(m_Specific(StorePtr)),
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index cd4846e006031d4..682c5c3d8c63404 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -192,7 +192,7 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,
// mechanism that may remove flags to increase the likelihood of CSE.
Flavor = SPF_UNKNOWN;
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(Cond, m_ICmp(Pred, m_Specific(A), m_Specific(B)))) {
// Check for commuted variants of min/max by swapping predicate.
@@ -279,7 +279,7 @@ static unsigned getHashValueImpl(SimpleValue Val) {
// Hash general selects to allow matching commuted true/false operands.
// If we do not have a compare as the condition, just hash in the condition.
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *X, *Y;
if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y))))
return hash_combine(Inst->getOpcode(), Cond, A, B);
@@ -451,7 +451,7 @@ static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) {
// this code, as we simplify the double-negation before hashing the second
// select (and so still succeed at CSEing them).
if (LHSA == RHSB && LHSB == RHSA) {
- CmpInst::Predicate PredL, PredR;
+ CmpPredicate PredL, PredR;
Value *X, *Y;
if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) &&
match(CondR, m_Cmp(PredR, m_Specific(X), m_Specific(Y))) &&
diff --git a/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
index a8fda0c6ab9cbe2..2978b7990a6ebe1 100644
--- a/llvm/lib/Transforms/Scalar/GuardWidening.cpp
+++ b/llvm/lib/Transforms/Scalar/GuardWidening.cpp
@@ -727,7 +727,7 @@ GuardWideningImpl::mergeChecks(SmallVectorImpl<Value *> &ChecksToHoist,
// L >u C0 && L >u C1 -> L >u max(C0, C1)
ConstantInt *RHS0, *RHS1;
Value *LHS;
- ICmpInst::Predicate Pred0, Pred1;
+ CmpPredicate Pred0, Pred1;
// TODO: Support searching for pairs to merge from both whole lists of
// ChecksToHoist and ChecksToWiden.
if (ChecksToWiden.size() == 1 && ChecksToHoist.size() == 1 &&
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 16110cd25bc61c5..300a564e222e163 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -591,7 +591,7 @@ bool JumpThreadingPass::computeValueKnownInPredecessorsImpl(
// 'getPredicateOnEdge' method. This would be able to handle value
// inequalities better, for example if the compare is "X < 4" and "X < 3"
// is known true but "X < 4" itself is not available.
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *Val;
Constant *Cst;
if (!PredCst && match(V, m_Cmp(Pred, m_Value(Val), m_Constant(Cst))))
@@ -2744,7 +2744,7 @@ bool JumpThreadingPass::duplicateCondBranchOnPHIIntoPred(
// Pred is a predecessor of BB with an unconditional branch to BB. SI is
// a Select instruction in Pred. BB has other predecessors and SI is used in
// a PHI node in BB. SI has no other use.
-// A new basic block, NewBB, is created and SI is converted to compare and
+// A new basic block, NewBB, is created and SI is converted to compare and
// conditional branch. SI is erased from parent.
void JumpThreadingPass::unfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB,
SelectInst *SI, PHINode *SIUse,
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 3ade32027289317..a5d5eecb1ebf823 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -2430,8 +2430,8 @@ static bool hoistMinMax(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
} else
return false;
- auto MatchICmpAgainstInvariant = [&](Value *C, ICmpInst::Predicate &P,
- Value *&LHS, Value *&RHS) {
+ auto MatchICmpAgainstInvariant = [&](Value *C, CmpPredicate &P, Value *&LHS,
+ Value *&RHS) {
if (!match(C, m_OneUse(m_ICmp(P, m_Value(LHS), m_Value(RHS)))))
return false;
if (!LHS->getType()->isIntegerTy())
@@ -2448,12 +2448,13 @@ static bool hoistMinMax(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
P = ICmpInst::getInversePredicate(P);
return true;
};
- ICmpInst::Predicate P1, P2;
+ CmpPredicate P1, P2;
Value *LHS1, *LHS2, *RHS1, *RHS2;
if (!MatchICmpAgainstInvariant(Cond1, P1, LHS1, RHS1) ||
!MatchICmpAgainstInvariant(Cond2, P2, LHS2, RHS2))
return false;
- if (P1 != P2 || LHS1 != LHS2)
+ // FIXME: Use CmpPredicate::getMatching here.
+ if (P1 != static_cast<CmpInst::Predicate>(P2) || LHS1 != LHS2)
return false;
// Everything is fine, we can do the transform.
@@ -2678,7 +2679,7 @@ static bool hoistAddSub(Instruction &I, Loop &L, ICFLoopSafetyInfo &SafetyInfo,
MemorySSAUpdater &MSSAU, AssumptionCache *AC,
DominatorTree *DT) {
using namespace PatternMatch;
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *LHS, *RHS;
if (!match(&I, m_ICmp(Pred, m_Value(LHS), m_Value(RHS))))
return false;
diff --git a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
index ff077624802be2f..73f1942849ac2f6 100644
--- a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
@@ -32,7 +32,7 @@ struct ConditionInfo {
/// ICmp instruction with this condition
ICmpInst *ICmp = nullptr;
/// Preciate info
- ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
+ CmpPredicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
/// AddRec llvm value
Value *AddRecValue = nullptr;
/// Non PHI AddRec llvm value
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 05cf638d3f09dff..ba1c2241aea9acd 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2432,7 +2432,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
// Step 1: Check if the loop backedge is in desirable form.
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *CmpLHS, *CmpRHS;
BasicBlock *TrueBB, *FalseBB;
if (!match(LoopHeaderBB->getTerminator(),
@@ -2797,7 +2797,7 @@ static bool detectShiftUntilZeroIdiom(Loop *CurLoop, ScalarEvolution *SE,
// Step 1: Check if the loop backedge, condition is in desirable form.
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
BasicBlock *TrueBB, *FalseBB;
if (!match(LoopHeaderBB->getTerminator(),
m_Br(m_Instruction(ValShiftedIsZero), m_BasicBlock(TrueBB),
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index d8ef450eeb9a15f..0712ff77151e293 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -2990,9 +2990,11 @@ static bool collectUnswitchCandidates(
/// into its equivalent where `Pred` is something that we support for injected
/// invariants (so far it is limited to ult), LHS in canonicalized form is
/// non-invariant and RHS is an invariant.
-static void canonicalizeForInvariantConditionInjection(
- ICmpInst::Predicate &Pred, Value *&LHS, Value *&RHS, BasicBlock *&IfTrue,
- BasicBlock *&IfFalse, const Loop &L) {
+static void canonicalizeForInvariantConditionInjection(CmpPredicate &Pred,
+ Value *&LHS, Value *&RHS,
+ BasicBlock *&IfTrue,
+ BasicBlock *&IfFalse,
+ const Loop &L) {
if (!L.contains(IfTrue)) {
Pred = ICmpInst::getInversePredicate(Pred);
std::swap(IfTrue, IfFalse);
@@ -3235,7 +3237,7 @@ static bool collectUnswitchCandidatesWithInjections(
// other).
for (auto *DTN = DT.getNode(Latch); L.contains(DTN->getBlock());
DTN = DTN->getIDom()) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Value *LHS = nullptr, *RHS = nullptr;
BasicBlock *IfTrue = nullptr, *IfFalse = nullptr;
auto *BB = DTN->getBlock();
diff --git a/llvm/lib/Transforms/Utils/LoopPeel.cpp b/llvm/lib/Transforms/Utils/LoopPeel.cpp
index 3cbde39b30b4e41..9a24c1b0d03de7e 100644
--- a/llvm/lib/Transforms/Utils/LoopPeel.cpp
+++ b/llvm/lib/Transforms/Utils/LoopPeel.cpp
@@ -378,7 +378,7 @@ static unsigned countToEliminateCompares(Loop &L, unsigned MaxPeelCount,
return;
}
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (!match(Condition, m_ICmp(Pred, m_Value(LeftVal), m_Value(RightVal))))
return;
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 791d528823972d7..0bc752a92340750 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1816,7 +1816,7 @@ bool SCEVExpander::hasRelatedExistingExpansion(const SCEV *S,
// Look for suitable value in simple conditions at the loop exits.
for (BasicBlock *BB : ExitingBlocks) {
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
Instruction *LHS, *RHS;
if (!match(BB->getTerminator(),
diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index 7fca1a6aa526054..f05d32d980e5a93 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -2164,16 +2164,14 @@ void WidenIV::calculatePostIncRange(Instruction *NarrowDef,
!NarrowDefRHS->isNonNegative())
return;
- auto UpdateRangeFromCondition = [&] (Value *Condition,
- bool TrueDest) {
- CmpInst::Predicate Pred;
+ auto UpdateRangeFromCondition = [&](Value *Condition, bool TrueDest) {
+ CmpPredicate Pred;
Value *CmpRHS;
if (!match(Condition, m_ICmp(Pred, m_Specific(NarrowDefLHS),
m_Value(CmpRHS))))
return;
- CmpInst::Predicate P =
- TrueDest ? Pred : CmpInst::getInversePredicate(Pred);
+ CmpPredicate P = TrueDest ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS));
auto CmpConstrainedLHSRange =
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index dd1e53a05ebeb86..e2742fbd9b4f096 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -514,7 +514,7 @@ static bool isCommutative(Instruction *I) {
BO->uses(),
[](const Use &U) {
// Commutative, if icmp eq/ne sub, 0
- ICmpInst::Predicate Pred;
+ CmpPredicate Pred;
if (match(U.getUser(),
m_ICmp(Pred, m_Specific(U.get()), m_Zero())) &&
(Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE))
@@ -11463,7 +11463,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
case Instruction::FCmp:
case Instruction::ICmp:
case Instruction::Select: {
- CmpInst::Predicate VecPred, SwappedVecPred;
+ CmpPredicate VecPred, SwappedVecPred;
auto MatchCmp = m_Cmp(VecPred, m_Value(), m_Value());
if (match(VL0, m_Select(MatchCmp, m_Value(), m_Value())) ||
match(VL0, MatchCmp))
@@ -11477,13 +11477,15 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
return InstructionCost(TTI::TCC_Free);
auto *VI = cast<Instruction>(UniqueValues[Idx]);
- CmpInst::Predicate CurrentPred = ScalarTy->isFloatingPointTy()
- ? CmpInst::BAD_FCMP_PREDICATE
- : CmpInst::BAD_ICMP_PREDICATE;
+ CmpPredicate CurrentPred = ScalarTy->isFloatingPointTy()
+ ? CmpInst::BAD_FCMP_PREDICATE
+ : CmpInst::BAD_ICMP_PREDICATE;
auto MatchCmp = m_Cmp(CurrentPred, m_Value(), m_Value());
+ // FIXME: Use CmpPredicate::getMatching here.
if ((!match(VI, m_Select(MatchCmp, m_Value(), m_Value())) &&
!match(VI, MatchCmp)) ||
- (CurrentPred != VecPred && CurrentPred != SwappedVecPred))
+ (CurrentPred != static_cast<CmpInst::Predicate>(VecPred) &&
+ CurrentPred != static_cast<CmpInst::Predicate>(SwappedVecPred)))
VecPred = SwappedVecPred = ScalarTy->isFloatingPointTy()
? CmpInst::BAD_FCMP_PREDICATE
: CmpInst::BAD_ICMP_PREDICATE;
@@ -19391,7 +19393,7 @@ class HorizontalReduction {
// %3 = extractelement <2 x i32> %a, i32 0
// %4 = extractelement <2 x i32> %a, i32 1
// %select = select i1 %cond, i32 %3, i32 %4
- CmpInst::Predicate Pred;
+ CmpPredicate Pred;
Instruction *L1;
Instruction *L2;
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index ebbd05e6d47afcf..772e0aed9b6c02c 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -596,7 +596,7 @@ bool VectorCombine::foldExtractExtract(Instruction &I) {
return false;
Instruction *I0, *I1;
- CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+ CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
!match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
return false;
@@ -922,7 +922,7 @@ bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
/// Match a vector binop or compare instruction with at least one inserted
/// scalar operand and convert to scalar binop/cmp followed by insertelement.
bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
- CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+ CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
Value *Ins0, *Ins1;
if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
!match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
@@ -1062,9 +1062,11 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
Instruction *I0, *I1;
Constant *C0, *C1;
- CmpInst::Predicate P0, P1;
+ CmpPredicate P0, P1;
+ // FIXME: Use CmpPredicate::getMatching here.
if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
- !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) || P0 != P1)
+ !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))) ||
+ P0 != static_cast<CmpInst::Predicate>(P1))
return false;
// The compare operands must be extracts of the same vector with constant
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 367ba6ab52a5965..47fde5782a13bce 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -2381,7 +2381,7 @@ TYPED_TEST(MutableConstTest, ICmp) {
ValueType MatchL;
ValueType MatchR;
- ICmpInst::Predicate MatchPred;
+ CmpPredicate MatchPred;
EXPECT_TRUE(m_ICmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateICmp(Pred, L, R)));
@@ -2473,7 +2473,7 @@ TYPED_TEST(MutableConstTest, FCmp) {
ValueType MatchL;
ValueType MatchR;
- FCmpInst::Predicate MatchPred;
+ CmpPredicate MatchPred;
EXPECT_TRUE(m_FCmp(MatchPred, m_Value(MatchL), m_Value(MatchR))
.match((InstructionType)IRB.CreateFCmp(Pred, L, R)));
>From f46e64f8bf35881fdfb7bc15516c6a36ca0df1d6 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Fri, 13 Dec 2024 10:58:36 +0000
Subject: [PATCH 2/2] PatternMatch, CmpPredicate: address review
---
llvm/include/llvm/IR/CmpPredicate.h | 7 +++----
llvm/include/llvm/IR/PatternMatch.h | 2 +-
llvm/lib/IR/Instructions.cpp | 11 ++++-------
3 files changed, 8 insertions(+), 12 deletions(-)
diff --git a/llvm/include/llvm/IR/CmpPredicate.h b/llvm/include/llvm/IR/CmpPredicate.h
index ae027305cb4199d..ce78e4311f9f826 100644
--- a/llvm/include/llvm/IR/CmpPredicate.h
+++ b/llvm/include/llvm/IR/CmpPredicate.h
@@ -62,12 +62,11 @@ class CmpPredicate {
bool operator==(CmpPredicate) const = delete;
bool operator!=(CmpPredicate) const = delete;
- /// TypeSwitch over the CmpInst and either do ICmpInst::getCmpPredicate() or
- /// FCmpInst::getPredicate().
+ /// Do a ICmpInst::getCmpPredicate() or CmpInst::getPredicate(), as
+ /// appropriate.
static CmpPredicate get(const CmpInst *Cmp);
- /// Get the swapped predicate of a CmpPredicate, using
- /// CmpInst::isIntPredicate().
+ /// Get the swapped predicate of a CmpPredicate.
static CmpPredicate getSwapped(CmpPredicate P);
/// Get the swapped predicate of a CmpInst.
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index ed7e1eff6f0005b..cc0e8d598ff1eaf 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -695,7 +695,7 @@ struct icmp_pred_with_threshold {
/// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
/// to Threshold. For vectors, this includes constants with undefined elements.
inline cst_pred_ty<icmp_pred_with_threshold>
-m_SpecificInt_ICMP(CmpPredicate Predicate, const APInt &Threshold) {
+m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) {
cst_pred_ty<icmp_pred_with_threshold> P;
P.Pred = Predicate;
P.Thr = &Threshold;
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index d27863574dcf25a..d1da02c744f18c6 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -16,7 +16,6 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
@@ -3934,15 +3933,13 @@ std::optional<CmpPredicate> CmpPredicate::getMatching(CmpPredicate A,
}
CmpPredicate CmpPredicate::get(const CmpInst *Cmp) {
- return TypeSwitch<const CmpInst *, CmpPredicate>(Cmp)
- .Case<ICmpInst>([](auto *ICI) { return ICI->getCmpPredicate(); })
- .Case<FCmpInst>([](auto *FCI) { return FCI->getPredicate(); });
+ if (auto *ICI = dyn_cast<ICmpInst>(Cmp))
+ return ICI->getCmpPredicate();
+ return Cmp->getPredicate();
}
CmpPredicate CmpPredicate::getSwapped(CmpPredicate P) {
- return CmpInst::isIntPredicate(P)
- ? ICmpInst::getSwappedCmpPredicate(P)
- : CmpPredicate{CmpInst::getSwappedPredicate(P)};
+ return {CmpInst::getSwappedPredicate(P), P.hasSameSign()};
}
CmpPredicate CmpPredicate::getSwapped(const CmpInst *Cmp) {
More information about the llvm-commits
mailing list