[llvm] 87b2c76 - [Instcombine] fold logic ops to select

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 4 20:04:43 PST 2023


Author: chenglin.bi
Date: 2023-01-05T12:04:35+08:00
New Revision: 87b2c760d0183246c27f9e1c34e1e1120e03449b

URL: https://github.com/llvm/llvm-project/commit/87b2c760d0183246c27f9e1c34e1e1120e03449b
DIFF: https://github.com/llvm/llvm-project/commit/87b2c760d0183246c27f9e1c34e1e1120e03449b.diff

LOG: [Instcombine] fold logic ops to select

(C & X) | ~(C | Y) -> C ? X : ~Y

https://alive2.llvm.org/ce/z/4yLh_i

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D139080

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/test/Transforms/InstCombine/logical-select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 48d57e39427d0..014ad4e7f4f02 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2600,7 +2600,9 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
 /// We have an expression of the form (A & C) | (B & D). If A is a scalar or
 /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of
 /// B, it can be used as the condition operand of a select instruction.
-Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
+/// We will detect (A & C) | ~(B | D) when the flag ABIsTheSame enabled.
+Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B,
+                                            bool ABIsTheSame) {
   // We may have peeked through bitcasts in the caller.
   // Exit immediately if we don't have (vector) integer types.
   Type *Ty = A->getType();
@@ -2608,7 +2610,7 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
     return nullptr;
 
   // If A is the 'not' operand of B and has enough signbits, we have our answer.
-  if (match(B, m_Not(m_Specific(A)))) {
+  if (ABIsTheSame ? (A == B) : match(B, m_Not(m_Specific(A)))) {
     // If these are scalars or vectors of i1, A can be used directly.
     if (Ty->isIntOrIntVectorTy(1))
       return A;
@@ -2628,6 +2630,10 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
     return nullptr;
   }
 
+  // TODO: add support for sext and constant case
+  if (ABIsTheSame)
+    return nullptr;
+
   // If both operands are constants, see if the constants are inverse bitmasks.
   Constant *AConst, *BConst;
   if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst)))
@@ -2676,14 +2682,17 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
 
 /// We have an expression of the form (A & C) | (B & D). Try to simplify this
 /// to "A' ? C : D", where A' is a boolean or vector of booleans.
+/// When InvertFalseVal is set to true, we try to match the pattern
+/// where we have peeked through a 'not' op and A and B are the same:
+/// (A & C) | ~(A | D) --> (A & C) | (~A & ~D) --> A' ? C : ~D
 Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
-                                              Value *D) {
+                                              Value *D, bool InvertFalseVal) {
   // The potential condition of the select may be bitcasted. In that case, look
   // through its bitcast and the corresponding bitcast of the 'not' condition.
   Type *OrigType = A->getType();
   A = peekThroughBitcast(A, true);
   B = peekThroughBitcast(B, true);
-  if (Value *Cond = getSelectCondition(A, B)) {
+  if (Value *Cond = getSelectCondition(A, B, InvertFalseVal)) {
     // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D))
     // If this is a vector, we may need to cast to match the condition's length.
     // The bitcasts will either all exist or all not exist. The builder will
@@ -2699,6 +2708,8 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
       SelTy = VectorType::get(EltTy, VecTy->getElementCount());
     }
     Value *BitcastC = Builder.CreateBitCast(C, SelTy);
+    if (InvertFalseVal)
+      D = Builder.CreateNot(D);
     Value *BitcastD = Builder.CreateBitCast(D, SelTy);
     Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD);
     return Builder.CreateBitCast(Select, OrigType);
@@ -3087,6 +3098,20 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
     }
   }
 
+  if (match(Op0, m_And(m_Value(A), m_Value(C))) &&
+      match(Op1, m_Not(m_Or(m_Value(B), m_Value(D)))) &&
+      (Op0->hasOneUse() || Op1->hasOneUse())) {
+    // (Cond & C) | ~(Cond | D) -> Cond ? C : ~D
+    if (Value *V = matchSelectFromAndOr(A, C, B, D, true))
+      return replaceInstUsesWith(I, V);
+    if (Value *V = matchSelectFromAndOr(A, C, D, B, true))
+      return replaceInstUsesWith(I, V);
+    if (Value *V = matchSelectFromAndOr(C, A, B, D, true))
+      return replaceInstUsesWith(I, V);
+    if (Value *V = matchSelectFromAndOr(C, A, D, B, true))
+      return replaceInstUsesWith(I, V);
+  }
+
   // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C
   if (match(Op0, m_Xor(m_Value(A), m_Value(B))))
     if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A))))

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index bfbc31e10a80a..e99c8a0bfa3b7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -371,8 +371,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
                                        Instruction *CxtI, bool IsAnd,
                                        bool IsLogical = false);
-  Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D);
-  Value *getSelectCondition(Value *A, Value *B);
+  Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D,
+                              bool InvertFalseVal = false);
+  Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame);
 
   Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV);
   Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II);

diff  --git a/llvm/test/Transforms/InstCombine/logical-select.ll b/llvm/test/Transforms/InstCombine/logical-select.ll
index 6707149593ed7..5537fdcca53ac 100644
--- a/llvm/test/Transforms/InstCombine/logical-select.ll
+++ b/llvm/test/Transforms/InstCombine/logical-select.ll
@@ -991,10 +991,8 @@ define <2 x i1> @xor_commute3(<2 x i1> %x, <2 x i1> %y) {
 
 define i1 @not_d_bools_commute00(i1 %c, i1 %x, i1 %y) {
 ; CHECK-LABEL: @not_d_bools_commute00(
-; CHECK-NEXT:    [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor i1 [[Y_C]], true
-; CHECK-NEXT:    [[AND1:%.*]] = and i1 [[C]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = or i1 [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[Y:%.*]], true
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %y_c = or i1 %c, %y
@@ -1006,10 +1004,8 @@ define i1 @not_d_bools_commute00(i1 %c, i1 %x, i1 %y) {
 
 define i1 @not_d_bools_commute01(i1 %c, i1 %x, i1 %y) {
 ; CHECK-LABEL: @not_d_bools_commute01(
-; CHECK-NEXT:    [[Y_C:%.*]] = or i1 [[Y:%.*]], [[C:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor i1 [[Y_C]], true
-; CHECK-NEXT:    [[AND1:%.*]] = and i1 [[C]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = or i1 [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[Y:%.*]], true
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %y_c = or i1 %y, %c
@@ -1021,10 +1017,8 @@ define i1 @not_d_bools_commute01(i1 %c, i1 %x, i1 %y) {
 
 define i1 @not_d_bools_commute10(i1 %c, i1 %x, i1 %y) {
 ; CHECK-LABEL: @not_d_bools_commute10(
-; CHECK-NEXT:    [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor i1 [[Y_C]], true
-; CHECK-NEXT:    [[AND1:%.*]] = and i1 [[X:%.*]], [[C]]
-; CHECK-NEXT:    [[R:%.*]] = or i1 [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[Y:%.*]], true
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %y_c = or i1 %c, %y
@@ -1036,10 +1030,8 @@ define i1 @not_d_bools_commute10(i1 %c, i1 %x, i1 %y) {
 
 define i1 @not_d_bools_commute11(i1 %c, i1 %x, i1 %y) {
 ; CHECK-LABEL: @not_d_bools_commute11(
-; CHECK-NEXT:    [[Y_C:%.*]] = or i1 [[Y:%.*]], [[C:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor i1 [[Y_C]], true
-; CHECK-NEXT:    [[AND1:%.*]] = and i1 [[X:%.*]], [[C]]
-; CHECK-NEXT:    [[R:%.*]] = or i1 [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[Y:%.*]], true
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[X:%.*]], i1 [[TMP1]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %y_c = or i1 %y, %c
@@ -1051,10 +1043,8 @@ define i1 @not_d_bools_commute11(i1 %c, i1 %x, i1 %y) {
 
 define <2 x i1> @not_d_bools_vector(<2 x i1> %c, <2 x i1> %x, <2 x i1> %y) {
 ; CHECK-LABEL: @not_d_bools_vector(
-; CHECK-NEXT:    [[Y_C:%.*]] = or <2 x i1> [[Y:%.*]], [[C:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor <2 x i1> [[Y_C]], <i1 true, i1 true>
-; CHECK-NEXT:    [[AND1:%.*]] = and <2 x i1> [[X:%.*]], [[C]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i1> [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i1> [[Y:%.*]], <i1 true, i1 true>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[X:%.*]], <2 x i1> [[TMP1]]
 ; CHECK-NEXT:    ret <2 x i1> [[R]]
 ;
   %y_c = or <2 x i1> %y, %c
@@ -1066,10 +1056,8 @@ define <2 x i1> @not_d_bools_vector(<2 x i1> %c, <2 x i1> %x, <2 x i1> %y) {
 
 define <2 x i1> @not_d_bools_vector_poison(<2 x i1> %c, <2 x i1> %x, <2 x i1> %y) {
 ; CHECK-LABEL: @not_d_bools_vector_poison(
-; CHECK-NEXT:    [[Y_C:%.*]] = or <2 x i1> [[Y:%.*]], [[C:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor <2 x i1> [[Y_C]], <i1 poison, i1 true>
-; CHECK-NEXT:    [[AND1:%.*]] = and <2 x i1> [[X:%.*]], [[C]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i1> [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor <2 x i1> [[Y:%.*]], <i1 true, i1 true>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[C:%.*]], <2 x i1> [[X:%.*]], <2 x i1> [[TMP1]]
 ; CHECK-NEXT:    ret <2 x i1> [[R]]
 ;
   %y_c = or <2 x i1> %y, %c
@@ -1081,11 +1069,9 @@ define <2 x i1> @not_d_bools_vector_poison(<2 x i1> %c, <2 x i1> %x, <2 x i1> %y
 
 define i32 @not_d_allSignBits(i32 %cond, i32 %tval, i32 %fval) {
 ; CHECK-LABEL: @not_d_allSignBits(
-; CHECK-NEXT:    [[BITMASK:%.*]] = ashr i32 [[COND:%.*]], 31
-; CHECK-NEXT:    [[A1:%.*]] = and i32 [[BITMASK]], [[TVAL:%.*]]
-; CHECK-NEXT:    [[OR:%.*]] = or i32 [[BITMASK]], [[FVAL:%.*]]
-; CHECK-NEXT:    [[A2:%.*]] = xor i32 [[OR]], -1
-; CHECK-NEXT:    [[SEL:%.*]] = or i32 [[A1]], [[A2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i32 [[FVAL:%.*]], -1
+; CHECK-NEXT:    [[DOTNOT2:%.*]] = icmp slt i32 [[COND:%.*]], 0
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[DOTNOT2]], i32 [[TVAL:%.*]], i32 [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[SEL]]
 ;
   %bitmask = ashr i32 %cond, 31
@@ -1099,9 +1085,9 @@ define i32 @not_d_allSignBits(i32 %cond, i32 %tval, i32 %fval) {
 define i1 @not_d_bools_use2(i1 %c, i1 %x, i1 %y) {
 ; CHECK-LABEL: @not_d_bools_use2(
 ; CHECK-NEXT:    [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[AND2:%.*]] = xor i1 [[Y_C]], true
 ; CHECK-NEXT:    [[AND1:%.*]] = and i1 [[C]], [[X:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = or i1 [[AND1]], [[AND2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i1 [[Y]], true
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i1 [[X]], i1 [[TMP1]]
 ; CHECK-NEXT:    call void @use1(i1 [[AND1]])
 ; CHECK-NEXT:    call void @use1(i1 [[Y_C]])
 ; CHECK-NEXT:    ret i1 [[R]]
@@ -1115,6 +1101,8 @@ define i1 @not_d_bools_use2(i1 %c, i1 %x, i1 %y) {
   ret i1 %r
 }
 
+; negative test: both op is not one-use
+
 define i1 @not_d_bools_negative_use2(i1 %c, i1 %x, i1 %y) {
 ; CHECK-LABEL: @not_d_bools_negative_use2(
 ; CHECK-NEXT:    [[Y_C:%.*]] = or i1 [[C:%.*]], [[Y:%.*]]


        


More information about the llvm-commits mailing list