[llvm] cf15515 - [AggressiveInstCombine] Add support for ICmp instr that feeds a select intsr's condition operand.

Ayman Musa via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 05:01:46 PST 2020


Author: Ayman Musa
Date: 2020-02-12T15:01:27+02:00
New Revision: cf155150f992270c88e586ffc61973d2552b72e8

URL: https://github.com/llvm/llvm-project/commit/cf155150f992270c88e586ffc61973d2552b72e8
DIFF: https://github.com/llvm/llvm-project/commit/cf155150f992270c88e586ffc61973d2552b72e8.diff

LOG: [AggressiveInstCombine] Add support for ICmp instr that feeds a select intsr's condition operand.

Added: 
    

Modified: 
    llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 0cfc6ef47d4b..3581a24851de 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -36,6 +36,24 @@ using namespace llvm;
 
 #define DEBUG_TYPE "aggressive-instcombine"
 
+// This function returns true if Value V is a constant or if it's a type
+// extension node.
+static bool isConstOrExt(Value *V) {
+  if (isa<Constant>(V))
+    return true;
+
+  if (Instruction *I = dyn_cast<Instruction>(V)) {
+    switch(I->getOpcode()) {
+    case Instruction::ZExt:
+    case Instruction::SExt:
+      return true;
+    default:
+      return false;
+    }
+  }
+  return false;
+}
+
 /// Given an instruction and a container, it fills all the relevant operands of
 /// that instruction, with respect to the Trunc expression dag optimizaton.
 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
@@ -53,12 +71,20 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor:
+  case Instruction::ICmp:
     Ops.push_back(I->getOperand(0));
     Ops.push_back(I->getOperand(1));
     break;
   case Instruction::Select:
+    Value *Op0 = I->getOperand(0);
     Ops.push_back(I->getOperand(1));
     Ops.push_back(I->getOperand(2));
+    // In case the condition is a compare instruction, that both of its operands
+    // are a type extension/truncate or a constant, that can be shrinked without
+    // loosing information in the compare instruction, add them as well.
+    if (CmpInst *C = dyn_cast<CmpInst>(Op0))
+      if (isConstOrExt(C->getOperand(0)) && isConstOrExt(C->getOperand(1)))
+        Ops.push_back(Op0);
     break;
   default:
     llvm_unreachable("Unreachable!");
@@ -119,7 +145,8 @@ bool TruncInstCombine::buildTruncExpressionDag() {
     case Instruction::And:
     case Instruction::Or:
     case Instruction::Xor:
-    case Instruction::Select: {
+    case Instruction::Select:
+    case Instruction::ICmp: {
       SmallVector<Value *, 2> Operands;
       getRelevantOperands(I, Operands);
       for (Value *Operand : Operands)
@@ -139,6 +166,21 @@ bool TruncInstCombine::buildTruncExpressionDag() {
   return true;
 }
 
+// Get the minimum number of bits needed for the given constant.
+static unsigned getConstMinBitWidth(bool IsSigned, ConstantInt *C) {
+  // If the const value is signed and negative, count the leading ones.
+  if (IsSigned) {
+    int64_t Val = C->getSExtValue();
+    uint64_t UVal = (uint64_t)Val;
+    if (Val < 0)
+      return sizeof(UVal)*8 - countLeadingOnes(UVal) + 1;
+  }
+  // Otherwise, count leading zeroes.
+  uint64_t Val = C->getZExtValue();
+  auto MinBits = sizeof(Val)*8 - countLeadingZeros(Val);
+  return IsSigned ? MinBits + 1 : MinBits;
+}
+
 unsigned TruncInstCombine::getMinBitWidth() {
   SmallVector<Value *, 8> Worklist;
   SmallVector<Instruction *, 8> Stack;
@@ -180,6 +222,13 @@ unsigned TruncInstCombine::getMinBitWidth() {
         if (auto *IOp = dyn_cast<Instruction>(Operand))
           Info.MinBitWidth =
               std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
+        else if (auto *C = dyn_cast<ConstantInt>(Operand)) {
+          // In case of Cmp instruction, make sure the constant can be truncated
+          // without losing information.
+          if (CmpInst *Cmp = dyn_cast<CmpInst>(I))
+            Info.MinBitWidth = std::max(
+                Info.MinBitWidth, getConstMinBitWidth(Cmp->isSigned(), C));
+        }
       continue;
     }
 
@@ -193,14 +242,27 @@ unsigned TruncInstCombine::getMinBitWidth() {
 
     for (auto *Operand : Operands)
       if (auto *IOp = dyn_cast<Instruction>(Operand)) {
-        // If we already calculated the minimum bit-width for this valid
-        // bit-width, or for a smaller valid bit-width, then just keep the
-        // answer we already calculated.
-        unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
-        if (IOpBitwidth >= ValidBitWidth)
-          continue;
-        InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
-        Worklist.push_back(IOp);
+        if (isa<CmpInst>(I)) {
+          // Cmp instructions kind of resets the valid bits analysis for its
+          // operands, as it does not continue with the same calculation chain
+          // but rather creates a new chain of its own.
+          switch (IOp->getOpcode()) {
+          case Instruction::SExt:
+          case Instruction::ZExt:
+            InstInfoMap[IOp].ValidBitWidth =
+                cast<CastInst>(IOp)->getSrcTy()->getScalarSizeInBits();
+            break;
+          }
+        } else {
+          // If we already calculated the minimum bit-width for this valid
+          // bit-width, or for a smaller valid bit-width, then just keep the
+          // answer we already calculated.
+          unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
+          if (IOpBitwidth >= ValidBitWidth)
+            continue;
+          InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
+          Worklist.push_back(IOp);
+        }
       }
   }
   unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;
@@ -363,6 +425,13 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
       Res = Builder.CreateSelect(Op0, LHS, RHS);
       break;
     }
+    case Instruction::ICmp: {
+      auto ICmp = cast<ICmpInst>(I);
+      Value *LHS = getReducedOperand(ICmp->getOperand(0), SclTy);
+      Value *RHS = getReducedOperand(ICmp->getOperand(1), SclTy);
+      Res = Builder.CreateICmp(ICmp->getPredicate(), LHS, RHS);
+      break;
+    }
     default:
       llvm_unreachable("Unhandled instruction");
     }


        


More information about the llvm-commits mailing list