[llvm] ecfd919 - [ConstantFolding] Use ICmpInst::Predicate instead of plain integer

Serge Pavlov via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 30 00:09:49 PST 2021


Author: Serge Pavlov
Date: 2021-12-30T14:31:44+07:00
New Revision: ecfd9196d5dde5699d7fe3bd411949a56e01bd8c

URL: https://github.com/llvm/llvm-project/commit/ecfd9196d5dde5699d7fe3bd411949a56e01bd8c
DIFF: https://github.com/llvm/llvm-project/commit/ecfd9196d5dde5699d7fe3bd411949a56e01bd8c.diff

LOG: [ConstantFolding] Use ICmpInst::Predicate instead of plain integer

The function `ConstantFoldCompareInstruction` uses `unsigned short` to
represent compare predicate, although all usesrs of the respective
include file use definition of CmpInst also. This change replaces
predicate argument type in this function to `ICmpInst::Predicate`,
which allows to make code a bit clearer and simpler.

No functional changes.

Differential Revision: https://reviews.llvm.org/D116379

Added: 
    

Modified: 
    llvm/lib/IR/ConstantFold.cpp
    llvm/lib/IR/ConstantFold.h
    llvm/lib/IR/Constants.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index ae926f95cefe1..16b0880ce2f97 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -1704,7 +1704,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
   return ICmpInst::BAD_ICMP_PREDICATE;
 }
 
-Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
+Constant *llvm::ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
                                                Constant *C1, Constant *C2) {
   Type *ResultTy;
   if (VectorType *VT = dyn_cast<VectorType>(C1->getType()))
@@ -1714,10 +1714,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     ResultTy = Type::getInt1Ty(C1->getContext());
 
   // Fold FCMP_FALSE/FCMP_TRUE unconditionally.
-  if (pred == FCmpInst::FCMP_FALSE)
+  if (Predicate == FCmpInst::FCMP_FALSE)
     return Constant::getNullValue(ResultTy);
 
-  if (pred == FCmpInst::FCMP_TRUE)
+  if (Predicate == FCmpInst::FCMP_TRUE)
     return Constant::getAllOnesValue(ResultTy);
 
   // Handle some degenerate cases first
@@ -1725,7 +1725,6 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     return PoisonValue::get(ResultTy);
 
   if (isa<UndefValue>(C1) || isa<UndefValue>(C2)) {
-    CmpInst::Predicate Predicate = CmpInst::Predicate(pred);
     bool isIntegerPredicate = ICmpInst::isIntPredicate(Predicate);
     // For EQ and NE, we can always pick a value for the undef to make the
     // predicate pass or fail, so we can return undef.
@@ -1750,9 +1749,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage() &&
           !NullPointerIsDefined(nullptr /* F */,
                                 GV->getType()->getAddressSpace())) {
-        if (pred == ICmpInst::ICMP_EQ)
+        if (Predicate == ICmpInst::ICMP_EQ)
           return ConstantInt::getFalse(C1->getContext());
-        else if (pred == ICmpInst::ICMP_NE)
+        else if (Predicate == ICmpInst::ICMP_NE)
           return ConstantInt::getTrue(C1->getContext());
       }
   // icmp eq/ne(GV,null) -> false/true
@@ -1762,9 +1761,9 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage() &&
           !NullPointerIsDefined(nullptr /* F */,
                                 GV->getType()->getAddressSpace())) {
-        if (pred == ICmpInst::ICMP_EQ)
+        if (Predicate == ICmpInst::ICMP_EQ)
           return ConstantInt::getFalse(C1->getContext());
-        else if (pred == ICmpInst::ICMP_NE)
+        else if (Predicate == ICmpInst::ICMP_NE)
           return ConstantInt::getTrue(C1->getContext());
       }
     }
@@ -1772,16 +1771,16 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     // The caller is expected to commute the operands if the constant expression
     // is C2.
     // C1 >= 0 --> true
-    if (pred == ICmpInst::ICMP_UGE)
+    if (Predicate == ICmpInst::ICMP_UGE)
       return Constant::getAllOnesValue(ResultTy);
     // C1 < 0 --> false
-    if (pred == ICmpInst::ICMP_ULT)
+    if (Predicate == ICmpInst::ICMP_ULT)
       return Constant::getNullValue(ResultTy);
   }
 
   // If the comparison is a comparison between two i1's, simplify it.
   if (C1->getType()->isIntegerTy(1)) {
-    switch(pred) {
+    switch (Predicate) {
     case ICmpInst::ICMP_EQ:
       if (isa<ConstantInt>(C2))
         return ConstantExpr::getXor(C1, ConstantExpr::getNot(C2));
@@ -1796,12 +1795,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
     const APInt &V1 = cast<ConstantInt>(C1)->getValue();
     const APInt &V2 = cast<ConstantInt>(C2)->getValue();
-    return ConstantInt::get(
-        ResultTy, ICmpInst::compare(V1, V2, (ICmpInst::Predicate)pred));
+    return ConstantInt::get(ResultTy, ICmpInst::compare(V1, V2, Predicate));
   } else if (isa<ConstantFP>(C1) && isa<ConstantFP>(C2)) {
     const APFloat &C1V = cast<ConstantFP>(C1)->getValueAPF();
     const APFloat &C2V = cast<ConstantFP>(C2)->getValueAPF();
-    CmpInst::Predicate Predicate = CmpInst::Predicate(pred);
     return ConstantInt::get(ResultTy, FCmpInst::compare(C1V, C2V, Predicate));
   } else if (auto *C1VTy = dyn_cast<VectorType>(C1->getType())) {
 
@@ -1810,7 +1807,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       if (Constant *C2Splat = C2->getSplatValue())
         return ConstantVector::getSplat(
             C1VTy->getElementCount(),
-            ConstantExpr::getCompare(pred, C1Splat, C2Splat));
+            ConstantExpr::getCompare(Predicate, C1Splat, C2Splat));
 
     // Do not iterate on scalable vector. The number of elements is unknown at
     // compile-time.
@@ -1829,7 +1826,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       Constant *C2E =
           ConstantExpr::getExtractElement(C2, ConstantInt::get(Ty, I));
 
-      ResElts.push_back(ConstantExpr::getCompare(pred, C1E, C2E));
+      ResElts.push_back(ConstantExpr::getCompare(Predicate, C1E, C2E));
     }
 
     return ConstantVector::get(ResElts);
@@ -1854,46 +1851,52 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     case FCmpInst::BAD_FCMP_PREDICATE:
       break; // Couldn't determine anything about these constants.
     case FCmpInst::FCMP_OEQ: // We know that C1 == C2
-      Result = (pred == FCmpInst::FCMP_UEQ || pred == FCmpInst::FCMP_OEQ ||
-                pred == FCmpInst::FCMP_ULE || pred == FCmpInst::FCMP_OLE ||
-                pred == FCmpInst::FCMP_UGE || pred == FCmpInst::FCMP_OGE);
+      Result =
+          (Predicate == FCmpInst::FCMP_UEQ || Predicate == FCmpInst::FCMP_OEQ ||
+           Predicate == FCmpInst::FCMP_ULE || Predicate == FCmpInst::FCMP_OLE ||
+           Predicate == FCmpInst::FCMP_UGE || Predicate == FCmpInst::FCMP_OGE);
       break;
     case FCmpInst::FCMP_OLT: // We know that C1 < C2
-      Result = (pred == FCmpInst::FCMP_UNE || pred == FCmpInst::FCMP_ONE ||
-                pred == FCmpInst::FCMP_ULT || pred == FCmpInst::FCMP_OLT ||
-                pred == FCmpInst::FCMP_ULE || pred == FCmpInst::FCMP_OLE);
+      Result =
+          (Predicate == FCmpInst::FCMP_UNE || Predicate == FCmpInst::FCMP_ONE ||
+           Predicate == FCmpInst::FCMP_ULT || Predicate == FCmpInst::FCMP_OLT ||
+           Predicate == FCmpInst::FCMP_ULE || Predicate == FCmpInst::FCMP_OLE);
       break;
     case FCmpInst::FCMP_OGT: // We know that C1 > C2
-      Result = (pred == FCmpInst::FCMP_UNE || pred == FCmpInst::FCMP_ONE ||
-                pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT ||
-                pred == FCmpInst::FCMP_UGE || pred == FCmpInst::FCMP_OGE);
+      Result =
+          (Predicate == FCmpInst::FCMP_UNE || Predicate == FCmpInst::FCMP_ONE ||
+           Predicate == FCmpInst::FCMP_UGT || Predicate == FCmpInst::FCMP_OGT ||
+           Predicate == FCmpInst::FCMP_UGE || Predicate == FCmpInst::FCMP_OGE);
       break;
     case FCmpInst::FCMP_OLE: // We know that C1 <= C2
       // We can only partially decide this relation.
-      if (pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT)
+      if (Predicate == FCmpInst::FCMP_UGT || Predicate == FCmpInst::FCMP_OGT)
         Result = 0;
-      else if (pred == FCmpInst::FCMP_ULT || pred == FCmpInst::FCMP_OLT)
+      else if (Predicate == FCmpInst::FCMP_ULT ||
+               Predicate == FCmpInst::FCMP_OLT)
         Result = 1;
       break;
     case FCmpInst::FCMP_OGE: // We known that C1 >= C2
       // We can only partially decide this relation.
-      if (pred == FCmpInst::FCMP_ULT || pred == FCmpInst::FCMP_OLT)
+      if (Predicate == FCmpInst::FCMP_ULT || Predicate == FCmpInst::FCMP_OLT)
         Result = 0;
-      else if (pred == FCmpInst::FCMP_UGT || pred == FCmpInst::FCMP_OGT)
+      else if (Predicate == FCmpInst::FCMP_UGT ||
+               Predicate == FCmpInst::FCMP_OGT)
         Result = 1;
       break;
     case FCmpInst::FCMP_ONE: // We know that C1 != C2
       // We can only partially decide this relation.
-      if (pred == FCmpInst::FCMP_OEQ || pred == FCmpInst::FCMP_UEQ)
+      if (Predicate == FCmpInst::FCMP_OEQ || Predicate == FCmpInst::FCMP_UEQ)
         Result = 0;
-      else if (pred == FCmpInst::FCMP_ONE || pred == FCmpInst::FCMP_UNE)
+      else if (Predicate == FCmpInst::FCMP_ONE ||
+               Predicate == FCmpInst::FCMP_UNE)
         Result = 1;
       break;
     case FCmpInst::FCMP_UEQ: // We know that C1 == C2 || isUnordered(C1, C2).
       // We can only partially decide this relation.
-      if (pred == FCmpInst::FCMP_ONE)
+      if (Predicate == FCmpInst::FCMP_ONE)
         Result = 0;
-      else if (pred == FCmpInst::FCMP_UEQ)
+      else if (Predicate == FCmpInst::FCMP_UEQ)
         Result = 1;
       break;
     }
@@ -1905,67 +1908,84 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   } else {
     // Evaluate the relation between the two constants, per the predicate.
     int Result = -1;  // -1 = unknown, 0 = known false, 1 = known true.
-    switch (evaluateICmpRelation(C1, C2,
-                                 CmpInst::isSigned((CmpInst::Predicate)pred))) {
+    switch (evaluateICmpRelation(C1, C2, CmpInst::isSigned(Predicate))) {
     default: llvm_unreachable("Unknown relational!");
     case ICmpInst::BAD_ICMP_PREDICATE:
       break;  // Couldn't determine anything about these constants.
     case ICmpInst::ICMP_EQ:   // We know the constants are equal!
       // If we know the constants are equal, we can decide the result of this
       // computation precisely.
-      Result = ICmpInst::isTrueWhenEqual((ICmpInst::Predicate)pred);
+      Result = ICmpInst::isTrueWhenEqual(Predicate);
       break;
     case ICmpInst::ICMP_ULT:
-      switch (pred) {
+      switch (Predicate) {
       case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULE:
         Result = 1; break;
       case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGE:
         Result = 0; break;
+      default:
+        break;
       }
       break;
     case ICmpInst::ICMP_SLT:
-      switch (pred) {
+      switch (Predicate) {
       case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_SLE:
         Result = 1; break;
       case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_SGE:
         Result = 0; break;
+      default:
+        break;
       }
       break;
     case ICmpInst::ICMP_UGT:
-      switch (pred) {
+      switch (Predicate) {
       case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGE:
         Result = 1; break;
       case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE:
         Result = 0; break;
+      default:
+        break;
       }
       break;
     case ICmpInst::ICMP_SGT:
-      switch (pred) {
+      switch (Predicate) {
       case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_NE: case ICmpInst::ICMP_SGE:
         Result = 1; break;
       case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_SLE:
         Result = 0; break;
+      default:
+        break;
       }
       break;
     case ICmpInst::ICMP_ULE:
-      if (pred == ICmpInst::ICMP_UGT) Result = 0;
-      if (pred == ICmpInst::ICMP_ULT || pred == ICmpInst::ICMP_ULE) Result = 1;
+      if (Predicate == ICmpInst::ICMP_UGT)
+        Result = 0;
+      if (Predicate == ICmpInst::ICMP_ULT || Predicate == ICmpInst::ICMP_ULE)
+        Result = 1;
       break;
     case ICmpInst::ICMP_SLE:
-      if (pred == ICmpInst::ICMP_SGT) Result = 0;
-      if (pred == ICmpInst::ICMP_SLT || pred == ICmpInst::ICMP_SLE) Result = 1;
+      if (Predicate == ICmpInst::ICMP_SGT)
+        Result = 0;
+      if (Predicate == ICmpInst::ICMP_SLT || Predicate == ICmpInst::ICMP_SLE)
+        Result = 1;
       break;
     case ICmpInst::ICMP_UGE:
-      if (pred == ICmpInst::ICMP_ULT) Result = 0;
-      if (pred == ICmpInst::ICMP_UGT || pred == ICmpInst::ICMP_UGE) Result = 1;
+      if (Predicate == ICmpInst::ICMP_ULT)
+        Result = 0;
+      if (Predicate == ICmpInst::ICMP_UGT || Predicate == ICmpInst::ICMP_UGE)
+        Result = 1;
       break;
     case ICmpInst::ICMP_SGE:
-      if (pred == ICmpInst::ICMP_SLT) Result = 0;
-      if (pred == ICmpInst::ICMP_SGT || pred == ICmpInst::ICMP_SGE) Result = 1;
+      if (Predicate == ICmpInst::ICMP_SLT)
+        Result = 0;
+      if (Predicate == ICmpInst::ICMP_SGT || Predicate == ICmpInst::ICMP_SGE)
+        Result = 1;
       break;
     case ICmpInst::ICMP_NE:
-      if (pred == ICmpInst::ICMP_EQ) Result = 0;
-      if (pred == ICmpInst::ICMP_NE) Result = 1;
+      if (Predicate == ICmpInst::ICMP_EQ)
+        Result = 0;
+      if (Predicate == ICmpInst::ICMP_NE)
+        Result = 1;
       break;
     }
 
@@ -1983,16 +2003,16 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
           CE2->getType()->isVectorTy() == CE2Op0->getType()->isVectorTy() &&
           !CE2Op0->getType()->isFPOrFPVectorTy()) {
         Constant *Inverse = ConstantExpr::getBitCast(C1, CE2Op0->getType());
-        return ConstantExpr::getICmp(pred, Inverse, CE2Op0);
+        return ConstantExpr::getICmp(Predicate, Inverse, CE2Op0);
       }
     }
 
     // If the left hand side is an extension, try eliminating it.
     if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(C1)) {
       if ((CE1->getOpcode() == Instruction::SExt &&
-           ICmpInst::isSigned((ICmpInst::Predicate)pred)) ||
+           ICmpInst::isSigned(Predicate)) ||
           (CE1->getOpcode() == Instruction::ZExt &&
-           !ICmpInst::isSigned((ICmpInst::Predicate)pred))){
+           !ICmpInst::isSigned(Predicate))) {
         Constant *CE1Op0 = CE1->getOperand(0);
         Constant *CE1Inverse = ConstantExpr::getTrunc(CE1, CE1Op0->getType());
         if (CE1Inverse == CE1Op0) {
@@ -2000,7 +2020,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
           Constant *C2Inverse = ConstantExpr::getTrunc(C2, CE1Op0->getType());
           if (ConstantExpr::getCast(CE1->getOpcode(), C2Inverse,
                                     C2->getType()) == C2)
-            return ConstantExpr::getICmp(pred, CE1Inverse, C2Inverse);
+            return ConstantExpr::getICmp(Predicate, CE1Inverse, C2Inverse);
         }
       }
     }
@@ -2010,8 +2030,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       // If C2 is a constant expr and C1 isn't, flip them around and fold the
       // other way if possible.
       // Also, if C1 is null and C2 isn't, flip them around.
-      pred = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)pred);
-      return ConstantExpr::getICmp(pred, C2, C1);
+      Predicate = ICmpInst::getSwappedPredicate(Predicate);
+      return ConstantExpr::getICmp(Predicate, C2, C1);
     }
   }
   return nullptr;

diff  --git a/llvm/lib/IR/ConstantFold.h b/llvm/lib/IR/ConstantFold.h
index 0cdd5cf3cbce6..1aa44f4d21e58 100644
--- a/llvm/lib/IR/ConstantFold.h
+++ b/llvm/lib/IR/ConstantFold.h
@@ -19,6 +19,7 @@
 #define LLVM_LIB_IR_CONSTANTFOLD_H
 
 #include "llvm/ADT/Optional.h"
+#include "llvm/IR/InstrTypes.h"
 
 namespace llvm {
 template <typename T> class ArrayRef;
@@ -46,7 +47,7 @@ template <typename T> class ArrayRef;
   Constant *ConstantFoldUnaryInstruction(unsigned Opcode, Constant *V);
   Constant *ConstantFoldBinaryInstruction(unsigned Opcode, Constant *V1,
                                           Constant *V2);
-  Constant *ConstantFoldCompareInstruction(unsigned short predicate,
+  Constant *ConstantFoldCompareInstruction(CmpInst::Predicate Predicate,
                                            Constant *C1, Constant *C2);
   Constant *ConstantFoldGetElementPtr(Type *Ty, Constant *C, bool InBounds,
                                       Optional<unsigned> InRangeIndex,

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 837be910f6d81..e753fc7a38710 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2546,11 +2546,11 @@ Constant *ConstantExpr::getGetElementPtr(Type *Ty, Constant *C,
 
 Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS,
                                 Constant *RHS, bool OnlyIfReduced) {
+  auto Predicate = static_cast<CmpInst::Predicate>(pred);
   assert(LHS->getType() == RHS->getType());
-  assert(CmpInst::isIntPredicate((CmpInst::Predicate)pred) &&
-         "Invalid ICmp Predicate");
+  assert(CmpInst::isIntPredicate(Predicate) && "Invalid ICmp Predicate");
 
-  if (Constant *FC = ConstantFoldCompareInstruction(pred, LHS, RHS))
+  if (Constant *FC = ConstantFoldCompareInstruction(Predicate, LHS, RHS))
     return FC;          // Fold a few common cases...
 
   if (OnlyIfReduced)
@@ -2559,7 +2559,7 @@ Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS,
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { LHS, RHS };
   // Get the key type with both the opcode and predicate
-  const ConstantExprKeyType Key(Instruction::ICmp, ArgVec, pred);
+  const ConstantExprKeyType Key(Instruction::ICmp, ArgVec, Predicate);
 
   Type *ResultTy = Type::getInt1Ty(LHS->getContext());
   if (VectorType *VT = dyn_cast<VectorType>(LHS->getType()))
@@ -2571,11 +2571,11 @@ Constant *ConstantExpr::getICmp(unsigned short pred, Constant *LHS,
 
 Constant *ConstantExpr::getFCmp(unsigned short pred, Constant *LHS,
                                 Constant *RHS, bool OnlyIfReduced) {
+  auto Predicate = static_cast<CmpInst::Predicate>(pred);
   assert(LHS->getType() == RHS->getType());
-  assert(CmpInst::isFPPredicate((CmpInst::Predicate)pred) &&
-         "Invalid FCmp Predicate");
+  assert(CmpInst::isFPPredicate(Predicate) && "Invalid FCmp Predicate");
 
-  if (Constant *FC = ConstantFoldCompareInstruction(pred, LHS, RHS))
+  if (Constant *FC = ConstantFoldCompareInstruction(Predicate, LHS, RHS))
     return FC;          // Fold a few common cases...
 
   if (OnlyIfReduced)
@@ -2584,7 +2584,7 @@ Constant *ConstantExpr::getFCmp(unsigned short pred, Constant *LHS,
   // Look up the constant in the table first to ensure uniqueness
   Constant *ArgVec[] = { LHS, RHS };
   // Get the key type with both the opcode and predicate
-  const ConstantExprKeyType Key(Instruction::FCmp, ArgVec, pred);
+  const ConstantExprKeyType Key(Instruction::FCmp, ArgVec, Predicate);
 
   Type *ResultTy = Type::getInt1Ty(LHS->getContext());
   if (VectorType *VT = dyn_cast<VectorType>(LHS->getType()))


        


More information about the llvm-commits mailing list