[llvm-commits] [llvm] r58457 - in /llvm/trunk: include/llvm/Support/PatternMatch.h lib/Transforms/Scalar/InstructionCombining.cpp test/Transforms/InstCombine/logical-select.ll

Dan Gohman gohman at apple.com
Thu Oct 30 13:40:11 PDT 2008


Author: djg
Date: Thu Oct 30 15:40:10 2008
New Revision: 58457

URL: http://llvm.org/viewvc/llvm-project?rev=58457&view=rev
Log:
Canonicalize sext(i1) to i1?-1:0, and update various instcombine
optimizations accordingly.

Modified:
    llvm/trunk/include/llvm/Support/PatternMatch.h
    llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp
    llvm/trunk/test/Transforms/InstCombine/logical-select.ll

Modified: llvm/trunk/include/llvm/Support/PatternMatch.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Support/PatternMatch.h?rev=58457&r1=58456&r2=58457&view=diff

==============================================================================
--- llvm/trunk/include/llvm/Support/PatternMatch.h (original)
+++ llvm/trunk/include/llvm/Support/PatternMatch.h Thu Oct 30 15:40:10 2008
@@ -51,6 +51,22 @@
 /// m_ConstantInt() - Match an arbitrary ConstantInt and ignore it.
 inline leaf_ty<ConstantInt> m_ConstantInt() { return leaf_ty<ConstantInt>(); }
 
+struct constantint_ty {
+  int64_t Val;
+  explicit constantint_ty(int64_t val) : Val(val) {}
+
+  template<typename ITy>
+  bool match(ITy *V) {
+    return isa<ConstantInt>(V) && cast<ConstantInt>(V)->getSExtValue() == Val;
+  }
+};
+
+/// m_ConstantInt(int64_t) - Match a ConstantInt with a specific value
+/// and ignore it.
+inline constantint_ty m_ConstantInt(int64_t Val) {
+  return constantint_ty(Val);
+}
+
 struct zero_ty {
   template<typename ITy>
   bool match(ITy *V) {
@@ -322,6 +338,36 @@
 }
 
 //===----------------------------------------------------------------------===//
+// Matchers for SelectInst classes
+//
+
+template<typename Cond_t, typename LHS_t, typename RHS_t>
+struct SelectClass_match {
+  Cond_t C;
+  LHS_t L;
+  RHS_t R;
+
+  SelectClass_match(const Cond_t &Cond, const LHS_t &LHS,
+                    const RHS_t &RHS)
+    : C(Cond), L(LHS), R(RHS) {}
+
+  template<typename OpTy>
+  bool match(OpTy *V) {
+    if (SelectInst *I = dyn_cast<SelectInst>(V))
+      return C.match(I->getOperand(0)) &&
+             L.match(I->getOperand(1)) &&
+             R.match(I->getOperand(2));
+    return false;
+  }
+};
+
+template<typename Cond, typename LHS, typename RHS>
+inline SelectClass_match<Cond, RHS, LHS>
+m_Select(const Cond &C, const LHS &L, const RHS &R) {
+  return SelectClass_match<Cond, LHS, RHS>(C, L, R);
+}
+
+//===----------------------------------------------------------------------===//
 // Matchers for CastInst classes
 //
 

Modified: llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp?rev=58457&r1=58456&r2=58457&view=diff

==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/InstructionCombining.cpp Thu Oct 30 15:40:10 2008
@@ -2012,6 +2012,14 @@
                                  KnownZero, KnownOne))
           return &I;
       }
+
+      // zext(i1) - 1  ->  select i1, 0, -1
+      if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
+        if (CI->isAllOnesValue() &&
+            ZI->getOperand(0)->getType() == Type::Int1Ty)
+          return SelectInst::Create(ZI->getOperand(0),
+                                    Constant::getNullValue(I.getType()),
+                                    ConstantInt::getAllOnesValue(I.getType()));
     }
 
     if (isa<PHINode>(LHS))
@@ -4338,24 +4346,55 @@
       }
     }
 
-    // (A & sext(C0)) | (B & ~sext(C0) ->  C0 ? A : B
-    if (isa<SExtInst>(C) &&
-        cast<User>(C)->getOperand(0)->getType() == Type::Int1Ty) {
+    // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) ->  C0 ? A : B, and commuted variants
+    if (match(A, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) {
+      if (match(D, m_Not(m_Value(A))))
+        return SelectInst::Create(cast<User>(A)->getOperand(0), C, B);
+      if (match(B, m_Not(m_Value(A))))
+        return SelectInst::Create(cast<User>(A)->getOperand(0), C, D);
+    }
+    if (match(B, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) {
+      if (match(C, m_Not(m_Value(B))))
+        return SelectInst::Create(cast<User>(B)->getOperand(0), A, D);
+      if (match(A, m_Not(m_Value(B))))
+        return SelectInst::Create(cast<User>(B)->getOperand(0), C, D);
+    }
+    if (match(C, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) {
       if (match(D, m_Not(m_Value(C))))
         return SelectInst::Create(cast<User>(C)->getOperand(0), A, B);
-      // And commutes, try both ways.
       if (match(B, m_Not(m_Value(C))))
         return SelectInst::Create(cast<User>(C)->getOperand(0), A, D);
     }
-    // Or commutes, try both ways.
-    if (isa<SExtInst>(D) &&
-        cast<User>(D)->getOperand(0)->getType() == Type::Int1Ty) {
+    if (match(D, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) {
       if (match(C, m_Not(m_Value(D))))
         return SelectInst::Create(cast<User>(D)->getOperand(0), A, B);
-      // And commutes, try both ways.
       if (match(A, m_Not(m_Value(D))))
         return SelectInst::Create(cast<User>(D)->getOperand(0), C, B);
     }
+    if (match(A, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) {
+      if (match(D, m_Not(m_Value(A))))
+        return SelectInst::Create(cast<User>(A)->getOperand(0), B, C);
+      if (match(B, m_Not(m_Value(A))))
+        return SelectInst::Create(cast<User>(A)->getOperand(0), D, C);
+    }
+    if (match(B, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) {
+      if (match(C, m_Not(m_Value(B))))
+        return SelectInst::Create(cast<User>(B)->getOperand(0), D, A);
+      if (match(A, m_Not(m_Value(B))))
+        return SelectInst::Create(cast<User>(B)->getOperand(0), D, C);
+    }
+    if (match(C, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) {
+      if (match(D, m_Not(m_Value(C))))
+        return SelectInst::Create(cast<User>(C)->getOperand(0), B, A);
+      if (match(B, m_Not(m_Value(C))))
+        return SelectInst::Create(cast<User>(C)->getOperand(0), D, A);
+    }
+    if (match(D, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) {
+      if (match(C, m_Not(m_Value(D))))
+        return SelectInst::Create(cast<User>(D)->getOperand(0), B, A);
+      if (match(A, m_Not(m_Value(D))))
+        return SelectInst::Create(cast<User>(D)->getOperand(0), B, C);
+    }
   }
   
   // (X >> Z) | (Y >> Z)  -> (X|Y) >> Z  for all shifts.
@@ -7965,37 +8004,11 @@
   
   Value *Src = CI.getOperand(0);
   
-  // sext (x <s 0) -> ashr x, 31   -> all ones if signed
-  // sext (x >s -1) -> ashr x, 31  -> all ones if not signed
-  if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) {
-    // If we are just checking for a icmp eq of a single bit and zext'ing it
-    // to an integer, then shift the bit to the appropriate place and then
-    // cast to integer to avoid the comparison.
-    if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) {
-      const APInt &Op1CV = Op1C->getValue();
-      
-      // sext (x <s  0) to i32 --> x>>s31      true if signbit set.
-      // sext (x >s -1) to i32 --> (x>>s31)^-1  true if signbit clear.
-      if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) ||
-          (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())){
-        Value *In = ICI->getOperand(0);
-        Value *Sh = ConstantInt::get(In->getType(),
-                                     In->getType()->getPrimitiveSizeInBits()-1);
-        In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh,
-                                                        In->getName()+".lobit"),
-                                 CI);
-        if (In->getType() != CI.getType())
-          In = CastInst::CreateIntegerCast(In, CI.getType(),
-                                           true/*SExt*/, "tmp", &CI);
-        
-        if (ICI->getPredicate() == ICmpInst::ICMP_SGT)
-          In = InsertNewInstBefore(BinaryOperator::CreateNot(In,
-                                     In->getName()+".not"), CI);
-        
-        return ReplaceInstUsesWith(CI, In);
-      }
-    }
-  }
+  // Canonicalize sign-extend from i1 to a select.
+  if (Src->getType() == Type::Int1Ty)
+    return SelectInst::Create(Src,
+                              ConstantInt::getAllOnesValue(CI.getType()),
+                              Constant::getNullValue(CI.getType()));
 
   // See if the value being truncated is already sign extended.  If so, just
   // eliminate the trunc/sext pair.
@@ -8468,7 +8481,7 @@
   // can be adjusted to fit the min/max idiom. We may edit ICI in
   // place here, so make sure the select is the only user.
   if (ICI->hasOneUse())
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS))
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) {
       switch (Pred) {
       default: break;
       case ICmpInst::ICMP_ULT:
@@ -8513,6 +8526,44 @@
       }
       }
 
+      // (x <s 0) ? -1 : 0 -> ashr x, 31   -> all ones if signed
+      // (x >s -1) ? -1 : 0 -> ashr x, 31  -> all ones if not signed
+      CmpInst::Predicate Pred = ICI->getPredicate();
+      if (match(TrueVal, m_ConstantInt(0)) &&
+          match(FalseVal, m_ConstantInt(-1)))
+        Pred = CmpInst::getInversePredicate(Pred);
+      else if (!match(TrueVal, m_ConstantInt(-1)) ||
+               !match(FalseVal, m_ConstantInt(0)))
+        Pred = CmpInst::BAD_ICMP_PREDICATE;
+      if (Pred != CmpInst::BAD_ICMP_PREDICATE) {
+        // If we are just checking for a icmp eq of a single bit and zext'ing it
+        // to an integer, then shift the bit to the appropriate place and then
+        // cast to integer to avoid the comparison.
+        const APInt &Op1CV = CI->getValue();
+    
+        // sext (x <s  0) to i32 --> x>>s31      true if signbit set.
+        // sext (x >s -1) to i32 --> (x>>s31)^-1  true if signbit clear.
+        if ((Pred == ICmpInst::ICMP_SLT && Op1CV == 0) ||
+            (Pred == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())) {
+          Value *In = ICI->getOperand(0);
+          Value *Sh = ConstantInt::get(In->getType(),
+                                       In->getType()->getPrimitiveSizeInBits()-1);
+          In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh,
+                                                          In->getName()+".lobit"),
+                                   *ICI);
+          if (In->getType() != CI->getType())
+            In = CastInst::CreateIntegerCast(In, CI->getType(),
+                                             true/*SExt*/, "tmp", ICI);
+    
+          if (Pred == ICmpInst::ICMP_SGT)
+            In = InsertNewInstBefore(BinaryOperator::CreateNot(In,
+                                       In->getName()+".not"), *ICI);
+    
+          return ReplaceInstUsesWith(SI, In);
+        }
+      }
+    }
+
   if (CmpLHS == TrueVal && CmpRHS == FalseVal) {
     // Transform (X == Y) ? X : Y  -> Y
     if (Pred == ICmpInst::ICMP_EQ)

Modified: llvm/trunk/test/Transforms/InstCombine/logical-select.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/logical-select.ll?rev=58457&r1=58456&r2=58457&view=diff

==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/logical-select.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/logical-select.ll Thu Oct 30 15:40:10 2008
@@ -1,4 +1,7 @@
-; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep select | count 2
+; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t
+; RUN grep select %t | count 4
+; RUN not grep and %t
+; RUN not grep or %t
 
 define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d) nounwind {
   %e = icmp slt i32 %a, %b
@@ -18,3 +21,24 @@
   %j = or i32 %i, %g
   ret i32 %j
 }
+define i32 @goo(i32 %a, i32 %b, i32 %c, i32 %d) nounwind {
+entry:
+  %0 = icmp slt i32 %a, %b
+  %iftmp.0.0 = select i1 %0, i32 -1, i32 0
+  %1 = and i32 %iftmp.0.0, %c
+  %not = xor i32 %iftmp.0.0, -1
+  %2 = and i32 %not, %d
+  %3 = or i32 %1, %2
+  ret i32 %3
+}
+
+define i32 @par(i32 %a, i32 %b, i32 %c, i32 %d) nounwind {
+entry:
+  %0 = icmp slt i32 %a, %b
+  %iftmp.1.0 = select i1 %0, i32 -1, i32 0
+  %1 = and i32 %iftmp.1.0, %c
+  %not = xor i32 %iftmp.1.0, -1
+  %2 = and i32 %not, %d
+  %3 = or i32 %1, %2
+  ret i32 %3
+}





More information about the llvm-commits mailing list