[llvm] 25043c8 - [NFCI] Introduce `ICmpInst::compare()` and use it where appropriate

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 30 07:50:32 PDT 2021


Author: Roman Lebedev
Date: 2021-10-30T17:50:06+03:00
New Revision: 25043c8276644e684f8d14cd4cadaa87a7e99b0e

URL: https://github.com/llvm/llvm-project/commit/25043c8276644e684f8d14cd4cadaa87a7e99b0e
DIFF: https://github.com/llvm/llvm-project/commit/25043c8276644e684f8d14cd4cadaa87a7e99b0e.diff

LOG: [NFCI] Introduce `ICmpInst::compare()` and use it where appropriate

As noted in https://reviews.llvm.org/D90924#inline-1076197
apparently this is a pretty common pattern,
let's not repeat it yet again, but have it in a common place.

There may be some more places where it could be used,
but these are the most obvious ones.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/Analysis.h
    llvm/include/llvm/IR/Instructions.h
    llvm/include/llvm/IR/PatternMatch.h
    llvm/lib/CodeGen/Analysis.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/IR/ConstantFold.cpp
    llvm/lib/IR/Instructions.cpp
    llvm/lib/Transforms/IPO/AttributorAttributes.cpp
    llvm/unittests/IR/ConstantRangeTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/Analysis.h b/llvm/include/llvm/CodeGen/Analysis.h
index bdfb416d9bd96..60442326d6c79 100644
--- a/llvm/include/llvm/CodeGen/Analysis.h
+++ b/llvm/include/llvm/CodeGen/Analysis.h
@@ -104,9 +104,12 @@ ISD::CondCode getFCmpCodeWithoutNaN(ISD::CondCode CC);
 
 /// getICmpCondCode - Return the ISD condition code corresponding to
 /// the given LLVM IR integer condition code.
-///
 ISD::CondCode getICmpCondCode(ICmpInst::Predicate Pred);
 
+/// getICmpCondCode - Return the LLVM IR integer condition code
+/// corresponding to the given ISD integer condition code.
+ICmpInst::Predicate getICmpCondCode(ISD::CondCode Pred);
+
 /// Test if the given instruction is in a position to be optimized
 /// with a tail-call. This roughly means that it's in a block with
 /// a return and there's nothing that needs to be scheduled

diff  --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index aa8613fbdb889..3c7911d251364 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1349,6 +1349,10 @@ class ICmpInst: public CmpInst {
     Op<0>().swap(Op<1>());
   }
 
+  /// Return result of `LHS Pred RHS` comparison.
+  static bool compare(const APInt &LHS, const APInt &RHS,
+                      ICmpInst::Predicate Pred);
+
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Instruction *I) {
     return I->getOpcode() == Instruction::ICmp;

diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 6096cbdf0e421..483e927489187 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -593,32 +593,7 @@ inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
 struct icmp_pred_with_threshold {
   ICmpInst::Predicate Pred;
   const APInt *Thr;
-  bool isValue(const APInt &C) {
-    switch (Pred) {
-    case ICmpInst::Predicate::ICMP_EQ:
-      return C.eq(*Thr);
-    case ICmpInst::Predicate::ICMP_NE:
-      return C.ne(*Thr);
-    case ICmpInst::Predicate::ICMP_UGT:
-      return C.ugt(*Thr);
-    case ICmpInst::Predicate::ICMP_UGE:
-      return C.uge(*Thr);
-    case ICmpInst::Predicate::ICMP_ULT:
-      return C.ult(*Thr);
-    case ICmpInst::Predicate::ICMP_ULE:
-      return C.ule(*Thr);
-    case ICmpInst::Predicate::ICMP_SGT:
-      return C.sgt(*Thr);
-    case ICmpInst::Predicate::ICMP_SGE:
-      return C.sge(*Thr);
-    case ICmpInst::Predicate::ICMP_SLT:
-      return C.slt(*Thr);
-    case ICmpInst::Predicate::ICMP_SLE:
-      return C.sle(*Thr);
-    default:
-      llvm_unreachable("Unhandled ICmp predicate");
-    }
-  }
+  bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); }
 };
 /// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
 /// to Threshold. For vectors, this includes constants with undefined elements.

diff  --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp
index c21db82db7cc6..7d8a73e12d3a5 100644
--- a/llvm/lib/CodeGen/Analysis.cpp
+++ b/llvm/lib/CodeGen/Analysis.cpp
@@ -221,9 +221,6 @@ ISD::CondCode llvm::getFCmpCodeWithoutNaN(ISD::CondCode CC) {
   }
 }
 
-/// getICmpCondCode - Return the ISD condition code corresponding to
-/// the given LLVM IR integer condition code.
-///
 ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) {
   switch (Pred) {
   case ICmpInst::ICMP_EQ:  return ISD::SETEQ;
@@ -241,6 +238,33 @@ ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) {
   }
 }
 
+ICmpInst::Predicate llvm::getICmpCondCode(ISD::CondCode Pred) {
+  switch (Pred) {
+  case ISD::SETEQ:
+    return ICmpInst::ICMP_EQ;
+  case ISD::SETNE:
+    return ICmpInst::ICMP_NE;
+  case ISD::SETLE:
+    return ICmpInst::ICMP_SLE;
+  case ISD::SETULE:
+    return ICmpInst::ICMP_ULE;
+  case ISD::SETGE:
+    return ICmpInst::ICMP_SGE;
+  case ISD::SETUGE:
+    return ICmpInst::ICMP_UGE;
+  case ISD::SETLT:
+    return ICmpInst::ICMP_SLT;
+  case ISD::SETULT:
+    return ICmpInst::ICMP_ULT;
+  case ISD::SETGT:
+    return ICmpInst::ICMP_SGT;
+  case ISD::SETUGT:
+    return ICmpInst::ICMP_UGT;
+  default:
+    llvm_unreachable("Invalid ISD integer condition code!");
+  }
+}
+
 static bool isNoopBitcast(Type *T1, Type *T2,
                           const TargetLoweringBase& TLI) {
   return T1 == T2 || (T1->isPointerTy() && T2->isPointerTy()) ||

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b928fd3b78d4a..1c25f5f952917 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -28,6 +28,7 @@
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/FunctionLoweringInfo.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
@@ -2312,19 +2313,8 @@ SDValue SelectionDAG::FoldSetCC(EVT VT, SDValue N1, SDValue N2,
     if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1)) {
       const APInt &C1 = N1C->getAPIntValue();
 
-      switch (Cond) {
-      default: llvm_unreachable("Unknown integer setcc!");
-      case ISD::SETEQ:  return getBoolConstant(C1 == C2, dl, VT, OpVT);
-      case ISD::SETNE:  return getBoolConstant(C1 != C2, dl, VT, OpVT);
-      case ISD::SETULT: return getBoolConstant(C1.ult(C2), dl, VT, OpVT);
-      case ISD::SETUGT: return getBoolConstant(C1.ugt(C2), dl, VT, OpVT);
-      case ISD::SETULE: return getBoolConstant(C1.ule(C2), dl, VT, OpVT);
-      case ISD::SETUGE: return getBoolConstant(C1.uge(C2), dl, VT, OpVT);
-      case ISD::SETLT:  return getBoolConstant(C1.slt(C2), dl, VT, OpVT);
-      case ISD::SETGT:  return getBoolConstant(C1.sgt(C2), dl, VT, OpVT);
-      case ISD::SETLE:  return getBoolConstant(C1.sle(C2), dl, VT, OpVT);
-      case ISD::SETGE:  return getBoolConstant(C1.sge(C2), dl, VT, OpVT);
-      }
+      return getBoolConstant(ICmpInst::compare(C1, C2, getICmpCondCode(Cond)),
+                             dl, VT, OpVT);
     }
   }
 

diff  --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index e7357a8428fc2..7c49adb873dc1 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1792,19 +1792,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
     const APInt &V1 = cast<ConstantInt>(C1)->getValue();
     const APInt &V2 = cast<ConstantInt>(C2)->getValue();
-    switch (pred) {
-    default: llvm_unreachable("Invalid ICmp Predicate");
-    case ICmpInst::ICMP_EQ:  return ConstantInt::get(ResultTy, V1 == V2);
-    case ICmpInst::ICMP_NE:  return ConstantInt::get(ResultTy, V1 != V2);
-    case ICmpInst::ICMP_SLT: return ConstantInt::get(ResultTy, V1.slt(V2));
-    case ICmpInst::ICMP_SGT: return ConstantInt::get(ResultTy, V1.sgt(V2));
-    case ICmpInst::ICMP_SLE: return ConstantInt::get(ResultTy, V1.sle(V2));
-    case ICmpInst::ICMP_SGE: return ConstantInt::get(ResultTy, V1.sge(V2));
-    case ICmpInst::ICMP_ULT: return ConstantInt::get(ResultTy, V1.ult(V2));
-    case ICmpInst::ICMP_UGT: return ConstantInt::get(ResultTy, V1.ugt(V2));
-    case ICmpInst::ICMP_ULE: return ConstantInt::get(ResultTy, V1.ule(V2));
-    case ICmpInst::ICMP_UGE: return ConstantInt::get(ResultTy, V1.uge(V2));
-    }
+    return ConstantInt::get(
+        ResultTy, ICmpInst::compare(V1, V2, (ICmpInst::Predicate)pred));
   } else if (isa<ConstantFP>(C1) && isa<ConstantFP>(C2)) {
     const APFloat &C1V = cast<ConstantFP>(C1)->getValueAPF();
     const APFloat &C2V = cast<ConstantFP>(C2)->getValueAPF();

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 83be43261b3d7..1bc4ebc7ac162 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -4055,6 +4055,35 @@ bool CmpInst::isSigned(Predicate predicate) {
   }
 }
 
+bool ICmpInst::compare(const APInt &LHS, const APInt &RHS,
+                       ICmpInst::Predicate Pred) {
+  assert(ICmpInst::isIntPredicate(Pred) && "Only for integer predicates!");
+  switch (Pred) {
+  case ICmpInst::Predicate::ICMP_EQ:
+    return LHS.eq(RHS);
+  case ICmpInst::Predicate::ICMP_NE:
+    return LHS.ne(RHS);
+  case ICmpInst::Predicate::ICMP_UGT:
+    return LHS.ugt(RHS);
+  case ICmpInst::Predicate::ICMP_UGE:
+    return LHS.uge(RHS);
+  case ICmpInst::Predicate::ICMP_ULT:
+    return LHS.ult(RHS);
+  case ICmpInst::Predicate::ICMP_ULE:
+    return LHS.ule(RHS);
+  case ICmpInst::Predicate::ICMP_SGT:
+    return LHS.sgt(RHS);
+  case ICmpInst::Predicate::ICMP_SGE:
+    return LHS.sge(RHS);
+  case ICmpInst::Predicate::ICMP_SLT:
+    return LHS.slt(RHS);
+  case ICmpInst::Predicate::ICMP_SLE:
+    return LHS.sle(RHS);
+  default:
+    llvm_unreachable("Unexpected non-integer predicate.");
+  };
+}
+
 CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
   assert(CmpInst::isRelational(pred) &&
          "Call only with non-equality predicates!");

diff  --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 06f399129f3c8..badb118ec2a44 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -8651,31 +8651,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
 
   static bool calculateICmpInst(const ICmpInst *ICI, const APInt &LHS,
                                 const APInt &RHS) {
-    ICmpInst::Predicate Pred = ICI->getPredicate();
-    switch (Pred) {
-    case ICmpInst::ICMP_UGT:
-      return LHS.ugt(RHS);
-    case ICmpInst::ICMP_SGT:
-      return LHS.sgt(RHS);
-    case ICmpInst::ICMP_EQ:
-      return LHS.eq(RHS);
-    case ICmpInst::ICMP_UGE:
-      return LHS.uge(RHS);
-    case ICmpInst::ICMP_SGE:
-      return LHS.sge(RHS);
-    case ICmpInst::ICMP_ULT:
-      return LHS.ult(RHS);
-    case ICmpInst::ICMP_SLT:
-      return LHS.slt(RHS);
-    case ICmpInst::ICMP_NE:
-      return LHS.ne(RHS);
-    case ICmpInst::ICMP_ULE:
-      return LHS.ule(RHS);
-    case ICmpInst::ICMP_SLE:
-      return LHS.sle(RHS);
-    default:
-      llvm_unreachable("Invalid ICmp predicate!");
-    }
+    return ICmpInst::compare(LHS, RHS, ICI->getPredicate());
   }
 
   static APInt calculateCastInst(const CastInst *CI, const APInt &Src,

diff  --git a/llvm/unittests/IR/ConstantRangeTest.cpp b/llvm/unittests/IR/ConstantRangeTest.cpp
index 21533652b11c2..5e6bc88d5efbd 100644
--- a/llvm/unittests/IR/ConstantRangeTest.cpp
+++ b/llvm/unittests/IR/ConstantRangeTest.cpp
@@ -1557,41 +1557,15 @@ TEST(ConstantRange, MakeSatisfyingICmpRegion) {
       ConstantRange(APInt(8, 4), APInt(8, -128)));
 }
 
-static bool icmp(CmpInst::Predicate Pred, const APInt &LHS, const APInt &RHS) {
-  switch (Pred) {
-  case CmpInst::Predicate::ICMP_EQ:
-    return LHS.eq(RHS);
-  case CmpInst::Predicate::ICMP_NE:
-    return LHS.ne(RHS);
-  case CmpInst::Predicate::ICMP_UGT:
-    return LHS.ugt(RHS);
-  case CmpInst::Predicate::ICMP_UGE:
-    return LHS.uge(RHS);
-  case CmpInst::Predicate::ICMP_ULT:
-    return LHS.ult(RHS);
-  case CmpInst::Predicate::ICMP_ULE:
-    return LHS.ule(RHS);
-  case CmpInst::Predicate::ICMP_SGT:
-    return LHS.sgt(RHS);
-  case CmpInst::Predicate::ICMP_SGE:
-    return LHS.sge(RHS);
-  case CmpInst::Predicate::ICMP_SLT:
-    return LHS.slt(RHS);
-  case CmpInst::Predicate::ICMP_SLE:
-    return LHS.sle(RHS);
-  default:
-    llvm_unreachable("Not an ICmp predicate!");
-  }
-}
-
 void ICmpTestImpl(CmpInst::Predicate Pred) {
   unsigned Bits = 4;
   EnumerateTwoConstantRanges(
       Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) {
         bool Exhaustive = true;
         ForeachNumInConstantRange(CR1, [&](const APInt &N1) {
-          ForeachNumInConstantRange(
-              CR2, [&](const APInt &N2) { Exhaustive &= icmp(Pred, N1, N2); });
+          ForeachNumInConstantRange(CR2, [&](const APInt &N2) {
+            Exhaustive &= ICmpInst::compare(N1, N2, Pred);
+          });
         });
         EXPECT_EQ(CR1.icmp(Pred, CR2), Exhaustive);
       });


        


More information about the llvm-commits mailing list