[llvm] [PatternMatch] Use `m_SpecificCmp` matchers. NFC. (PR #100878)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 28 09:00:35 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/100878

>From 861ffa4ec5f7bde5a194a7715593a1b5359eb581 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sat, 27 Jul 2024 22:27:56 +0800
Subject: [PATCH 1/3] [PatternMatch] Use `m_SpecificCmp` matchers. NFC.

---
 llvm/lib/Analysis/InstructionSimplify.cpp     |  9 +--
 llvm/lib/Analysis/ValueTracking.cpp           | 11 ++--
 .../AMDGPU/AMDGPULowerKernelAttributes.cpp    |  6 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp       | 42 +++++++------
 .../AggressiveInstCombine.cpp                 |  7 +--
 .../InstCombine/InstCombineAddSub.cpp         |  5 +-
 .../InstCombine/InstCombineAndOrXor.cpp       | 61 +++++++++----------
 .../InstCombine/InstCombineCalls.cpp          | 17 +++---
 .../InstCombine/InstCombineCompares.cpp       | 11 ++--
 .../InstCombine/InstCombineSelect.cpp         |  5 +-
 .../InstCombineSimplifyDemanded.cpp           |  3 +-
 .../InstCombine/InstructionCombining.cpp      |  6 +-
 .../Scalar/DeadStoreElimination.cpp           |  8 +--
 llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp |  3 +-
 .../Vectorize/LoopIdiomVectorize.cpp          | 12 ++--
 15 files changed, 99 insertions(+), 107 deletions(-)

diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 3a7ae577bb068..abf0735aa2c99 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -110,10 +110,11 @@ static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
   // -->
   // %TV
   Value *X, *Y;
-  if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal),
-                                      m_Specific(FalseVal)),
-                             m_ICmp(Pred2, m_Value(X), m_Value(Y)))) ||
-      Pred1 != Pred2 || Pred1 != ExpectedPred)
+  if (!match(
+          Cond,
+          m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal), m_Specific(FalseVal)),
+                    m_SpecificICmp(ExpectedPred, m_Value(X), m_Value(Y)))) ||
+      Pred1 != ExpectedPred)
     return nullptr;
 
   if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index bfd26fadd237b..10884f8d32b86 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -254,8 +254,7 @@ bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
 
 bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
   return !I->user_empty() && all_of(I->users(), [](const User *U) {
-    ICmpInst::Predicate P;
-    return match(U, m_ICmp(P, m_Value(), m_Zero()));
+    return match(U, m_ICmp(m_Value(), m_Zero()));
   });
 }
 
@@ -2594,10 +2593,10 @@ static bool isNonZeroRecurrence(const PHINode *PN) {
 }
 
 static bool matchOpWithOpEqZero(Value *Op0, Value *Op1) {
-  ICmpInst::Predicate Pred;
-  return (match(Op0, m_ZExtOrSExt(m_ICmp(Pred, m_Specific(Op1), m_Zero()))) ||
-          match(Op1, m_ZExtOrSExt(m_ICmp(Pred, m_Specific(Op0), m_Zero())))) &&
-         Pred == ICmpInst::ICMP_EQ;
+  return match(Op0, m_ZExtOrSExt(m_SpecificICmp(ICmpInst::ICMP_EQ,
+                                                m_Specific(Op1), m_Zero()))) ||
+         match(Op1, m_ZExtOrSExt(m_SpecificICmp(ICmpInst::ICMP_EQ,
+                                                m_Specific(Op0), m_Zero())));
 }
 
 static bool isNonZeroAdd(const APInt &DemandedElts, unsigned Depth,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp b/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp
index e91d05954a1c9..e724c978c44b6 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp
@@ -225,10 +225,8 @@ static bool processUse(CallInst *CI, bool IsV5OrAbove) {
                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
 
       for (User *ICmp : BlockCount->users()) {
-        ICmpInst::Predicate Pred;
-        if (match(ICmp, m_ICmp(Pred, GroupIDIntrin, m_Specific(BlockCount)))) {
-          if (Pred != ICmpInst::ICMP_ULT)
-            continue;
+        if (match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT, GroupIDIntrin,
+                                       m_Specific(BlockCount)))) {
           ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
           MadeChange = true;
         }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ad59b13933a6a..b971afda4229a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3435,12 +3435,9 @@ X86TargetLowering::getJumpConditionMergingParams(Instruction::BinaryOps Opc,
   if (BaseCost >= 0 && Subtarget.hasCCMP())
     BaseCost += BrMergingCcmpBias;
   // a == b && a == c is a fast pattern on x86.
-  ICmpInst::Predicate Pred;
   if (BaseCost >= 0 && Opc == Instruction::And &&
-      match(Lhs, m_ICmp(Pred, m_Value(), m_Value())) &&
-      Pred == ICmpInst::ICMP_EQ &&
-      match(Rhs, m_ICmp(Pred, m_Value(), m_Value())) &&
-      Pred == ICmpInst::ICMP_EQ)
+      match(Lhs, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(), m_Value())) &&
+      match(Rhs, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(), m_Value())))
     BaseCost += 1;
   return {BaseCost, BrMergingLikelyBias.getValue(),
           BrMergingUnlikelyBias.getValue()};
@@ -30760,10 +30757,12 @@ static bool shouldExpandCmpArithRMWInIR(AtomicRMWInst *AI) {
     if (match(I, m_c_ICmp(Pred, m_Sub(m_ZeroInt(), m_Specific(Op)), m_Value())))
       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE;
     if (match(I, m_OneUse(m_c_Add(m_Specific(Op), m_Value())))) {
-      if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
-        return Pred == CmpInst::ICMP_SLT;
-      if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
-        return Pred == CmpInst::ICMP_SGT;
+      if (match(I->user_back(),
+                m_SpecificICmp(CmpInst::ICMP_SLT, m_Value(), m_ZeroInt())))
+        return true;
+      if (match(I->user_back(),
+                m_SpecificICmp(CmpInst::ICMP_SGT, m_Value(), m_AllOnes())))
+        return true;
     }
     return false;
   }
@@ -30771,10 +30770,12 @@ static bool shouldExpandCmpArithRMWInIR(AtomicRMWInst *AI) {
     if (match(I, m_c_ICmp(Pred, m_Specific(Op), m_Value())))
       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE;
     if (match(I, m_OneUse(m_Sub(m_Value(), m_Specific(Op))))) {
-      if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
-        return Pred == CmpInst::ICMP_SLT;
-      if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
-        return Pred == CmpInst::ICMP_SGT;
+      if (match(I->user_back(),
+                m_SpecificICmp(CmpInst::ICMP_SLT, m_Value(), m_ZeroInt())))
+        return true;
+      if (match(I->user_back(),
+                m_SpecificICmp(CmpInst::ICMP_SGT, m_Value(), m_AllOnes())))
+        return true;
     }
     return false;
   }
@@ -30785,18 +30786,21 @@ static bool shouldExpandCmpArithRMWInIR(AtomicRMWInst *AI) {
     if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE ||
              Pred == CmpInst::ICMP_SLT;
-    if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
-      return Pred == CmpInst::ICMP_SGT;
+    if (match(I->user_back(),
+              m_SpecificICmp(CmpInst::ICMP_SGT, m_Value(), m_AllOnes())))
+      return true;
     return false;
   }
   if (Opc == AtomicRMWInst::Xor) {
     if (match(I, m_c_ICmp(Pred, m_Specific(Op), m_Value())))
       return Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE;
     if (match(I, m_OneUse(m_c_Xor(m_Specific(Op), m_Value())))) {
-      if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_ZeroInt())))
-        return Pred == CmpInst::ICMP_SLT;
-      if (match(I->user_back(), m_ICmp(Pred, m_Value(), m_AllOnes())))
-        return Pred == CmpInst::ICMP_SGT;
+      if (match(I->user_back(),
+                m_SpecificICmp(CmpInst::ICMP_SLT, m_Value(), m_ZeroInt())))
+        return true;
+      if (match(I->user_back(),
+                m_SpecificICmp(CmpInst::ICMP_SGT, m_Value(), m_AllOnes())))
+        return true;
     }
     return false;
   }
diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
index d5a38ec17a2a8..09ffc2d184f18 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
@@ -135,15 +135,12 @@ static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
   if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
     return false;
 
-  ICmpInst::Predicate Pred;
   BasicBlock *PhiBB = Phi.getParent();
-  if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()),
+  if (!match(TermI, m_Br(m_SpecificICmp(CmpInst::ICMP_EQ, m_Specific(ShAmt),
+                                        m_ZeroInt()),
                          m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
     return false;
 
-  if (Pred != CmpInst::ICMP_EQ)
-    return false;
-
   IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
 
   if (ShVal0 == ShVal1)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 0a55f4762fdf0..4341ad99f7175 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1697,9 +1697,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   ICmpInst::Predicate Pred;
   uint64_t BitWidth = Ty->getScalarSizeInBits();
   if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowPoison(BitWidth - 1))) &&
-      match(RHS, m_OneUse(m_ZExt(
-                     m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) &&
-      Pred == CmpInst::ICMP_SGT) {
+      match(RHS, m_OneUse(m_ZExt(m_OneUse(m_SpecificICmp(
+                     CmpInst::ICMP_SGT, m_Specific(A), m_ZeroInt())))))) {
     Value *NotZero = Builder.CreateIsNotNull(A, "isnotnull");
     Value *Zext = Builder.CreateZExt(NotZero, Ty, "isnotnull.zext");
     return BinaryOperator::CreateOr(LHS, Zext);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index f9caa4da44931..4ca12d5b92f18 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -818,11 +818,11 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
   // Match  icmp ult (add %arg, C01), C1   (C1 == C01 << 1; powers of two)
   auto tryToMatchSignedTruncationCheck = [](ICmpInst *ICmp, Value *&X,
                                             APInt &SignBitMask) -> bool {
-    CmpInst::Predicate Pred;
     const APInt *I01, *I1; // powers of two; I1 == I01 << 1
-    if (!(match(ICmp,
-                m_ICmp(Pred, m_Add(m_Value(X), m_Power2(I01)), m_Power2(I1))) &&
-          Pred == ICmpInst::ICMP_ULT && I1->ugt(*I01) && I01->shl(1) == *I1))
+    if (!(match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT,
+                                     m_Add(m_Value(X), m_Power2(I01)),
+                                     m_Power2(I1))) &&
+          I1->ugt(*I01) && I01->shl(1) == *I1))
       return false;
     // Which bit is the new sign bit as per the 'signed truncation' pattern?
     SignBitMask = *I01;
@@ -936,20 +936,21 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd,
     std::swap(Cmp0, Cmp1);
 
   // (X != 0) && (ctpop(X) u< 2) --> ctpop(X) == 1
-  CmpInst::Predicate Pred0, Pred1;
   Value *X;
-  if (JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) &&
-      match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)),
-                         m_SpecificInt(2))) &&
-      Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) {
+  if (JoinedByAnd &&
+      match(Cmp0, m_SpecificICmp(ICmpInst::ICMP_NE, m_Value(X), m_ZeroInt())) &&
+      match(Cmp1, m_SpecificICmp(ICmpInst::ICMP_ULT,
+                                 m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)),
+                                 m_SpecificInt(2)))) {
     Value *CtPop = Cmp1->getOperand(0);
     return Builder.CreateICmpEQ(CtPop, ConstantInt::get(CtPop->getType(), 1));
   }
   // (X == 0) || (ctpop(X) u> 1) --> ctpop(X) != 1
-  if (!JoinedByAnd && match(Cmp0, m_ICmp(Pred0, m_Value(X), m_ZeroInt())) &&
-      match(Cmp1, m_ICmp(Pred1, m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)),
-                         m_SpecificInt(1))) &&
-      Pred0 == ICmpInst::ICMP_EQ && Pred1 == ICmpInst::ICMP_UGT) {
+  if (!JoinedByAnd &&
+      match(Cmp0, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(X), m_ZeroInt())) &&
+      match(Cmp1, m_SpecificICmp(ICmpInst::ICMP_UGT,
+                                 m_Intrinsic<Intrinsic::ctpop>(m_Specific(X)),
+                                 m_SpecificInt(1)))) {
     Value *CtPop = Cmp1->getOperand(0);
     return Builder.CreateICmpNE(CtPop, ConstantInt::get(CtPop->getType(), 1));
   }
@@ -1608,31 +1609,30 @@ static Instruction *reassociateFCmps(BinaryOperator &BO,
   // There are 4 commuted variants of the pattern. Canonicalize operands of this
   // logic op so an fcmp is operand 0 and a matching logic op is operand 1.
   Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1), *X;
-  FCmpInst::Predicate Pred;
-  if (match(Op1, m_FCmp(Pred, m_Value(), m_AnyZeroFP())))
+  if (match(Op1, m_FCmp(m_Value(), m_AnyZeroFP())))
     std::swap(Op0, Op1);
 
   // Match inner binop and the predicate for combining 2 NAN checks into 1.
   Value *BO10, *BO11;
   FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD
                                                            : FCmpInst::FCMP_UNO;
-  if (!match(Op0, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())) || Pred != NanPred ||
+  if (!match(Op0, m_SpecificFCmp(NanPred, m_Value(X), m_AnyZeroFP())) ||
       !match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11))))
     return nullptr;
 
   // The inner logic op must have a matching fcmp operand.
   Value *Y;
-  if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) ||
-      Pred != NanPred || X->getType() != Y->getType())
+  if (!match(BO10, m_SpecificFCmp(NanPred, m_Value(Y), m_AnyZeroFP())) ||
+      X->getType() != Y->getType())
     std::swap(BO10, BO11);
 
-  if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) ||
-      Pred != NanPred || X->getType() != Y->getType())
+  if (!match(BO10, m_SpecificFCmp(NanPred, m_Value(Y), m_AnyZeroFP())) ||
+      X->getType() != Y->getType())
     return nullptr;
 
   // and (fcmp ord X, 0), (and (fcmp ord Y, 0), Z) --> and (fcmp ord X, Y), Z
   // or  (fcmp uno X, 0), (or  (fcmp uno Y, 0), Z) --> or  (fcmp uno X, Y), Z
-  Value *NewFCmp = Builder.CreateFCmp(Pred, X, Y);
+  Value *NewFCmp = Builder.CreateFCmp(NanPred, X, Y);
   if (auto *NewFCmpInst = dyn_cast<FCmpInst>(NewFCmp)) {
     // Intersect FMF from the 2 source fcmps.
     NewFCmpInst->copyIRFlags(Op0);
@@ -1744,14 +1744,13 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
   //   -> zext(bitwise(A < 0, icmp))
   auto FoldBitwiseICmpZeroWithICmp = [&](Value *Op0,
                                          Value *Op1) -> Instruction * {
-    ICmpInst::Predicate Pred;
     Value *A;
     bool IsMatched =
         match(Op0,
               m_OneUse(m_LShr(
                   m_Value(A),
                   m_SpecificInt(Op0->getType()->getScalarSizeInBits() - 1)))) &&
-        match(Op1, m_OneUse(m_ZExt(m_ICmp(Pred, m_Value(), m_Value()))));
+        match(Op1, m_OneUse(m_ZExt(m_ICmp(m_Value(), m_Value()))));
 
     if (!IsMatched)
       return nullptr;
@@ -3878,14 +3877,14 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   if (match(&I,
             m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)),
                                 m_Value(Ov)),
-                   m_CombineAnd(m_ICmp(Pred,
-                                       m_CombineAnd(m_ExtractValue<0>(
-                                                        m_Deferred(UMulWithOv)),
-                                                    m_Value(Mul)),
-                                       m_ZeroInt()),
-                                m_Value(MulIsNotZero)))) &&
-      (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse())) &&
-      Pred == CmpInst::ICMP_NE) {
+                   m_CombineAnd(
+                       m_SpecificICmp(ICmpInst::ICMP_NE,
+                                      m_CombineAnd(m_ExtractValue<0>(
+                                                       m_Deferred(UMulWithOv)),
+                                                   m_Value(Mul)),
+                                      m_ZeroInt()),
+                       m_Value(MulIsNotZero)))) &&
+      (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse()))) {
     Value *A, *B;
     if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>(
                               m_Value(A), m_Value(B)))) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index f6c4b6e180937..9e690cd2c624f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3031,10 +3031,10 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
 
     // assume( (load addr) != null ) -> add 'nonnull' metadata to load
     // (if assume is valid at the load)
-    CmpInst::Predicate Pred;
     Instruction *LHS;
-    if (match(IIOperand, m_ICmp(Pred, m_Instruction(LHS), m_Zero())) &&
-        Pred == ICmpInst::ICMP_NE && LHS->getOpcode() == Instruction::Load &&
+    if (match(IIOperand, m_SpecificICmp(ICmpInst::ICMP_NE, m_Instruction(LHS),
+                                        m_Zero())) &&
+        LHS->getOpcode() == Instruction::Load &&
         LHS->getType()->isPointerTy() &&
         isValidAssumeForContext(II, LHS, &DT)) {
       MDNode *MD = MDNode::get(II->getContext(), std::nullopt);
@@ -3073,8 +3073,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     // into
     // call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ]
     if (EnableKnowledgeRetention &&
-        match(IIOperand, m_Cmp(Pred, m_Value(A), m_Zero())) &&
-        Pred == CmpInst::ICMP_NE && A->getType()->isPointerTy()) {
+        match(IIOperand,
+              m_SpecificICmp(ICmpInst::ICMP_NE, m_Value(A), m_Zero())) &&
+        A->getType()->isPointerTy()) {
       if (auto *Replacement = buildAssumeFromKnowledge(
               {RetainedKnowledge{Attribute::NonNull, 0, A}}, Next, &AC, &DT)) {
 
@@ -3094,9 +3095,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     uint64_t AlignMask;
     if (EnableKnowledgeRetention &&
         match(IIOperand,
-              m_Cmp(Pred, m_And(m_Value(A), m_ConstantInt(AlignMask)),
-                    m_Zero())) &&
-        Pred == CmpInst::ICMP_EQ) {
+              m_SpecificICmp(ICmpInst::ICMP_EQ,
+                             m_And(m_Value(A), m_ConstantInt(AlignMask)),
+                             m_Zero()))) {
       if (isPowerOf2_64(AlignMask + 1)) {
         uint64_t Offset = 0;
         match(A, m_Add(m_Value(A), m_ConstantInt(Offset)));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index abadf54a96767..3b6df2760ecc2 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5772,8 +5772,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
   //    -> icmp eq/ne X, rotate-left(X)
   // We generally try to convert rotate-right -> rotate-left, this just
   // canonicalizes another case.
-  CmpInst::Predicate PredUnused = Pred;
-  if (match(&I, m_c_ICmp(PredUnused, m_Value(A),
+  if (match(&I, m_c_ICmp(m_Value(A),
                          m_OneUse(m_Intrinsic<Intrinsic::fshr>(
                              m_Deferred(A), m_Deferred(A), m_Value(B))))))
     return new ICmpInst(
@@ -5783,8 +5782,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
   // Canonicalize:
   // icmp eq/ne OneUse(A ^ Cst), B --> icmp eq/ne (A ^ B), Cst
   Constant *Cst;
-  if (match(&I, m_c_ICmp(PredUnused,
-                         m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))),
+  if (match(&I, m_c_ICmp(m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))),
                          m_CombineAnd(m_Value(B), m_Unless(m_ImmConstant())))))
     return new ICmpInst(Pred, Builder.CreateXor(A, B), Cst);
 
@@ -5795,13 +5793,12 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
                                 m_c_Xor(m_Value(B), m_Deferred(A))),
                     m_Sub(m_Value(B), m_Deferred(A)));
     std::optional<bool> IsZero = std::nullopt;
-    if (match(&I, m_c_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)),
+    if (match(&I, m_c_ICmp(m_OneUse(m_c_And(m_Value(A), m_Matcher)),
                            m_Deferred(A))))
       IsZero = false;
     // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0)
     else if (match(&I,
-                   m_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)),
-                          m_Zero())))
+                   m_ICmp(m_OneUse(m_c_And(m_Value(A), m_Matcher)), m_Zero())))
       IsZero = true;
 
     if (IsZero && isKnownToBeAPowerOfTwo(A, /* OrZero */ true, /*Depth*/ 0, &I))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index aaf4ece3249a2..a22ee1de0ac21 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2484,9 +2484,8 @@ static Instruction *foldSelectFunnelShift(SelectInst &Sel,
 
   // Finally, see if the select is filtering out a shift-by-zero.
   Value *Cond = Sel.getCondition();
-  ICmpInst::Predicate Pred;
-  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) ||
-      Pred != ICmpInst::ICMP_EQ)
+  if (!match(Cond, m_OneUse(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(ShAmt),
+                                           m_ZeroInt()))))
     return nullptr;
 
   // If this is not a rotate then the select was blocking poison from the
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 8a6ec3076ac62..c494fec84c1e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -388,8 +388,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
       // invert the transform that reduces set bits and infinite-loop.
       Value *X;
       const APInt *CmpC;
-      ICmpInst::Predicate Pred;
-      if (!match(I->getOperand(0), m_ICmp(Pred, m_Value(X), m_APInt(CmpC))) ||
+      if (!match(I->getOperand(0), m_ICmp(m_Value(X), m_APInt(CmpC))) ||
           isa<Constant>(X) || CmpC->getBitWidth() != SelC->getBitWidth())
         return ShrinkDemandedConstant(I, OpNo, DemandedMask);
 
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 0d8e7e92c5c8e..0fb8b639c97b9 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1651,14 +1651,14 @@ static Constant *constantFoldOperationIntoSelectOperand(Instruction &I,
                                                         bool IsTrueArm) {
   SmallVector<Constant *> ConstOps;
   for (Value *Op : I.operands()) {
-    CmpInst::Predicate Pred;
     Constant *C = nullptr;
     if (Op == SI) {
       C = dyn_cast<Constant>(IsTrueArm ? SI->getTrueValue()
                                        : SI->getFalseValue());
     } else if (match(SI->getCondition(),
-                     m_ICmp(Pred, m_Specific(Op), m_Constant(C))) &&
-               Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE) &&
+                     m_SpecificICmp(IsTrueArm ? ICmpInst::ICMP_EQ
+                                              : ICmpInst::ICMP_NE,
+                                    m_Specific(Op), m_Constant(C))) &&
                isGuaranteedNotToBeUndefOrPoison(C)) {
       // Pass
     } else {
diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
index 931606c6f8fe1..992139a95a43d 100644
--- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp
@@ -1889,12 +1889,12 @@ struct DSEState {
         return true;
       auto *Ptr = Memset->getArgOperand(0);
       auto *TI = MallocBB->getTerminator();
-      ICmpInst::Predicate Pred;
       BasicBlock *TrueBB, *FalseBB;
-      if (!match(TI, m_Br(m_ICmp(Pred, m_Specific(Ptr), m_Zero()), TrueBB,
-                          FalseBB)))
+      if (!match(TI, m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(Ptr),
+                                         m_Zero()),
+                          TrueBB, FalseBB)))
         return false;
-      if (Pred != ICmpInst::ICMP_EQ || MemsetBB != FalseBB)
+      if (MemsetBB != FalseBB)
         return false;
       return true;
     };
diff --git a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
index 6092cd1bc08be..ff077624802be 100644
--- a/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopBoundSplit.cpp
@@ -160,9 +160,8 @@ static bool isProcessableCondBI(const ScalarEvolution &SE,
                                 const BranchInst *BI) {
   BasicBlock *TrueSucc = nullptr;
   BasicBlock *FalseSucc = nullptr;
-  ICmpInst::Predicate Pred;
   Value *LHS, *RHS;
-  if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)),
+  if (!match(BI, m_Br(m_ICmp(m_Value(LHS), m_Value(RHS)),
                       m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc))))
     return false;
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 64e04cae2773f..cb31e2a2ecaec 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -268,25 +268,25 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
             return false;
 
   // Match the branch instruction for the header
-  ICmpInst::Predicate Pred;
   Value *MaxLen;
   BasicBlock *EndBB, *WhileBB;
   if (!match(Header->getTerminator(),
-             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
+             m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(Index),
+                                 m_Value(MaxLen)),
                   m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
-      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+      !CurLoop->contains(WhileBB))
     return false;
 
   // WhileBB should contain the pattern of load & compare instructions. Match
   // the pattern and find the GEP instructions used by the loads.
-  ICmpInst::Predicate WhilePred;
   BasicBlock *FoundBB;
   BasicBlock *TrueBB;
   Value *LoadA, *LoadB;
   if (!match(WhileBB->getTerminator(),
-             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
+             m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(LoadA),
+                                 m_Value(LoadB)),
                   m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
-      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+      !CurLoop->contains(TrueBB))
     return false;
 
   Value *A, *B;

>From 37591ead98dfa6eb32982308eddaaf40a04ed281 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 28 Jul 2024 23:41:10 +0800
Subject: [PATCH 2/3] [InstCombine] Address review comments. NFC.

---
 llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 4341ad99f7175..1377992130a02 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1694,7 +1694,6 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
 
   // Canonicalize signum variant that ends in add:
   // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0))
-  ICmpInst::Predicate Pred;
   uint64_t BitWidth = Ty->getScalarSizeInBits();
   if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowPoison(BitWidth - 1))) &&
       match(RHS, m_OneUse(m_ZExt(m_OneUse(m_SpecificICmp(
@@ -1710,12 +1709,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     // (add X, (sext/zext (icmp eq X, C)))
     //    -> (select (icmp eq X, C), (add C, (sext/zext 1)), X)
     auto CondMatcher = m_CombineAnd(
-        m_Value(Cond), m_ICmp(Pred, m_Deferred(A), m_ImmConstant(C)));
+        m_Value(Cond), m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C)));
 
     if (match(&I,
               m_c_Add(m_Value(A),
                       m_CombineAnd(m_Value(Ext), m_ZExtOrSExt(CondMatcher)))) &&
-        Pred == ICmpInst::ICMP_EQ && Ext->hasOneUse()) {
+        Ext->hasOneUse()) {
       Value *Add = isa<ZExtInst>(Ext) ? InstCombiner::AddOne(C)
                                       : InstCombiner::SubOne(C);
       return replaceInstUsesWith(I, Builder.CreateSelect(Cond, Add, A));
@@ -1790,6 +1789,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // -->
   // BW - ctlz(A - 1, false)
   const APInt *XorC;
+  ICmpInst::Predicate Pred;
   if (match(&I,
             m_c_Add(
                 m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),

>From 8649a831f9387b7540605e15595f5fbd73883f3e Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 28 Jul 2024 23:59:47 +0800
Subject: [PATCH 3/3] [PatternMatch] Add m_c_SpecificICmp matcher

---
 llvm/include/llvm/IR/PatternMatch.h           | 24 +++++++++++++++----
 llvm/lib/Analysis/InstructionSimplify.cpp     | 11 ++++-----
 .../InstCombine/InstCombineAddSub.cpp         |  3 ++-
 llvm/unittests/IR/PatternMatch.cpp            |  8 +++++++
 4 files changed, 35 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index d9e27e087e705..9c70a60bf50c7 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1616,7 +1616,8 @@ m_FCmp(const LHS &L, const RHS &R) {
 
 // Same as CmpClass, but instead of saving Pred as out output variable, match a
 // specific input pred for equality.
-template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy>
+template <typename LHS_t, typename RHS_t, typename Class, typename PredicateTy,
+          bool Commutable = false>
 struct SpecificCmpClass_match {
   const PredicateTy Predicate;
   LHS_t L;
@@ -1626,9 +1627,17 @@ struct SpecificCmpClass_match {
       : Predicate(Pred), L(LHS), R(RHS) {}
 
   template <typename OpTy> bool match(OpTy *V) {
-    if (auto *I = dyn_cast<Class>(V))
-      return I->getPredicate() == Predicate && L.match(I->getOperand(0)) &&
-             R.match(I->getOperand(1));
+    if (auto *I = dyn_cast<Class>(V)) {
+      if (I->getPredicate() == Predicate && L.match(I->getOperand(0)) &&
+          R.match(I->getOperand(1)))
+        return true;
+      if constexpr (Commutable) {
+        if (I->getPredicate() == Class::getSwappedPredicate(Predicate) &&
+            L.match(I->getOperand(1)) && R.match(I->getOperand(0)))
+          return true;
+      }
+    }
+
     return false;
   }
 };
@@ -1647,6 +1656,13 @@ m_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
       MatchPred, L, R);
 }
 
+template <typename LHS, typename RHS>
+inline SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>
+m_c_SpecificICmp(ICmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
+  return SpecificCmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate, true>(
+      MatchPred, L, R);
+}
+
 template <typename LHS, typename RHS>
 inline SpecificCmpClass_match<LHS, RHS, FCmpInst, FCmpInst::Predicate>
 m_SpecificFCmp(FCmpInst::Predicate MatchPred, const LHS &L, const RHS &R) {
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index abf0735aa2c99..12a3193e63755 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -88,7 +88,7 @@ static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
   else
     return nullptr;
 
-  CmpInst::Predicate ExpectedPred, Pred1, Pred2;
+  CmpInst::Predicate ExpectedPred;
   if (BinOpCode == BinaryOperator::Or) {
     ExpectedPred = ICmpInst::ICMP_NE;
   } else if (BinOpCode == BinaryOperator::And) {
@@ -110,11 +110,10 @@ static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
   // -->
   // %TV
   Value *X, *Y;
-  if (!match(
-          Cond,
-          m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal), m_Specific(FalseVal)),
-                    m_SpecificICmp(ExpectedPred, m_Value(X), m_Value(Y)))) ||
-      Pred1 != ExpectedPred)
+  if (!match(Cond,
+             m_c_BinOp(m_c_SpecificICmp(ExpectedPred, m_Specific(TrueVal),
+                                        m_Specific(FalseVal)),
+                       m_SpecificICmp(ExpectedPred, m_Value(X), m_Value(Y)))))
     return nullptr;
 
   if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 1377992130a02..3bd086230cbec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1709,7 +1709,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     // (add X, (sext/zext (icmp eq X, C)))
     //    -> (select (icmp eq X, C), (add C, (sext/zext 1)), X)
     auto CondMatcher = m_CombineAnd(
-        m_Value(Cond), m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C)));
+        m_Value(Cond),
+        m_SpecificICmp(ICmpInst::ICMP_EQ, m_Deferred(A), m_ImmConstant(C)));
 
     if (match(&I,
               m_c_Add(m_Value(A),
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 309fcc93996bc..379f97fb63139 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -2317,6 +2317,14 @@ TYPED_TEST(MutableConstTest, ICmp) {
   EXPECT_FALSE(m_SpecificCmp(ICmpInst::getInversePredicate(Pred),
                              m_Value(MatchL), m_Value(MatchR))
                    .match((InstructionType)IRB.CreateICmp(Pred, L, R)));
+
+  EXPECT_TRUE(m_c_SpecificICmp(Pred, m_Specific(L), m_Specific(R))
+                  .match((InstructionType)IRB.CreateICmp(Pred, L, R)));
+  EXPECT_TRUE(m_c_SpecificICmp(ICmpInst::getSwappedPredicate(Pred),
+                               m_Specific(R), m_Specific(L))
+                  .match((InstructionType)IRB.CreateICmp(Pred, L, R)));
+  EXPECT_FALSE(m_c_SpecificICmp(Pred, m_Specific(R), m_Specific(L))
+                   .match((InstructionType)IRB.CreateICmp(Pred, L, R)));
 }
 
 TYPED_TEST(MutableConstTest, FCmp) {



More information about the llvm-commits mailing list