[llvm] [InstCombine] fold Select with a predicate consists of icmp connected by And (PR #76363)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 25 08:26:07 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Chia (sun-jacobi)

<details>
<summary>Changes</summary>


This patch closes #<!-- -->76043. 

---
We extended the pre-exist `foldSelectWithBinaryOp`, to make it support the below case: 
```
%A = icmp eq %TV, %FV
%C = and %A, %B
%D = select %C, %TV, %FV
->
%FV
```
or 
```
%A = icmp ne %TV, %FV
%C = or %A, %B
%D = select %C, %FV, %TV
->
%TV
```
The Alive2 proof: https://alive2.llvm.org/ce/z/XLyhE-

--- 
For updated test cases in `select-and-cmp.ll` and `select-or-cmp.ll`, we also provided Alive2 proof: https://alive2.llvm.org/ce/z/krhtZy

---
Full diff: https://github.com/llvm/llvm-project/pull/76363.diff


3 Files Affected:

- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+46-27) 
- (modified) llvm/test/Transforms/InstSimplify/select-and-cmp.ll (+27-24) 
- (modified) llvm/test/Transforms/InstSimplify/select-or-cmp.ll (+28-36) 


``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 5beac5547d65e0..ca2fc9ca173c9a 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -79,15 +79,56 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
                                               const SimplifyQuery &SQ,
                                               unsigned MaxRecurse);
 
+static Value *simplifySelectWithAndOR(Value *Cond, Value *TrueVal,
+                                      Value *FalseVal,
+                                      CmpInst::Predicate ExpectedPred,
+                                      BinaryOperator::BinaryOps BinOpCode,
+                                      unsigned int MaxRecurse) {
+  assert(
+      (BinOpCode == BinaryOperator::And || BinOpCode == BinaryOperator::Or) &&
+      "Binary Operator should be And or Or");
+
+  assert(
+      (BinOpCode == BinaryOperator::And && ExpectedPred == ICmpInst::ICMP_EQ) ||
+      (BinOpCode == BinaryOperator::Or && ExpectedPred == ICmpInst::ICMP_NE));
+
+  if (!MaxRecurse)
+    return nullptr;
+
+  auto getSimplifiedValue = [](BinaryOperator::BinaryOps BinOpCode,
+                               Value *TrueVal, Value *FalseVal) {
+    return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal;
+  };
+
+  CmpInst::Predicate Pred;
+  if (match(Cond, m_c_ICmp(Pred, m_Specific(TrueVal), m_Specific(FalseVal))) &&
+      Pred == ExpectedPred)
+    return getSimplifiedValue(BinOpCode, TrueVal, FalseVal);
+
+  Value *X, *Y;
+  if (match(Cond, m_c_BinOp(BinOpCode, m_Value(X), m_Value(Y)))) {
+
+    auto matchBinOpCode = [&](Value *V) {
+      return simplifySelectWithAndOR(V, TrueVal, FalseVal, ExpectedPred,
+                                     BinOpCode, MaxRecurse - 1);
+    };
+
+    if (matchBinOpCode(X) || matchBinOpCode(Y))
+      return getSimplifiedValue(BinOpCode, TrueVal, FalseVal);
+  }
+
+  return nullptr;
+}
+
 static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
-                                     Value *FalseVal) {
+                                     Value *FalseVal, unsigned int MaxRecurse) {
   BinaryOperator::BinaryOps BinOpCode;
   if (auto *BO = dyn_cast<BinaryOperator>(Cond))
     BinOpCode = BO->getOpcode();
   else
     return nullptr;
 
-  CmpInst::Predicate ExpectedPred, Pred1, Pred2;
+  CmpInst::Predicate ExpectedPred;
   if (BinOpCode == BinaryOperator::Or) {
     ExpectedPred = ICmpInst::ICMP_NE;
   } else if (BinOpCode == BinaryOperator::And) {
@@ -95,30 +136,8 @@ static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal,
   } else
     return nullptr;
 
-  // %A = icmp eq %TV, %FV
-  // %B = icmp eq %X, %Y (and one of these is a select operand)
-  // %C = and %A, %B
-  // %D = select %C, %TV, %FV
-  // -->
-  // %FV
-
-  // %A = icmp ne %TV, %FV
-  // %B = icmp ne %X, %Y (and one of these is a select operand)
-  // %C = or %A, %B
-  // %D = select %C, %TV, %FV
-  // -->
-  // %TV
-  Value *X, *Y;
-  if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal),
-                                      m_Specific(FalseVal)),
-                             m_ICmp(Pred2, m_Value(X), m_Value(Y)))) ||
-      Pred1 != Pred2 || Pred1 != ExpectedPred)
-    return nullptr;
-
-  if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal)
-    return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal;
-
-  return nullptr;
+  return simplifySelectWithAndOR(Cond, TrueVal, FalseVal, ExpectedPred,
+                                 BinOpCode, MaxRecurse);
 }
 
 /// For a boolean type or a vector of boolean type, return false or a vector
@@ -4906,7 +4925,7 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
   if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal, Q))
     return V;
 
-  if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal))
+  if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal, MaxRecurse))
     return V;
 
   std::optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL);
diff --git a/llvm/test/Transforms/InstSimplify/select-and-cmp.ll b/llvm/test/Transforms/InstSimplify/select-and-cmp.ll
index 41a4ab96bd62cc..8a48618e217084 100644
--- a/llvm/test/Transforms/InstSimplify/select-and-cmp.ll
+++ b/llvm/test/Transforms/InstSimplify/select-and-cmp.ll
@@ -78,6 +78,28 @@ define i32 @select_and_inv_icmp_alt(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
+define i32 @select_and_icmp_ne(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_ne(
+; CHECK-NEXT:    ret i32 [[X:%.*]]
+;
+  %A = icmp eq i32 %x, %z
+  %B = icmp ne i32 %y, %z
+  %C = and i1 %A, %B
+  %D = select i1 %C, i32 %z, i32 %x
+  ret i32 %D
+}
+
+define i32 @select_and_icmp_ne_alt(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_ne_alt(
+; CHECK-NEXT:    ret i32 [[Z:%.*]]
+;
+  %A = icmp eq i32 %x, %z
+  %B = icmp ne i32 %y, %z
+  %C = and i1 %A, %B
+  %D = select i1 %C, i32 %x, i32 %z
+  ret i32 %D
+}
+
 define i32 @select_and_inv_icmp(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_inv_icmp(
 ; CHECK-NEXT:    ret i32 [[X:%.*]]
@@ -115,21 +137,6 @@ define i32 @select_and_icmp_inv(i32 %x, i32 %y, i32 %z) {
 ; Negative tests
 define i32 @select_and_icmp_pred_bad_1(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_icmp_pred_bad_1(
-; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[Z]], i32 [[X]]
-; CHECK-NEXT:    ret i32 [[D]]
-;
-  %A = icmp eq i32 %x, %z
-  %B = icmp ne i32 %y, %z
-  %C = and i1 %A, %B
-  %D = select i1 %C, i32 %z, i32 %x
-  ret i32 %D
-}
-
-define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_2(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -143,8 +150,8 @@ define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_3(
+define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_pred_bad_2(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -158,8 +165,8 @@ define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_and_icmp_pred_bad_4(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_4(
+define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_pred_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
@@ -235,11 +242,7 @@ define i32 @select_and_icmp_bad_op_2(i32 %x, i32 %y, i32 %z, i32 %k) {
 
 define i32 @select_and_icmp_alt_bad_1(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_icmp_alt_bad_1(
-; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[X]], i32 [[Z]]
-; CHECK-NEXT:    ret i32 [[D]]
+; CHECK-NEXT:    ret i32 [[Z:%.*]]
 ;
   %A = icmp eq i32 %x, %z
   %B = icmp ne i32 %y, %z
diff --git a/llvm/test/Transforms/InstSimplify/select-or-cmp.ll b/llvm/test/Transforms/InstSimplify/select-or-cmp.ll
index 0e410a9645f0d2..8b91ea03062695 100644
--- a/llvm/test/Transforms/InstSimplify/select-or-cmp.ll
+++ b/llvm/test/Transforms/InstSimplify/select-or-cmp.ll
@@ -78,6 +78,28 @@ define i32 @select_or_inv_icmp_alt(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
+define i32 @select_or_icmp_eq(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_or_icmp_eq(
+; CHECK-NEXT:    ret i32 [[X:%.*]]
+;
+  %A = icmp ne i32 %x, %z
+  %B = icmp eq i32 %y, %z
+  %C = or i1 %A, %B
+  %D = select i1 %C, i32 %x, i32 %z
+  ret i32 %D
+}
+
+define i32 @select_or_icmp_eq_alt(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_or_icmp_eq_alt(
+; CHECK-NEXT:    ret i32 [[Z:%.*]]
+;
+  %A = icmp ne i32 %x, %z
+  %B = icmp eq i32 %y, %z
+  %C = or i1 %A, %B
+  %D = select i1 %C, i32 %z, i32 %x
+  ret i32 %D
+}
+
 define <2 x i8> @select_or_icmp_alt_vec(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) {
 ; CHECK-LABEL: @select_or_icmp_alt_vec(
 ; CHECK-NEXT:    ret <2 x i8> [[X:%.*]]
@@ -129,21 +151,6 @@ define i32 @select_and_icmp_pred_bad_1(i32 %x, i32 %y, i32 %z) {
 
 define i32 @select_and_icmp_pred_bad_2(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_and_icmp_pred_bad_2(
-; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[Z]], i32 [[X]]
-; CHECK-NEXT:    ret i32 [[D]]
-;
-  %A = icmp ne i32 %x, %z
-  %B = icmp eq i32 %y, %z
-  %C = or i1 %A, %B
-  %D = select i1 %C, i32 %z, i32 %x
-  ret i32 %D
-}
-
-define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
@@ -157,8 +164,8 @@ define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_and_icmp_pred_bad_4(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_and_icmp_pred_bad_4(
+define i32 @select_and_icmp_pred_bad_3(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_and_icmp_pred_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -250,21 +257,6 @@ define i32 @select_or_icmp_alt_bad_1(i32 %x, i32 %y, i32 %z) {
 
 define i32 @select_or_icmp_alt_bad_2(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: @select_or_icmp_alt_bad_2(
-; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
-; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
-; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
-; CHECK-NEXT:    [[D:%.*]] = select i1 [[C]], i32 [[X]], i32 [[Z]]
-; CHECK-NEXT:    ret i32 [[D]]
-;
-  %A = icmp ne i32 %x, %z
-  %B = icmp eq i32 %y, %z
-  %C = or i1 %A, %B
-  %D = select i1 %C, i32 %x, i32 %z
-  ret i32 %D
-}
-
-define i32 @select_or_icmp_alt_bad_3(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_or_icmp_alt_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]
@@ -278,8 +270,8 @@ define i32 @select_or_icmp_alt_bad_3(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_or_icmp_alt_bad_4(i32 %x, i32 %y, i32 %z) {
-; CHECK-LABEL: @select_or_icmp_alt_bad_4(
+define i32 @select_or_icmp_alt_bad_3(i32 %x, i32 %y, i32 %z) {
+; CHECK-LABEL: @select_or_icmp_alt_bad_3(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z]]
 ; CHECK-NEXT:    [[C:%.*]] = and i1 [[A]], [[B]]
@@ -293,8 +285,8 @@ define i32 @select_or_icmp_alt_bad_4(i32 %x, i32 %y, i32 %z) {
   ret i32 %D
 }
 
-define i32 @select_or_icmp_alt_bad_5(i32 %x, i32 %y, i32 %z, i32 %k) {
-; CHECK-LABEL: @select_or_icmp_alt_bad_5(
+define i32 @select_or_icmp_alt_bad_4(i32 %x, i32 %y, i32 %z, i32 %k) {
+; CHECK-LABEL: @select_or_icmp_alt_bad_4(
 ; CHECK-NEXT:    [[A:%.*]] = icmp ne i32 [[X:%.*]], [[K:%.*]]
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 [[Y:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[A]], [[B]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/76363


More information about the llvm-commits mailing list