[llvm-commits] [llvm] r125266 - in /llvm/trunk: lib/Transforms/InstCombine/InstCombineCompares.cpp test/Transforms/InstCombine/exact.ll test/Transforms/InstCombine/nsw.ll

Chris Lattner sabre at nondot.org
Wed Feb 9 21:23:05 PST 2011


Author: lattner
Date: Wed Feb  9 23:23:05 2011
New Revision: 125266

URL: http://llvm.org/viewvc/llvm-project?rev=125266&view=rev
Log:
Enhance the "compare with shift" and "compare with div" 
optimizations to be much more aggressive in the face of
exact/nsw/nuw div and shifts.  For example, these (which
are the same except the first is 'exact' sdiv:

define i1 @sdiv_icmp4_exact(i64 %X) nounwind {
  %A = sdiv exact i64 %X, -5   ; X/-5 == 0 --> x == 0
  %B = icmp eq i64 %A, 0
  ret i1 %B
}

define i1 @sdiv_icmp4(i64 %X) nounwind {
  %A = sdiv i64 %X, -5   ; X/-5 == 0 --> x == 0
  %B = icmp eq i64 %A, 0
  ret i1 %B
}

compile down to:

define i1 @sdiv_icmp4_exact(i64 %X) nounwind {
  %1 = icmp eq i64 %X, 0
  ret i1 %1
}

define i1 @sdiv_icmp4(i64 %X) nounwind {
  %X.off = add i64 %X, 4
  %1 = icmp ult i64 %X.off, 9
  ret i1 %1
}

This happens when you do something like:
  (ptr1-ptr2) == 42

where the pointers are pointers to non-unit types.


Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/trunk/test/Transforms/InstCombine/exact.ll
    llvm/trunk/test/Transforms/InstCombine/nsw.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=125266&r1=125265&r2=125266&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Wed Feb  9 23:23:05 2011
@@ -22,13 +22,17 @@
 using namespace llvm;
 using namespace PatternMatch;
 
+static ConstantInt *getOne(Constant *C) {
+  return ConstantInt::get(cast<IntegerType>(C->getType()), 1);
+}
+
 /// AddOne - Add one to a ConstantInt
 static Constant *AddOne(Constant *C) {
   return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
 }
 /// SubOne - Subtract one from a ConstantInt
-static Constant *SubOne(ConstantInt *C) {
-  return ConstantExpr::getSub(C,  ConstantInt::get(C->getType(), 1));
+static Constant *SubOne(Constant *C) {
+  return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
 }
 
 static ConstantInt *ExtractElement(Constant *V, Constant *Idx) {
@@ -782,7 +786,7 @@
   // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even
   // (x /u C1) <u C2.  Simply casting the operands and result won't 
   // work. :(  The if statement below tests that condition and bails 
-  // if it finds it. 
+  // if it finds it.
   bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv;
   if (!ICI.isEquality() && DivIsSigned != ICI.isSigned())
     return 0;
@@ -809,6 +813,10 @@
   // Get the ICmp opcode
   ICmpInst::Predicate Pred = ICI.getPredicate();
 
+  /// If the division is known to be exact, then there is no remainder from the
+  /// divide, so the covered range size is unit, otherwise it is the divisor.
+  ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS;
+  
   // Figure out the interval that is being checked.  For example, a comparison
   // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). 
   // Compute this interval based on the constants involved and the signedness of
@@ -818,38 +826,43 @@
   // -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
   int LoOverflow = 0, HiOverflow = 0;
   Constant *LoBound = 0, *HiBound = 0;
-  
+
   if (!DivIsSigned) {  // udiv
     // e.g. X/5 op 3  --> [15, 20)
     LoBound = Prod;
     HiOverflow = LoOverflow = ProdOV;
-    if (!HiOverflow)
-      HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false);
+    if (!HiOverflow) {
+      // If this is not an exact divide, then many values in the range collapse
+      // to the same result value.
+      HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false);
+    }
+    
   } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0.
     if (CmpRHSV == 0) {       // (X / pos) op 0
       // Can't overflow.  e.g.  X/2 op 0 --> [-1, 2)
-      LoBound = cast<ConstantInt>(ConstantExpr::getNeg(SubOne(DivRHS)));
-      HiBound = DivRHS;
+      LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
+      HiBound = RangeSize;
     } else if (CmpRHSV.isStrictlyPositive()) {   // (X / pos) op pos
       LoBound = Prod;     // e.g.   X/5 op 3 --> [15, 20)
       HiOverflow = LoOverflow = ProdOV;
       if (!HiOverflow)
-        HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true);
+        HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true);
     } else {                       // (X / pos) op neg
       // e.g. X/5 op -3  --> [-15-4, -15+1) --> [-19, -14)
       HiBound = AddOne(Prod);
       LoOverflow = HiOverflow = ProdOV ? -1 : 0;
       if (!LoOverflow) {
-        ConstantInt* DivNeg =
-                         cast<ConstantInt>(ConstantExpr::getNeg(DivRHS));
+        ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
         LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
-       }
+      }
     }
   } else if (DivRHS->getValue().isNegative()) { // Divisor is < 0.
+    if (DivI->isExact())
+      RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
     if (CmpRHSV == 0) {       // (X / neg) op 0
       // e.g. X/-5 op 0  --> [-4, 5)
-      LoBound = AddOne(DivRHS);
-      HiBound = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS));
+      LoBound = AddOne(RangeSize);
+      HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
       if (HiBound == DivRHS) {     // -INTMIN = INTMIN
         HiOverflow = 1;            // [INTMIN+1, overflow)
         HiBound = 0;               // e.g. X/INTMIN = 0 --> X > INTMIN
@@ -859,12 +872,12 @@
       HiBound = AddOne(Prod);
       HiOverflow = LoOverflow = ProdOV ? -1 : 0;
       if (!LoOverflow)
-        LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0;
+        LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
     } else {                       // (X / neg) op neg
       LoBound = Prod;       // e.g. X/-5 op -3  --> [15, 20)
       LoOverflow = HiOverflow = ProdOV;
       if (!HiOverflow)
-        HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true);
+        HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true);
     }
     
     // Dividing by a negative swaps the condition.  LT <-> GT
@@ -883,9 +896,8 @@
     if (LoOverflow)
       return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
                           ICmpInst::ICMP_ULT, X, HiBound);
-    return ReplaceInstUsesWith(ICI,
-                               InsertRangeTest(X, LoBound, HiBound, DivIsSigned,
-                                               true));
+    return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound,
+                                                    DivIsSigned, true));
   case ICmpInst::ICMP_NE:
     if (LoOverflow && HiOverflow)
       return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext()));
@@ -908,12 +920,11 @@
   case ICmpInst::ICMP_SGT:
     if (HiOverflow == +1)       // High bound greater than input range.
       return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext()));
-    else if (HiOverflow == -1)  // High bound less than input range.
+    if (HiOverflow == -1)       // High bound less than input range.
       return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext()));
     if (Pred == ICmpInst::ICMP_UGT)
       return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
-    else
-      return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
+    return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
   }
 }
 
@@ -1182,6 +1193,12 @@
         return ReplaceInstUsesWith(ICI, Cst);
       }
       
+      // If the shift is NUW, then it is just shifting out zeros, no need for an
+      // AND.
+      if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap())
+        return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
+                            ConstantExpr::getLShr(RHS, ShAmt));
+      
       if (LHSI->hasOneUse()) {
         // Otherwise strength reduce the shift into an and.
         uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
@@ -1192,8 +1209,7 @@
         Value *And =
           Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask");
         return new ICmpInst(ICI.getPredicate(), And,
-                            ConstantInt::get(ICI.getContext(),
-                                             RHSV.lshr(ShAmtVal)));
+                            ConstantExpr::getLShr(RHS, ShAmt));
       }
     }
     
@@ -1222,10 +1238,9 @@
     // undefined shifts.  When the shift is visited it will be
     // simplified.
     uint32_t TypeBits = RHSV.getBitWidth();
-    if (ShAmt->uge(TypeBits))
-      break;
-    
     uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
+    if (ShAmtVal >= TypeBits)
+      break;
       
     // If we are comparing against bits always shifted out, the
     // comparison cannot succeed.
@@ -1245,13 +1260,10 @@
     // Otherwise, check to see if the bits shifted out are known to be zero.
     // If so, we can compare against the unshifted value:
     //  (X & 4) >> 1 == 2  --> (X & 4) == 4.
-    if (LHSI->hasOneUse() &&
-        MaskedValueIsZero(LHSI->getOperand(0), 
-                          APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) {
+    if (LHSI->hasOneUse() && cast<BinaryOperator>(LHSI)->isExact())
       return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
                           ConstantExpr::getShl(RHS, ShAmt));
-    }
-      
+    
     if (LHSI->hasOneUse()) {
       // Otherwise strength reduce the shift into an and.
       APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal));
@@ -1911,14 +1923,12 @@
         
         // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
         // then turn "((8 >>u x)&1) == 0" into "x != 3".
-        ConstantInt *CI = 0;
+        const APInt *CI;
         if (Op0KnownZeroInverted == 1 &&
-            match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) &&
-            CI->getValue().isPowerOf2()) {
-          unsigned CmpVal = CI->getValue().countTrailingZeros();
+            match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
           return new ICmpInst(ICmpInst::ICMP_NE, X,
-                              ConstantInt::get(X->getType(), CmpVal));
-        }
+                              ConstantInt::get(X->getType(),
+                                               CI->countTrailingZeros()));
       }
         
       break;
@@ -1950,14 +1960,12 @@
         
         // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
         // then turn "((8 >>u x)&1) != 0" into "x == 3".
-        ConstantInt *CI = 0;
+        const APInt *CI;
         if (Op0KnownZeroInverted == 1 &&
-            match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) &&
-            CI->getValue().isPowerOf2()) {
-          unsigned CmpVal = CI->getValue().countTrailingZeros();
+            match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
           return new ICmpInst(ICmpInst::ICMP_EQ, X,
-                              ConstantInt::get(X->getType(), CmpVal));
-        }
+                              ConstantInt::get(X->getType(),
+                                               CI->countTrailingZeros()));
       }
       
       break;

Modified: llvm/trunk/test/Transforms/InstCombine/exact.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/exact.ll?rev=125266&r1=125265&r2=125266&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/exact.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/exact.ll Wed Feb  9 23:23:05 2011
@@ -1,60 +1,119 @@
 ; RUN: opt < %s -instcombine -S | FileCheck %s
 
-; CHECK: define i32 @foo
+; CHECK: @sdiv1
 ; CHECK: sdiv i32 %x, 8
-define i32 @foo(i32 %x) {
+define i32 @sdiv1(i32 %x) {
   %y = sdiv i32 %x, 8
   ret i32 %y
 }
 
-; CHECK: define i32 @bar
-; CHECK: ashr i32 %x, 3
-define i32 @bar(i32 %x) {
-  %y = sdiv exact i32 %x, 8
-  ret i32 %y
-}
-
-; CHECK: i32 @a0
+; CHECK: @sdiv3
 ; CHECK: %y = srem i32 %x, 3
 ; CHECK: %z = sub i32 %x, %y
 ; CHECK: ret i32 %z
-define i32 @a0(i32 %x) {
+define i32 @sdiv3(i32 %x) {
   %y = sdiv i32 %x, 3
   %z = mul i32 %y, 3
   ret i32 %z
 }
 
-; CHECK: i32 @b0
+; CHECK: @sdiv4
 ; CHECK: ret i32 %x
-define i32 @b0(i32 %x) {
+define i32 @sdiv4(i32 %x) {
   %y = sdiv exact i32 %x, 3
   %z = mul i32 %y, 3
   ret i32 %z
 }
 
-; CHECK: i32 @a1
+; CHECK: i32 @sdiv5
 ; CHECK: %y = srem i32 %x, 3
 ; CHECK: %z = sub i32 %y, %x
 ; CHECK: ret i32 %z
-define i32 @a1(i32 %x) {
+define i32 @sdiv5(i32 %x) {
   %y = sdiv i32 %x, 3
   %z = mul i32 %y, -3
   ret i32 %z
 }
 
-; CHECK: i32 @b1
+; CHECK: @sdiv6
 ; CHECK: %z = sub i32 0, %x
 ; CHECK: ret i32 %z
-define i32 @b1(i32 %x) {
+define i32 @sdiv6(i32 %x) {
   %y = sdiv exact i32 %x, 3
   %z = mul i32 %y, -3
   ret i32 %z
 }
 
-; CHECK: i32 @b2
+; CHECK: @udiv1
 ; CHECK: ret i32 %x
-define i32 @b2(i32 %x, i32 %w) {
+define i32 @udiv1(i32 %x, i32 %w) {
   %y = udiv exact i32 %x, %w
   %z = mul i32 %y, %w
   ret i32 %z
 }
+
+; CHECK: @ashr_icmp
+; CHECK: %B = icmp eq i64 %X, 0
+; CHECK: ret i1 %B
+define i1 @ashr_icmp(i64 %X) nounwind {
+  %A = ashr exact i64 %X, 2   ; X/4
+  %B = icmp eq i64 %A, 0
+  ret i1 %B
+}
+
+; CHECK: @udiv_icmp1
+; CHECK: icmp ne i64 %X, 0
+define i1 @udiv_icmp1(i64 %X) nounwind {
+  %A = udiv exact i64 %X, 5   ; X/5
+  %B = icmp ne i64 %A, 0
+  ret i1 %B
+}
+
+; CHECK: @sdiv_icmp1
+; CHECK: icmp eq i64 %X, 0
+define i1 @sdiv_icmp1(i64 %X) nounwind {
+  %A = sdiv exact i64 %X, 5   ; X/5 == 0 --> x == 0
+  %B = icmp eq i64 %A, 0
+  ret i1 %B
+}
+
+; CHECK: @sdiv_icmp2
+; CHECK: icmp eq i64 %X, 5
+define i1 @sdiv_icmp2(i64 %X) nounwind {
+  %A = sdiv exact i64 %X, 5   ; X/5 == 1 --> x == 5
+  %B = icmp eq i64 %A, 1
+  ret i1 %B
+}
+
+; CHECK: @sdiv_icmp3
+; CHECK: icmp eq i64 %X, -5
+define i1 @sdiv_icmp3(i64 %X) nounwind {
+  %A = sdiv exact i64 %X, 5   ; X/5 == -1 --> x == -5
+  %B = icmp eq i64 %A, -1
+  ret i1 %B
+}
+
+; CHECK: @sdiv_icmp4
+; CHECK: icmp eq i64 %X, 0
+define i1 @sdiv_icmp4(i64 %X) nounwind {
+  %A = sdiv exact i64 %X, -5   ; X/-5 == 0 --> x == 0
+  %B = icmp eq i64 %A, 0
+  ret i1 %B
+}
+
+; CHECK: @sdiv_icmp5
+; CHECK: icmp eq i64 %X, -5
+define i1 @sdiv_icmp5(i64 %X) nounwind {
+  %A = sdiv exact i64 %X, -5   ; X/-5 == 1 --> x == -5
+  %B = icmp eq i64 %A, 1
+  ret i1 %B
+}
+
+; CHECK: @sdiv_icmp6
+; CHECK: icmp eq i64 %X, 5
+define i1 @sdiv_icmp6(i64 %X) nounwind {
+  %A = sdiv exact i64 %X, -5   ; X/-5 == 1 --> x == 5
+  %B = icmp eq i64 %A, -1
+  ret i1 %B
+}
+

Modified: llvm/trunk/test/Transforms/InstCombine/nsw.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/nsw.ll?rev=125266&r1=125265&r2=125266&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/nsw.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/nsw.ll Wed Feb  9 23:23:05 2011
@@ -1,20 +1,30 @@
 ; RUN: opt < %s -instcombine -S | FileCheck %s
 
-; CHECK: define i32 @foo
-; %y = sub i32 0, %x
-; %z = sdiv i32 %y, 337
-; ret i32 %y
-define i32 @foo(i32 %x) {
+; CHECK: @sub1
+; CHECK: %y = sub i32 0, %x
+; CHECK: %z = sdiv i32 %y, 337
+; CHECK: ret i32 %z
+define i32 @sub1(i32 %x) {
   %y = sub i32 0, %x
   %z = sdiv i32 %y, 337
-  ret i32 %y
+  ret i32 %z
 }
 
-; CHECK: define i32 @bar
-; %y = sdiv i32 %x, -337
-; ret i32 %y
-define i32 @bar(i32 %x) {
+; CHECK: @sub2
+; CHECK: %z = sdiv i32 %x, -337
+; CHECK: ret i32 %z
+define i32 @sub2(i32 %x) {
   %y = sub nsw i32 0, %x
   %z = sdiv i32 %y, 337
-  ret i32 %y
+  ret i32 %z
 }
+
+; CHECK: @shl_icmp
+; CHECK: %B = icmp eq i64 %X, 0
+; CHECK: ret i1 %B
+define i1 @shl_icmp(i64 %X) nounwind {
+  %A = shl nuw i64 %X, 2   ; X/4
+  %B = icmp eq i64 %A, 0
+  ret i1 %B
+}
+





More information about the llvm-commits mailing list