[llvm-commits] [llvm] r125734 - in /llvm/trunk: lib/Transforms/InstCombine/InstCombineCompares.cpp test/Transforms/InstCombine/icmp.ll

Duncan Sands baldrick at free.fr
Wed Feb 16 23:46:37 PST 2011


Author: baldrick
Date: Thu Feb 17 01:46:37 2011
New Revision: 125734

URL: http://llvm.org/viewvc/llvm-project?rev=125734&view=rev
Log:
Transform "A + B >= A + C" into "B >= C" if the adds do not wrap.  Likewise for some
variations (some of these were already present so I unified the code).  Spotted by my
auto-simplifier as occurring a lot.

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/trunk/test/Transforms/InstCombine/icmp.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=125734&r1=125733&r2=125734&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Thu Feb 17 01:46:37 2011
@@ -698,13 +698,6 @@
   if (Pred == ICmpInst::ICMP_NE)
     return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(X->getContext()));
 
-  // If this is an instruction (as opposed to constantexpr) get NUW/NSW info.
-  bool isNUW = false, isNSW = false;
-  if (BinaryOperator *Add = dyn_cast<BinaryOperator>(TheAdd)) {
-    isNUW = Add->hasNoUnsignedWrap();
-    isNSW = Add->hasNoSignedWrap();
-  }      
-  
   // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
   // so the values can never be equal.  Similiarly for all other "or equals"
   // operators.
@@ -713,10 +706,6 @@
   // (X+2) <u X        --> X >u (MAXUINT-2)        --> X > 253
   // (X+MAXUINT) <u X  --> X >u (MAXUINT-MAXUINT)  --> X != 0
   if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
-    // If this is an NUW add, then this is always false.
-    if (isNUW)
-      return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(X->getContext())); 
-    
     Value *R = 
       ConstantExpr::getSub(ConstantInt::getAllOnesValue(CI->getType()), CI);
     return new ICmpInst(ICmpInst::ICMP_UGT, X, R);
@@ -725,12 +714,8 @@
   // (X+1) >u X        --> X <u (0-1)        --> X != 255
   // (X+2) >u X        --> X <u (0-2)        --> X <u 254
   // (X+MAXUINT) >u X  --> X <u (0-MAXUINT)  --> X <u 1  --> X == 0
-  if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
-    // If this is an NUW add, then this is always true.
-    if (isNUW)
-      return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(X->getContext())); 
+  if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
     return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantExpr::getNeg(CI));
-  }
   
   unsigned BitWidth = CI->getType()->getPrimitiveSizeInBits();
   ConstantInt *SMax = ConstantInt::get(X->getContext(),
@@ -742,16 +727,8 @@
   // (X+MINSINT) <s X  --> X >s (MAXSINT-MINSINT)    --> X >s -1
   // (X+ -2) <s X      --> X >s (MAXSINT- -2)        --> X >s 126
   // (X+ -1) <s X      --> X >s (MAXSINT- -1)        --> X != 127
-  if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
-    // If this is an NSW add, then we have two cases: if the constant is
-    // positive, then this is always false, if negative, this is always true.
-    if (isNSW) {
-      bool isTrue = CI->getValue().isNegative();
-      return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue));
-    }
-    
+  if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
     return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantExpr::getSub(SMax, CI));
-  }
   
   // (X+ 1) >s X       --> X <s (MAXSINT-(1-1))       --> X != 127
   // (X+ 2) >s X       --> X <s (MAXSINT-(2-1))       --> X <s 126
@@ -760,13 +737,6 @@
   // (X+ -2) >s X      --> X <s (MAXSINT-(-2-1))      --> X <s -126
   // (X+ -1) >s X      --> X <s (MAXSINT-(-1-1))      --> X == -128
   
-  // If this is an NSW add, then we have two cases: if the constant is
-  // positive, then this is always true, if negative, this is always false.
-  if (isNSW) {
-    bool isTrue = !CI->getValue().isNegative();
-    return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue));
-  }
-  
   assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE);
   Constant *C = ConstantInt::get(X->getContext(), CI->getValue()-1);
   return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C));
@@ -2263,60 +2233,115 @@
       if (Instruction *R = visitICmpInstWithCastAndCast(I))
         return R;
   }
-  
-  // See if it's the same type of instruction on the left and right.
-  if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
-    if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) {
-      if (Op0I->getOpcode() == Op1I->getOpcode() && Op0I->hasOneUse() &&
-          Op1I->hasOneUse() && Op0I->getOperand(1) == Op1I->getOperand(1)) {
-        switch (Op0I->getOpcode()) {
-        default: break;
-        case Instruction::Add:
-        case Instruction::Sub:
-        case Instruction::Xor:
-          if (I.isEquality())    // a+x icmp eq/ne b+x --> a icmp b
-            return new ICmpInst(I.getPredicate(), Op0I->getOperand(0),
-                                Op1I->getOperand(0));
-          // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
-            if (CI->getValue().isSignBit()) {
-              ICmpInst::Predicate Pred = I.isSigned()
-                                             ? I.getUnsignedPredicate()
-                                             : I.getSignedPredicate();
-              return new ICmpInst(Pred, Op0I->getOperand(0),
-                                  Op1I->getOperand(0));
-            }
-            
-            if (CI->getValue().isMaxSignedValue()) {
-              ICmpInst::Predicate Pred = I.isSigned()
-                                             ? I.getUnsignedPredicate()
-                                             : I.getSignedPredicate();
-              Pred = I.getSwappedPredicate(Pred);
-              return new ICmpInst(Pred, Op0I->getOperand(0),
-                                  Op1I->getOperand(0));
-            }
+
+  // Special logic for binary operators.
+  BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0);
+  BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1);
+  if (BO0 || BO1) {
+    CmpInst::Predicate Pred = I.getPredicate();
+    bool NoOp0WrapProblem = false, NoOp1WrapProblem = false;
+    if (BO0 && isa<OverflowingBinaryOperator>(BO0))
+      NoOp0WrapProblem = ICmpInst::isEquality(Pred) ||
+        (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) ||
+        (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap());
+    if (BO1 && isa<OverflowingBinaryOperator>(BO1))
+      NoOp1WrapProblem = ICmpInst::isEquality(Pred) ||
+        (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) ||
+        (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap());
+
+    // Analyze the case when either Op0 or Op1 is an add instruction.
+    // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null).
+    Value *A = 0, *B = 0, *C = 0, *D = 0;
+    if (BO0 && BO0->getOpcode() == Instruction::Add)
+      A = BO0->getOperand(0), B = BO0->getOperand(1);
+    if (BO1 && BO1->getOpcode() == Instruction::Add)
+      C = BO1->getOperand(0), D = BO1->getOperand(1);
+
+    // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
+    if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
+      return new ICmpInst(Pred, A == Op1 ? B : A,
+                          Constant::getNullValue(Op1->getType()));
+
+    // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow.
+    if ((C == Op0 || D == Op0) && NoOp1WrapProblem)
+      return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()),
+                          C == Op0 ? D : C);
+
+    // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow.
+    if (A && C && (A == C || A == D || B == C || B == D) &&
+        NoOp0WrapProblem && NoOp1WrapProblem &&
+        // Try not to increase register pressure.
+        BO0->hasOneUse() && BO1->hasOneUse()) {
+      // Determine Y and Z in the form icmp (X+Y), (X+Z).
+      Value *Y = (A == C || A == D) ? B : A;
+      Value *Z = (C == A || C == B) ? D : C;
+      return new ICmpInst(Pred, Y, Z);
+    }
+
+    // Analyze the case when either Op0 or Op1 is a sub instruction.
+    // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null).
+    A = 0; B = 0; C = 0; D = 0;
+    if (BO0 && BO0->getOpcode() == Instruction::Sub)
+      A = BO0->getOperand(0), B = BO0->getOperand(1);
+    if (BO1 && BO1->getOpcode() == Instruction::Sub)
+      C = BO1->getOperand(0), D = BO1->getOperand(1);
+
+    // icmp (Y-X), (Z-X) -> icmp Y,Z for equalities or if there is no overflow.
+    if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem &&
+        // Try not to increase register pressure.
+        BO0->hasOneUse() && BO1->hasOneUse())
+      return new ICmpInst(Pred, A, C);
+
+    if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() &&
+        BO0->hasOneUse() && BO1->hasOneUse() &&
+        BO0->getOperand(1) == BO1->getOperand(1)) {
+      switch (BO0->getOpcode()) {
+      default: break;
+      case Instruction::Add:
+      case Instruction::Sub:
+      case Instruction::Xor:
+        if (I.isEquality())    // a+x icmp eq/ne b+x --> a icmp b
+          return new ICmpInst(I.getPredicate(), BO0->getOperand(0),
+                              BO1->getOperand(0));
+        // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b
+        if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) {
+          if (CI->getValue().isSignBit()) {
+            ICmpInst::Predicate Pred = I.isSigned()
+                                           ? I.getUnsignedPredicate()
+                                           : I.getSignedPredicate();
+            return new ICmpInst(Pred, BO0->getOperand(0),
+                                BO1->getOperand(0));
           }
-          break;
-        case Instruction::Mul:
-          if (!I.isEquality())
-            break;
-
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
-            // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask
-            // Mask = -1 >> count-trailing-zeros(Cst).
-            if (!CI->isZero() && !CI->isOne()) {
-              const APInt &AP = CI->getValue();
-              ConstantInt *Mask = ConstantInt::get(I.getContext(), 
-                                      APInt::getLowBitsSet(AP.getBitWidth(),
-                                                           AP.getBitWidth() -
-                                                      AP.countTrailingZeros()));
-              Value *And1 = Builder->CreateAnd(Op0I->getOperand(0), Mask);
-              Value *And2 = Builder->CreateAnd(Op1I->getOperand(0), Mask);
-              return new ICmpInst(I.getPredicate(), And1, And2);
-            }
+          
+          if (CI->getValue().isMaxSignedValue()) {
+            ICmpInst::Predicate Pred = I.isSigned()
+                                           ? I.getUnsignedPredicate()
+                                           : I.getSignedPredicate();
+            Pred = I.getSwappedPredicate(Pred);
+            return new ICmpInst(Pred, BO0->getOperand(0),
+                                BO1->getOperand(0));
           }
+        }
+        break;
+      case Instruction::Mul:
+        if (!I.isEquality())
           break;
+
+        if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) {
+          // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask
+          // Mask = -1 >> count-trailing-zeros(Cst).
+          if (!CI->isZero() && !CI->isOne()) {
+            const APInt &AP = CI->getValue();
+            ConstantInt *Mask = ConstantInt::get(I.getContext(), 
+                                    APInt::getLowBitsSet(AP.getBitWidth(),
+                                                         AP.getBitWidth() -
+                                                    AP.countTrailingZeros()));
+            Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask);
+            Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask);
+            return new ICmpInst(I.getPredicate(), And1, And2);
+          }
         }
+        break;
       }
     }
   }
@@ -2400,18 +2425,6 @@
       return new ICmpInst(I.getPredicate(), B,
                           Constant::getNullValue(B->getType()));
 
-    // (A+B) == A  ->  B == 0
-    if (match(Op0, m_Add(m_Specific(Op1), m_Value(B))) ||
-        match(Op0, m_Add(m_Value(B), m_Specific(Op1))))
-      return new ICmpInst(I.getPredicate(), B,
-                          Constant::getNullValue(B->getType()));
-
-    // A == (A+B)  ->  B == 0
-    if (match(Op1, m_Add(m_Specific(Op0), m_Value(B))) ||
-        match(Op1, m_Add(m_Value(B), m_Specific(Op0))))
-      return new ICmpInst(I.getPredicate(), B,
-                          Constant::getNullValue(B->getType()));
-
     // (X&Z) == (Y&Z) -> (X^Y) & Z == 0
     if (Op0->hasOneUse() && Op1->hasOneUse() &&
         match(Op0, m_And(m_Value(A), m_Value(B))) && 

Modified: llvm/trunk/test/Transforms/InstCombine/icmp.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/icmp.ll?rev=125734&r1=125733&r2=125734&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/icmp.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/icmp.ll Thu Feb 17 01:46:37 2011
@@ -234,3 +234,22 @@
   ret i1 %cmp
 }
 
+; CHECK: @test25
+; CHECK: %c = icmp sgt i32 %x, %y
+; CHECK: ret i1 %c
+define i1 @test25(i32 %x, i32 %y, i32 %z) {
+  %lhs = add nsw i32 %x, %z
+  %rhs = add nsw i32 %y, %z
+  %c = icmp sgt i32 %lhs, %rhs
+  ret i1 %c
+}
+
+; CHECK: @test26
+; CHECK: %c = icmp sgt i32 %x, %y
+; CHECK: ret i1 %c
+define i1 @test26(i32 %x, i32 %y, i32 %z) {
+  %lhs = sub nsw i32 %x, %z
+  %rhs = sub nsw i32 %y, %z
+  %c = icmp sgt i32 %lhs, %rhs
+  ret i1 %c
+}





More information about the llvm-commits mailing list