[llvm] c75a0f5 - [InstCombine] Optimize compares with multiple selects as operands

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri May 26 07:05:40 PDT 2023


Author: Tejas Joshi
Date: 2023-05-26T16:05:32+02:00
New Revision: c75a0f5a9a368b6ca3ec0a696f2a934e8dd0e5bb

URL: https://github.com/llvm/llvm-project/commit/c75a0f5a9a368b6ca3ec0a696f2a934e8dd0e5bb
DIFF: https://github.com/llvm/llvm-project/commit/c75a0f5a9a368b6ca3ec0a696f2a934e8dd0e5bb.diff

LOG: [InstCombine] Optimize compares with multiple selects as operands

In case of a comparison with two select instructions having the same
condition, check whether one of the resulting branches can be simplified.
If so, just compare the other branch and select the appropriate result.
For example:

    %tmp1 = select i1 %cmp, i32 %y, i32 %x
    %tmp2 = select i1 %cmp, i32 %z, i32 %x
    %cmp2 = icmp slt i32 %tmp2, %tmp1

The icmp will result false for the false value of selects and the result
will depend upon the comparison of true values of selects if %cmp is
true. Thus, transform this into:

    %cmp = icmp slt i32 %y, %z
    %sel = select i1 %cond, i1 %cmp, i1 false

Differential Revision: https://reviews.llvm.org/D150360

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/icmp-with-selects.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e1a80737c913..462e65d8e7d6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6577,6 +6577,37 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
     if (Instruction *NI = foldSelectICmp(I.getSwappedPredicate(), SI, Op0, I))
       return NI;
 
+  // In case of a comparison with two select instructions having the same
+  // condition, check whether one of the resulting branches can be simplified.
+  // If so, just compare the other branch and select the appropriate result.
+  // For example:
+  //   %tmp1 = select i1 %cmp, i32 %y, i32 %x
+  //   %tmp2 = select i1 %cmp, i32 %z, i32 %x
+  //   %cmp2 = icmp slt i32 %tmp2, %tmp1
+  // The icmp will result false for the false value of selects and the result
+  // will depend upon the comparison of true values of selects if %cmp is
+  // true. Thus, transform this into:
+  //   %cmp = icmp slt i32 %y, %z
+  //   %sel = select i1 %cond, i1 %cmp, i1 false
+  // This handles similar cases to transform.
+  {
+    Value *Cond, *A, *B, *C, *D;
+    if (match(Op0, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) &&
+        match(Op1, m_Select(m_Specific(Cond), m_Value(C), m_Value(D))) &&
+        (Op0->hasOneUse() || Op1->hasOneUse())) {
+      // Check whether comparison of TrueValues can be simplified
+      if (Value *Res = simplifyICmpInst(Pred, A, C, SQ)) {
+        Value *NewICMP = Builder.CreateICmp(Pred, B, D);
+        return SelectInst::Create(Cond, Res, NewICMP);
+      }
+      // Check whether comparison of FalseValues can be simplified
+      if (Value *Res = simplifyICmpInst(Pred, B, D, SQ)) {
+        Value *NewICMP = Builder.CreateICmp(Pred, A, C);
+        return SelectInst::Create(Cond, NewICMP, Res);
+      }
+    }
+  }
+
   // Try to optimize equality comparisons against alloca-based pointers.
   if (Op0->getType()->isPointerTy() && I.isEquality()) {
     assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");

diff  --git a/llvm/test/Transforms/InstCombine/icmp-with-selects.ll b/llvm/test/Transforms/InstCombine/icmp-with-selects.ll
index 540eccea3b22..9ee7c78379f5 100644
--- a/llvm/test/Transforms/InstCombine/icmp-with-selects.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-with-selects.ll
@@ -7,10 +7,7 @@ define i1 @both_sides_fold_slt(i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @both_sides_fold_slt
 ; CHECK-SAME: (i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 1, i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 9, i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %cond1 = select i1 %cond, i32 1, i32 %param
@@ -23,10 +20,8 @@ define i1 @both_sides_fold_eq(i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @both_sides_fold_eq
 ; CHECK-SAME: (i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 1, i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 9, i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[COND6]], [[COND1]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[NOT_COND:%.*]] = xor i1 [[COND]], true
+; CHECK-NEXT:    ret i1 [[NOT_COND]]
 ;
 entry:
   %cond1 = select i1 %cond, i32 1, i32 %param
@@ -39,9 +34,8 @@ define i1 @one_side_fold_slt(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @one_side_fold_slt
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp slt i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[COND]], i1 [[TMP0]], i1 false
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -55,9 +49,9 @@ define i1 @one_side_fold_sgt(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @one_side_fold_sgt
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[PARAM]], i32 [[VAL1]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[PARAM]], i32 [[VAL2]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[NOT_COND:%.*]] = xor i1 [[COND]], true
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[NOT_COND]], i1 [[TMP0]], i1 false
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -71,9 +65,9 @@ define i1 @one_side_fold_eq(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-LABEL: define i1 @one_side_fold_eq
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[NOT_COND:%.*]] = xor i1 [[COND]], true
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[NOT_COND]], i1 true, i1 [[TMP0]]
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -120,9 +114,9 @@ define i1 @one_select_mult_use(i32 %val1, i32 %val2, i32 %param, i1 %cond) {
 ; CHECK-SAME: (i32 [[VAL1:%.*]], i32 [[VAL2:%.*]], i32 [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], i32 [[VAL1]], i32 [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], i32 [[VAL2]], i32 [[PARAM]]
 ; CHECK-NEXT:    call void @use(i32 [[COND1]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i32 [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp slt i32 [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[COND]], i1 [[TMP0]], i1 false
 ; CHECK-NEXT:    ret i1 [[CMP]]
 ;
 entry:
@@ -155,9 +149,8 @@ define <4 x i1> @fold_vector_ops(<4 x i32> %val1, <4 x i32> %val2, <4 x i32> %pa
 ; CHECK-LABEL: define <4 x i1> @fold_vector_ops
 ; CHECK-SAME: (<4 x i32> [[VAL1:%.*]], <4 x i32> [[VAL2:%.*]], <4 x i32> [[PARAM:%.*]], i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select i1 [[COND]], <4 x i32> [[VAL1]], <4 x i32> [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select i1 [[COND]], <4 x i32> [[VAL2]], <4 x i32> [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <4 x i32> [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp eq <4 x i32> [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select i1 [[COND]], <4 x i1> [[TMP0]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>
 ; CHECK-NEXT:    ret <4 x i1> [[CMP]]
 ;
 entry:
@@ -171,9 +164,8 @@ define <8 x i1> @fold_vector_cond_ops(<8 x i32> %val1, <8 x i32> %val2, <8 x i32
 ; CHECK-LABEL: define <8 x i1> @fold_vector_cond_ops
 ; CHECK-SAME: (<8 x i32> [[VAL1:%.*]], <8 x i32> [[VAL2:%.*]], <8 x i32> [[PARAM:%.*]], <8 x i1> [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[COND1:%.*]] = select <8 x i1> [[COND]], <8 x i32> [[VAL1]], <8 x i32> [[PARAM]]
-; CHECK-NEXT:    [[COND6:%.*]] = select <8 x i1> [[COND]], <8 x i32> [[VAL2]], <8 x i32> [[PARAM]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt <8 x i32> [[COND6]], [[COND1]]
+; CHECK-NEXT:    [[TMP0:%.*]] = icmp sgt <8 x i32> [[VAL2]], [[VAL1]]
+; CHECK-NEXT:    [[CMP:%.*]] = select <8 x i1> [[COND]], <8 x i1> [[TMP0]], <8 x i1> zeroinitializer
 ; CHECK-NEXT:    ret <8 x i1> [[CMP]]
 ;
 entry:


        


More information about the llvm-commits mailing list