[llvm-commits] [llvm] r70053 - in /llvm/trunk: lib/Transforms/Scalar/InstructionCombining.cpp test/Transforms/InstCombine/signed-comparison.ll

Dan Gohman gohman at apple.com
Sat Apr 25 10:12:48 PDT 2009


Author: djg
Date: Sat Apr 25 12:12:48 2009
New Revision: 70053

URL: http://llvm.org/viewvc/llvm-project?rev=70053&view=rev
Log:
Add several more icmp simplifications. Transform signed comparisons
into unsigned ones when the operands are known to have the same
sign bit value.

Added:
    llvm/trunk/test/Transforms/InstCombine/signed-comparison.ll
Modified:
    llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp

Modified: llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp?rev=70053&r1=70052&r2=70053&view=diff

==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp Sat Apr 25 12:12:48 2009
@@ -708,15 +708,13 @@
 // set of known zero and one bits, compute the maximum and minimum values that
 // could have the specified known zero and known one bits, returning them in
 // min/max.
-static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
-                                                   const APInt& KnownZero,
+static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero,
                                                    const APInt& KnownOne,
                                                    APInt& Min, APInt& Max) {
-  uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
-  assert(KnownZero.getBitWidth() == BitWidth && 
-         KnownOne.getBitWidth() == BitWidth &&
-         Min.getBitWidth() == BitWidth && Max.getBitWidth() == BitWidth &&
-         "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
+  assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
+         KnownZero.getBitWidth() == Min.getBitWidth() &&
+         KnownZero.getBitWidth() == Max.getBitWidth() &&
+         "KnownZero, KnownOne and Min, Max must have equal bitwidth.");
   APInt UnknownBits = ~(KnownZero|KnownOne);
 
   // The minimum value is when all unknown bits are zeros, EXCEPT for the sign
@@ -724,9 +722,9 @@
   Min = KnownOne;
   Max = KnownOne|UnknownBits;
   
-  if (UnknownBits[BitWidth-1]) { // Sign bit is unknown
-    Min.set(BitWidth-1);
-    Max.clear(BitWidth-1);
+  if (UnknownBits.isNegative()) { // Sign bit is unknown
+    Min.set(Min.getBitWidth()-1);
+    Max.clear(Max.getBitWidth()-1);
   }
 }
 
@@ -734,14 +732,12 @@
 // a set of known zero and one bits, compute the maximum and minimum values that
 // could have the specified known zero and known one bits, returning them in
 // min/max.
-static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty,
-                                                     const APInt &KnownZero,
+static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero,
                                                      const APInt &KnownOne,
                                                      APInt &Min, APInt &Max) {
-  uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth(); BitWidth = BitWidth;
-  assert(KnownZero.getBitWidth() == BitWidth && 
-         KnownOne.getBitWidth() == BitWidth &&
-         Min.getBitWidth() == BitWidth && Max.getBitWidth() &&
+  assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
+         KnownZero.getBitWidth() == Min.getBitWidth() &&
+         KnownZero.getBitWidth() == Max.getBitWidth() &&
          "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
   APInt UnknownBits = ~(KnownZero|KnownOne);
   
@@ -808,9 +804,13 @@
   assert(V != 0 && "Null pointer of Value???");
   assert(Depth <= 6 && "Limit Search Depth");
   uint32_t BitWidth = DemandedMask.getBitWidth();
-  const IntegerType *VTy = cast<IntegerType>(V->getType());
-  assert(VTy->getBitWidth() == BitWidth && 
-         KnownZero.getBitWidth() == BitWidth && 
+  const Type *VTy = V->getType();
+  assert((TD || !isa<PointerType>(VTy)) &&
+         "SimplifyDemandedBits needs to know bit widths!");
+  assert((!TD || TD->getTypeSizeInBits(VTy) == BitWidth) &&
+         (!isa<IntegerType>(VTy) ||
+          VTy->getPrimitiveSizeInBits() == BitWidth) &&
+         KnownZero.getBitWidth() == BitWidth &&
          KnownOne.getBitWidth() == BitWidth &&
          "Value *V, DemandedMask, KnownZero and KnownOne \
           must have same BitWidth");
@@ -820,7 +820,13 @@
     KnownZero = ~KnownOne & DemandedMask;
     return 0;
   }
-  
+  if (isa<ConstantPointerNull>(V)) {
+    // We know all of the bits for a constant!
+    KnownOne.clear();
+    KnownZero = DemandedMask;
+    return 0;
+  }
+
   KnownZero.clear();
   KnownOne.clear();
   if (DemandedMask == 0) {   // Not demanding any bits from V.
@@ -832,12 +838,15 @@
   if (Depth == 6)        // Limit search depth.
     return 0;
   
-  Instruction *I = dyn_cast<Instruction>(V);
-  if (!I) return 0;        // Only analyze instructions.
-  
   APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
   APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne;
 
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I) {
+    ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth);
+    return 0;        // Only analyze instructions.
+  }
+
   // If there are multiple uses of this value and we aren't at the root, then
   // we can't do any simplifications of the operands, because DemandedMask
   // only reflects the bits demanded by *one* of the users.
@@ -1399,8 +1408,12 @@
   
   // If the client is only demanding bits that we know, return the known
   // constant.
-  if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask)
-    return ConstantInt::get(RHSKnownOne);
+  if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) {
+    Constant *C = ConstantInt::get(RHSKnownOne);
+    if (isa<PointerType>(V->getType()))
+      C = ConstantExpr::getIntToPtr(C, V->getType());
+    return C;
+  }
   return false;
 }
 
@@ -5831,6 +5844,14 @@
     }
   }
 
+  unsigned BitWidth = 0;
+  if (TD)
+    BitWidth = TD->getTypeSizeInBits(Ty);
+  else if (isa<IntegerType>(Ty))
+    BitWidth = Ty->getPrimitiveSizeInBits();
+
+  bool isSignBit = false;
+
   // See if we are doing a comparison with a constant.
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
     Value *A = 0, *B = 0;
@@ -5865,105 +5886,161 @@
       return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI));
     }
     
-    // See if we can fold the comparison based on range information we can get
-    // by checking whether bits are known to be zero or one in the input.
-    uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
-    APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-    
     // If this comparison is a normal comparison, it demands all
     // bits, if it is a sign bit comparison, it only demands the sign bit.
     bool UnusedBit;
-    bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
-    
-    if (SimplifyDemandedBits(I.getOperandUse(0), 
+    isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
+  }
+
+  // See if we can fold the comparison based on range information we can get
+  // by checking whether bits are known to be zero or one in the input.
+  if (BitWidth != 0) {
+    APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
+    APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
+
+    if (SimplifyDemandedBits(I.getOperandUse(0),
                              isSignBit ? APInt::getSignBit(BitWidth)
                                        : APInt::getAllOnesValue(BitWidth),
-                             KnownZero, KnownOne, 0))
+                             Op0KnownZero, Op0KnownOne, 0))
       return &I;
-        
+    if (SimplifyDemandedBits(I.getOperandUse(1),
+                             APInt::getAllOnesValue(BitWidth),
+                             Op1KnownZero, Op1KnownOne, 0))
+      return &I;
+
     // Given the known and unknown bits, compute a range that the LHS could be
     // in.  Compute the Min, Max and RHS values based on the known bits. For the
     // EQ and NE we use unsigned values.
-    APInt Min(BitWidth, 0), Max(BitWidth, 0);
-    if (ICmpInst::isSignedPredicate(I.getPredicate()))
-      ComputeSignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, Max);
-    else
-      ComputeUnsignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne,Min,Max);
-    
+    APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
+    APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
+    if (ICmpInst::isSignedPredicate(I.getPredicate())) {
+      ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
+                                             Op0Min, Op0Max);
+      ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
+                                             Op1Min, Op1Max);
+    } else {
+      ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
+                                               Op0Min, Op0Max);
+      ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
+                                               Op1Min, Op1Max);
+    }
+
     // If Min and Max are known to be the same, then SimplifyDemandedBits
     // figured out that the LHS is a constant.  Just constant fold this now so
     // that code below can assume that Min != Max.
-    if (Min == Max)
-      return ReplaceInstUsesWith(I, ConstantExpr::getICmp(I.getPredicate(),
-                                                          ConstantInt::get(Min),
-                                                          CI));
-    
+    if (!isa<Constant>(Op0) && Op0Min == Op0Max)
+      return new ICmpInst(I.getPredicate(), ConstantInt::get(Op0Min), Op1);
+    if (!isa<Constant>(Op1) && Op1Min == Op1Max)
+      return new ICmpInst(I.getPredicate(), Op0, ConstantInt::get(Op1Min));
+
     // Based on the range information we know about the LHS, see if we can
     // simplify this comparison.  For example, (x&4) < 8  is always true.
-    const APInt &RHSVal = CI->getValue();
-    switch (I.getPredicate()) {  // LE/GE have been folded already.
+    switch (I.getPredicate()) {
     default: assert(0 && "Unknown icmp opcode!");
     case ICmpInst::ICMP_EQ:
-      if (Max.ult(RHSVal) || Min.ugt(RHSVal))
+      if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
         return ReplaceInstUsesWith(I, ConstantInt::getFalse());
       break;
     case ICmpInst::ICMP_NE:
-      if (Max.ult(RHSVal) || Min.ugt(RHSVal))
+      if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
         return ReplaceInstUsesWith(I, ConstantInt::getTrue());
       break;
     case ICmpInst::ICMP_ULT:
-      if (Max.ult(RHSVal))                    // A <u C -> true iff max(A) < C
+      if (Op0Max.ult(Op1Min))          // A <u B -> true if max(A) < min(B)
         return ReplaceInstUsesWith(I, ConstantInt::getTrue());
-      if (Min.uge(RHSVal))                    // A <u C -> false iff min(A) >= C
+      if (Op0Min.uge(Op1Max))          // A <u B -> false if min(A) >= max(B)
         return ReplaceInstUsesWith(I, ConstantInt::getFalse());
-      if (RHSVal == Max)                      // A <u MAX -> A != MAX
+      if (Op1Min == Op0Max)            // A <u B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      if (RHSVal == Min+1)                    // A <u MIN+1 -> A == MIN
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
-        
-      // (x <u 2147483648) -> (x >s -1)  -> true if sign bit clear
-      if (CI->isMinValue(true))
-        return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+        if (Op1Max == Op0Min+1)        // A <u C -> A == C-1 if min(A)+1 == C
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
+
+        // (x <u 2147483648) -> (x >s -1)  -> true if sign bit clear
+        if (CI->isMinValue(true))
+          return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
                             ConstantInt::getAllOnesValue(Op0->getType()));
+      }
       break;
     case ICmpInst::ICMP_UGT:
-      if (Min.ugt(RHSVal))                    // A >u C -> true iff min(A) > C
+      if (Op0Min.ugt(Op1Max))          // A >u B -> true if min(A) > max(B)
         return ReplaceInstUsesWith(I, ConstantInt::getTrue());
-      if (Max.ule(RHSVal))                    // A >u C -> false iff max(A) <= C
+      if (Op0Max.ule(Op1Min))          // A >u B -> false if max(A) <= max(B)
         return ReplaceInstUsesWith(I, ConstantInt::getFalse());
-        
-      if (RHSVal == Min)                      // A >u MIN -> A != MIN
+
+      if (Op1Max == Op0Min)            // A >u B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      if (RHSVal == Max-1)                    // A >u MAX-1 -> A == MAX
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
-      
-      // (x >u 2147483647) -> (x <s 0)  -> true if sign bit set
-      if (CI->isMaxValue(true))
-        return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
-                            ConstantInt::getNullValue(Op0->getType()));
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+        if (Op1Min == Op0Max-1)        // A >u C -> A == C+1 if max(a)-1 == C
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
+
+        // (x >u 2147483647) -> (x <s 0)  -> true if sign bit set
+        if (CI->isMaxValue(true))
+          return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
+                              ConstantInt::getNullValue(Op0->getType()));
+      }
       break;
     case ICmpInst::ICMP_SLT:
-      if (Max.slt(RHSVal))                    // A <s C -> true iff max(A) < C
+      if (Op0Max.slt(Op1Min))          // A <s B -> true if max(A) < min(C)
         return ReplaceInstUsesWith(I, ConstantInt::getTrue());
-      if (Min.sge(RHSVal))                    // A <s C -> false iff min(A) >= C
+      if (Op0Min.sge(Op1Max))          // A <s B -> false if min(A) >= max(C)
         return ReplaceInstUsesWith(I, ConstantInt::getFalse());
-      if (RHSVal == Max)                      // A <s MAX -> A != MAX
+      if (Op1Min == Op0Max)            // A <s B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      if (RHSVal == Min+1)                    // A <s MIN+1 -> A == MIN
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+        if (Op1Max == Op0Min+1)        // A <s C -> A == C-1 if min(A)+1 == C
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
+      }
       break;
-    case ICmpInst::ICMP_SGT: 
-      if (Min.sgt(RHSVal))                    // A >s C -> true iff min(A) > C
+    case ICmpInst::ICMP_SGT:
+      if (Op0Min.sgt(Op1Max))          // A >s B -> true if min(A) > max(B)
         return ReplaceInstUsesWith(I, ConstantInt::getTrue());
-      if (Max.sle(RHSVal))                    // A >s C -> false iff max(A) <= C
+      if (Op0Max.sle(Op1Min))          // A >s B -> false if max(A) <= min(B)
         return ReplaceInstUsesWith(I, ConstantInt::getFalse());
-        
-      if (RHSVal == Min)                      // A >s MIN -> A != MIN
+
+      if (Op1Max == Op0Min)            // A >s B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      if (RHSVal == Max-1)                    // A >s MAX-1 -> A == MAX
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+        if (Op1Min == Op0Max-1)        // A >s C -> A == C+1 if max(A)-1 == C
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
+      }
+      break;
+    case ICmpInst::ICMP_SGE:
+      assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
+      if (Op0Min.sge(Op1Max))          // A >=s B -> true if min(A) >= max(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+      if (Op0Max.slt(Op1Min))          // A >=s B -> false if max(A) < min(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse());
+      break;
+    case ICmpInst::ICMP_SLE:
+      assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
+      if (Op0Max.sle(Op1Min))          // A <=s B -> true if max(A) <= min(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+      if (Op0Min.sgt(Op1Max))          // A <=s B -> false if min(A) > max(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse());
+      break;
+    case ICmpInst::ICMP_UGE:
+      assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
+      if (Op0Min.uge(Op1Max))          // A >=u B -> true if min(A) >= max(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+      if (Op0Max.ult(Op1Min))          // A >=u B -> false if max(A) < min(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse());
+      break;
+    case ICmpInst::ICMP_ULE:
+      assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
+      if (Op0Max.ule(Op1Min))          // A <=u B -> true if max(A) <= min(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+      if (Op0Min.ugt(Op1Max))          // A <=u B -> false if min(A) > max(B)
+        return ReplaceInstUsesWith(I, ConstantInt::getFalse());
       break;
     }
+
+    // Turn a signed comparison into an unsigned one if both operands
+    // are known to have the same sign.
+    if (I.isSignedPredicate() &&
+        ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
+         (Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
+      return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
   }
 
   // Test if the ICmpInst instruction is used exclusively by a select as

Added: llvm/trunk/test/Transforms/InstCombine/signed-comparison.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/signed-comparison.ll?rev=70053&view=auto

==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/signed-comparison.ll (added)
+++ llvm/trunk/test/Transforms/InstCombine/signed-comparison.ll Sat Apr 25 12:12:48 2009
@@ -0,0 +1,28 @@
+; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t
+; RUN: not grep zext %t
+; RUN: not grep slt %t
+; RUN: grep {icmp ult} %t
+
+; Instcombine should convert the zext+slt into a simple ult.
+
+define void @foo(double* %p) nounwind {
+entry:
+	br label %bb
+
+bb:
+	%indvar = phi i64 [ 0, %entry ], [ %indvar.next, %bb ]
+	%t0 = and i64 %indvar, 65535
+	%t1 = getelementptr double* %p, i64 %t0
+	%t2 = load double* %t1, align 8
+	%t3 = mul double %t2, 2.2
+	store double %t3, double* %t1, align 8
+	%i.04 = trunc i64 %indvar to i16
+	%t4 = add i16 %i.04, 1
+	%t5 = zext i16 %t4 to i32
+	%t6 = icmp slt i32 %t5, 500
+	%indvar.next = add i64 %indvar, 1
+	br i1 %t6, label %bb, label %return
+
+return:
+	ret void
+}





More information about the llvm-commits mailing list