[llvm] r305548 - [InstCombine] Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) if K1 and K2 are a 1-bit mask

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 15 22:10:37 PDT 2017


Author: ctopper
Date: Fri Jun 16 00:10:37 2017
New Revision: 305548

URL: http://llvm.org/viewvc/llvm-project?rev=305548&view=rev
Log:
[InstCombine] Fold (!iszero(A & K1) & !iszero(A & K2)) ->  (A & (K1 | K2)) == (K1 | K2) if K1 and K2 are a 1-bit mask

Summary: This is the demorganed version of the case we already handle for the OR of iszero.

Reviewers: spatel

Reviewed By: spatel

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D34244

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/trunk/test/Transforms/InstCombine/onehot_merge.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp?rev=305548&r1=305547&r2=305548&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp Fri Jun 16 00:10:37 2017
@@ -763,8 +763,54 @@ foldAndOrOfEqualityCmpsWithConstants(ICm
   return nullptr;
 }
 
+// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
+// Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2)
+Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
+                                                   bool JoinedByAnd,
+                                                   Instruction &CxtI) {
+  ICmpInst::Predicate Pred = LHS->getPredicate();
+  if (Pred != RHS->getPredicate())
+    return nullptr;
+  if (JoinedByAnd && Pred != ICmpInst::ICMP_NE)
+    return nullptr;
+  if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ)
+    return nullptr;
+
+  // TODO support vector splats
+  ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+  ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
+  if (!LHSC || !RHSC || !LHSC->isZero() || !RHSC->isZero())
+    return nullptr;
+
+  Value *A, *B, *C, *D;
+  if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) &&
+      match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) {
+    if (A == D || B == D)
+      std::swap(C, D);
+    if (B == C)
+      std::swap(A, B);
+
+    if (A == C &&
+        isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) &&
+        isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) {
+      Value *Mask = Builder->CreateOr(B, D);
+      Value *Masked = Builder->CreateAnd(A, Mask);
+      auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
+      return Builder->CreateICmp(NewPred, Masked, Mask);
+    }
+  }
+
+  return nullptr;
+}
+
 /// Fold (icmp)&(icmp) if possible.
-Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
+Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS,
+                                    Instruction &CxtI) {
+  // Fold (!iszero(A & K1) & !iszero(A & K2)) ->  (A & (K1 | K2)) == (K1 | K2)
+  // if K1 and K2 are a one-bit mask.
+  if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI))
+    return V;
+
   ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
 
   // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
@@ -1127,7 +1173,7 @@ Instruction *InstCombiner::foldCastedBit
   ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src);
   ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src);
   if (ICmp0 && ICmp1) {
-    Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1)
+    Value *Res = LogicOpc == Instruction::And ? foldAndOfICmps(ICmp0, ICmp1, I)
                                               : foldOrOfICmps(ICmp0, ICmp1, I);
     if (Res)
       return CastInst::Create(CastOpcode, Res, DestTy);
@@ -1426,7 +1472,7 @@ Instruction *InstCombiner::visitAnd(Bina
     ICmpInst *LHS = dyn_cast<ICmpInst>(Op0);
     ICmpInst *RHS = dyn_cast<ICmpInst>(Op1);
     if (LHS && RHS)
-      if (Value *Res = foldAndOfICmps(LHS, RHS))
+      if (Value *Res = foldAndOfICmps(LHS, RHS, I))
         return replaceInstUsesWith(I, Res);
 
     // TODO: Make this recursive; it's a little tricky because an arbitrary
@@ -1434,18 +1480,18 @@ Instruction *InstCombiner::visitAnd(Bina
     Value *X, *Y;
     if (LHS && match(Op1, m_OneUse(m_And(m_Value(X), m_Value(Y))))) {
       if (auto *Cmp = dyn_cast<ICmpInst>(X))
-        if (Value *Res = foldAndOfICmps(LHS, Cmp))
+        if (Value *Res = foldAndOfICmps(LHS, Cmp, I))
           return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y));
       if (auto *Cmp = dyn_cast<ICmpInst>(Y))
-        if (Value *Res = foldAndOfICmps(LHS, Cmp))
+        if (Value *Res = foldAndOfICmps(LHS, Cmp, I))
           return replaceInstUsesWith(I, Builder->CreateAnd(Res, X));
     }
     if (RHS && match(Op0, m_OneUse(m_And(m_Value(X), m_Value(Y))))) {
       if (auto *Cmp = dyn_cast<ICmpInst>(X))
-        if (Value *Res = foldAndOfICmps(Cmp, RHS))
+        if (Value *Res = foldAndOfICmps(Cmp, RHS, I))
           return replaceInstUsesWith(I, Builder->CreateAnd(Res, Y));
       if (auto *Cmp = dyn_cast<ICmpInst>(Y))
-        if (Value *Res = foldAndOfICmps(Cmp, RHS))
+        if (Value *Res = foldAndOfICmps(Cmp, RHS, I))
           return replaceInstUsesWith(I, Builder->CreateAnd(Res, X));
     }
   }
@@ -1592,34 +1638,15 @@ static Value *matchSelectFromAndOr(Value
 /// Fold (icmp)|(icmp) if possible.
 Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
                                    Instruction &CxtI) {
-  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
-
   // Fold (iszero(A & K1) | iszero(A & K2)) ->  (A & (K1 | K2)) != (K1 | K2)
   // if K1 and K2 are a one-bit mask.
-  ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
-  ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
+  if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI))
+    return V;
 
-  // TODO support vector splats
-  if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() &&
-      RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
+  ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
 
-    Value *A, *B, *C, *D;
-    if (match(LHS->getOperand(0), m_And(m_Value(A), m_Value(B))) &&
-        match(RHS->getOperand(0), m_And(m_Value(C), m_Value(D)))) {
-      if (A == D || B == D)
-        std::swap(C, D);
-      if (B == C)
-        std::swap(A, B);
-
-      if (A == C &&
-          isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) &&
-          isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) {
-        Value *Mask = Builder->CreateOr(B, D);
-        Value *Masked = Builder->CreateAnd(A, Mask);
-        return Builder->CreateICmp(ICmpInst::ICMP_NE, Masked, Mask);
-      }
-    }
-  }
+  ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+  ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
 
   // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3)
   //                   -->  (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3)

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=305548&r1=305547&r2=305548&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Fri Jun 16 00:10:37 2017
@@ -447,12 +447,14 @@ private:
   Instruction::CastOps isEliminableCastPair(const CastInst *CI1,
                                             const CastInst *CI2);
 
-  Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS);
+  Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI);
   Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS);
   Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI);
   Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS);
   Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS);
 
+  Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
+                                       bool JoinedByAnd, Instruction &CxtI);
 public:
   /// \brief Inserts an instruction \p New before instruction \p Old
   ///

Modified: llvm/trunk/test/Transforms/InstCombine/onehot_merge.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/onehot_merge.ll?rev=305548&r1=305547&r2=305548&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/onehot_merge.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/onehot_merge.ll Fri Jun 16 00:10:37 2017
@@ -73,12 +73,10 @@ define i1 @foo1_or(i32 %k, i32 %c1, i32
 ; CHECK-LABEL: @foo1_or(
 ; CHECK-NEXT:    [[TMP:%.*]] = shl i32 1, [[C1:%.*]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 -2147483648, [[C2:%.*]]
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[TMP]], [[K:%.*]]
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0
-; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], [[K]]
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i32 [[TMP5]], 0
-; CHECK-NEXT:    [[OR:%.*]] = and i1 [[TMP2]], [[TMP6]]
-; CHECK-NEXT:    ret i1 [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[TMP]], [[TMP4]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], [[K:%.*]]
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    ret i1 [[TMP3]]
 ;
   %tmp = shl i32 1, %c1
   %tmp4 = lshr i32 -2147483648, %c2
@@ -96,12 +94,10 @@ define i1 @foo1_or_commuted(i32 %k, i32
 ; CHECK-NEXT:    [[K2:%.*]] = mul i32 [[K:%.*]], [[K]]
 ; CHECK-NEXT:    [[TMP:%.*]] = shl i32 1, [[C1:%.*]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = lshr i32 -2147483648, [[C2:%.*]]
-; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[K2]], [[TMP]]
-; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0
-; CHECK-NEXT:    [[TMP5:%.*]] = and i32 [[TMP4]], [[K2]]
-; CHECK-NEXT:    [[TMP6:%.*]] = icmp ne i32 [[TMP5]], 0
-; CHECK-NEXT:    [[OR:%.*]] = and i1 [[TMP2]], [[TMP6]]
-; CHECK-NEXT:    ret i1 [[OR]]
+; CHECK-NEXT:    [[TMP1:%.*]] = or i32 [[TMP]], [[TMP4]]
+; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[K2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    ret i1 [[TMP3]]
 ;
   %k2 = mul i32 %k, %k ; to trick the complexity sorting
   %tmp = shl i32 1, %c1




More information about the llvm-commits mailing list