[llvm] b4c8cfc - [InstCombine] fold more icmp + select patterns by distributive laws

via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 7 07:55:57 PST 2022


Author: chenglin.bi
Date: 2022-12-07T23:55:49+08:00
New Revision: b4c8cfc7c211fae99c734ca5091c334238528443

URL: https://github.com/llvm/llvm-project/commit/b4c8cfc7c211fae99c734ca5091c334238528443
DIFF: https://github.com/llvm/llvm-project/commit/b4c8cfc7c211fae99c734ca5091c334238528443.diff

LOG: [InstCombine] fold more icmp + select patterns by distributive laws

follow up D139076, add icmp with not only eq/ne, but also gt/lt/ge/le.

Reviewed By: spatel

Differential Revision: https://reviews.llvm.org/D139253

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select-bitext.ll
    llvm/test/Transforms/InstCombine/select-cmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 8d6fad3955d9..31659d581293 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -316,33 +316,42 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
 
   Value *OtherOpT, *OtherOpF;
   bool MatchIsOpZero;
-  auto getCommonOp = [&](Instruction *TI, Instruction *FI,
-                         bool Commute) -> Value * {
-    Value *CommonOp = nullptr;
-    if (TI->getOperand(0) == FI->getOperand(0)) {
-      CommonOp = TI->getOperand(0);
-      OtherOpT = TI->getOperand(1);
-      OtherOpF = FI->getOperand(1);
-      MatchIsOpZero = true;
-    } else if (TI->getOperand(1) == FI->getOperand(1)) {
-      CommonOp = TI->getOperand(1);
-      OtherOpT = TI->getOperand(0);
-      OtherOpF = FI->getOperand(0);
-      MatchIsOpZero = false;
-    } else if (!Commute) {
+  auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute,
+                         bool Swapped = false) -> Value * {
+    assert(!(Commute && Swapped) &&
+           "Commute and Swapped can't set at the same time");
+    if (!Swapped) {
+      if (TI->getOperand(0) == FI->getOperand(0)) {
+        OtherOpT = TI->getOperand(1);
+        OtherOpF = FI->getOperand(1);
+        MatchIsOpZero = true;
+        return TI->getOperand(0);
+      } else if (TI->getOperand(1) == FI->getOperand(1)) {
+        OtherOpT = TI->getOperand(0);
+        OtherOpF = FI->getOperand(0);
+        MatchIsOpZero = false;
+        return TI->getOperand(1);
+      }
+    }
+
+    if (!Commute && !Swapped)
       return nullptr;
-    } else if (TI->getOperand(0) == FI->getOperand(1)) {
-      CommonOp = TI->getOperand(0);
+
+    // If we are allowing commute or swap of operands, then
+    // allow a cross-operand match. In that case, MatchIsOpZero
+    // means that TI's operand 0 (FI's operand 1) is the common op.
+    if (TI->getOperand(0) == FI->getOperand(1)) {
       OtherOpT = TI->getOperand(1);
       OtherOpF = FI->getOperand(0);
       MatchIsOpZero = true;
+      return TI->getOperand(0);
     } else if (TI->getOperand(1) == FI->getOperand(0)) {
-      CommonOp = TI->getOperand(1);
       OtherOpT = TI->getOperand(0);
       OtherOpF = FI->getOperand(1);
-      MatchIsOpZero = true;
+      MatchIsOpZero = false;
+      return TI->getOperand(1);
     }
-    return CommonOp;
+    return nullptr;
   };
 
   if (TI->hasOneUse() || FI->hasOneUse()) {
@@ -379,16 +388,20 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
       }
     }
 
-    // icmp eq/ne with a common operand also can have the common operand
+    // icmp with a common operand also can have the common operand
     // pulled after the select.
     ICmpInst::Predicate TPred, FPred;
     if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) &&
         match(FI, m_ICmp(FPred, m_Value(), m_Value()))) {
-      if (TPred == FPred && ICmpInst::isEquality(TPred)) {
-        if (Value *MatchOp = getCommonOp(TI, FI, true)) {
+      if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) {
+        bool Swapped = TPred != FPred;
+        if (Value *MatchOp =
+                getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) {
           Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
                                                SI.getName() + ".v", &SI);
-          return new ICmpInst(TPred, NewSel, MatchOp);
+          return new ICmpInst(
+              MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred),
+              MatchOp, NewSel);
         }
       }
     }

diff  --git a/llvm/test/Transforms/InstCombine/select-bitext.ll b/llvm/test/Transforms/InstCombine/select-bitext.ll
index 02379a494abb..75272c643926 100644
--- a/llvm/test/Transforms/InstCombine/select-bitext.ll
+++ b/llvm/test/Transforms/InstCombine/select-bitext.ll
@@ -485,10 +485,9 @@ define i32 @sel_zext_const_uses(i8 %a, i8 %x) {
 
 define i32 @test_op_op(i32 %a, i32 %b, i32 %c) {
 ; CHECK-LABEL: @test_op_op(
-; CHECK-NEXT:    [[CCA:%.*]] = icmp sgt i32 [[A:%.*]], 0
-; CHECK-NEXT:    [[CCB:%.*]] = icmp sgt i32 [[B:%.*]], 0
 ; CHECK-NEXT:    [[CCC:%.*]] = icmp sgt i32 [[C:%.*]], 0
-; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[CCC]], i1 [[CCA]], i1 [[CCB]]
+; CHECK-NEXT:    [[R_V_V:%.*]] = select i1 [[CCC]], i32 [[A:%.*]], i32 [[B:%.*]]
+; CHECK-NEXT:    [[R_V:%.*]] = icmp sgt i32 [[R_V_V]], 0
 ; CHECK-NEXT:    [[R:%.*]] = sext i1 [[R_V]] to i32
 ; CHECK-NEXT:    ret i32 [[R]]
 ;

diff  --git a/llvm/test/Transforms/InstCombine/select-cmp.ll b/llvm/test/Transforms/InstCombine/select-cmp.ll
index a39f48dcbe72..711fac542179 100644
--- a/llvm/test/Transforms/InstCombine/select-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/select-cmp.ll
@@ -124,9 +124,8 @@ define i1 @icmp_common_one_use_1(i1 %c, i8 %x, i8 %y, i8 %z) {
 
 define i1 @icmp_slt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_slt_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i6 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i6 [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp slt i6 %x, %y
@@ -137,9 +136,8 @@ define i1 @icmp_slt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_sgt_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sgt i6 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i6 [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp slt i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp sgt i6 %x, %y
@@ -150,9 +148,8 @@ define i1 @icmp_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_sle_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_sle_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sle i6 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sle i6 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp sle i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp sle i6 %y, %x
@@ -163,9 +160,8 @@ define i1 @icmp_sle_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_sge_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sge i6 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sge i6 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp sge i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp sge i6 %y, %x
@@ -176,9 +172,8 @@ define i1 @icmp_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_slt_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_slt_sgt_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i6 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i6 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp slt i6 %x, %y
@@ -189,9 +184,8 @@ define i1 @icmp_slt_sgt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_sle_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_sle_sge_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp sle i6 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp sge i6 [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp sle i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp sle i6 %y, %x
@@ -202,9 +196,8 @@ define i1 @icmp_sle_sge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_ult_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_ult_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult i6 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ult i6 [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp ult i6 %x, %y
@@ -215,9 +208,8 @@ define i1 @icmp_ult_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_ule_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_ule_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ule i6 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ule i6 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ule i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp ule i6 %y, %x
@@ -228,9 +220,8 @@ define i1 @icmp_ule_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_ugt_common(i1 %c, i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @icmp_ugt_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ugt i8 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i8 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i8 [[Y:%.*]], i8 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt i8 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp ugt i8 %y, %x
@@ -241,9 +232,8 @@ define i1 @icmp_ugt_common(i1 %c, i8 %x, i8 %y, i8 %z) {
 
 define i1 @icmp_uge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_uge_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp uge i6 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp uge i6 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp uge i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp uge i6 %y, %x
@@ -254,9 +244,8 @@ define i1 @icmp_uge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_ult_ugt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_ult_ugt_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult i6 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i6 [[Z:%.*]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ugt i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp ult i6 %x, %y
@@ -267,9 +256,8 @@ define i1 @icmp_ult_ugt_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 
 define i1 @icmp_ule_uge_common(i1 %c, i6 %x, i6 %y, i6 %z) {
 ; CHECK-LABEL: @icmp_ule_uge_common(
-; CHECK-NEXT:    [[CMP1:%.*]] = icmp ule i6 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp uge i6 [[X]], [[Z:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i1 [[CMP1]], i1 [[CMP2]]
+; CHECK-NEXT:    [[R_V:%.*]] = select i1 [[C:%.*]], i6 [[Y:%.*]], i6 [[Z:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = icmp ule i6 [[R_V]], [[X:%.*]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %cmp1 = icmp ule i6 %y, %x
@@ -293,7 +281,7 @@ define i1 @icmp_common_pred_
diff erent(i1 %c, i8 %x, i8 %y, i8 %z) {
   ret i1 %r
 }
 
-; negative test: two pred is not swap
+; negative test for non-equality: two pred is not swap
 
 define i1 @icmp_common_pred_not_swap(i1 %c, i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @icmp_common_pred_not_swap(
@@ -308,7 +296,7 @@ define i1 @icmp_common_pred_not_swap(i1 %c, i8 %x, i8 %y, i8 %z) {
   ret i1 %r
 }
 
-; negative test: not commute pred
+; negative test for non-equality: not commute pred
 
 define i1 @icmp_common_pred_not_commute_pred(i1 %c, i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @icmp_common_pred_not_commute_pred(


        


More information about the llvm-commits mailing list