[llvm] 0c47363 - [InstCombine] Simplify nested selects with implied condition (#83739)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 22:11:40 PST 2024


Author: Yingwei Zheng
Date: 2024-03-05T14:11:37+08:00
New Revision: 0c4736338596d6e527e286b7b551af4bb8b63a55

URL: https://github.com/llvm/llvm-project/commit/0c4736338596d6e527e286b7b551af4bb8b63a55
DIFF: https://github.com/llvm/llvm-project/commit/0c4736338596d6e527e286b7b551af4bb8b63a55.diff

LOG: [InstCombine] Simplify nested selects with implied condition (#83739)

This patch does the following simplification:
```
sel1 = select cond1, X, Y 
sel2 = select cond2, sel1, Z
-->
sel2 = select cond2, X, Z if cond2 implies cond1
sel2 = select cond2, Y, Z if cond2 implies !cond1
```
Alive2: https://alive2.llvm.org/ce/z/9A_arU

It cannot be done in CVP/SCCP since we should guarantee that `cond2` is
not an undef.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
    llvm/test/Transforms/InstCombine/nested-select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 71fa9b9ba41ebb..c47bc33df0706b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2643,46 +2643,33 @@ static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy
   return nullptr;
 }
 
+/// Given that \p CondVal is known to be \p CondIsTrue, try to simplify \p SI.
+static Value *simplifyNestedSelectsUsingImpliedCond(SelectInst &SI,
+                                                    Value *CondVal,
+                                                    bool CondIsTrue,
+                                                    const DataLayout &DL) {
+  Value *InnerCondVal = SI.getCondition();
+  Value *InnerTrueVal = SI.getTrueValue();
+  Value *InnerFalseVal = SI.getFalseValue();
+  assert(CondVal->getType() == InnerCondVal->getType() &&
+         "The type of inner condition must match with the outer.");
+  if (auto Implied = isImpliedCondition(CondVal, InnerCondVal, DL, CondIsTrue))
+    return *Implied ? InnerTrueVal : InnerFalseVal;
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
                                                                  SelectInst &SI,
                                                                  bool IsAnd) {
-  Value *CondVal = SI.getCondition();
-  Value *A = SI.getTrueValue();
-  Value *B = SI.getFalseValue();
-
   assert(Op->getType()->isIntOrIntVectorTy(1) &&
          "Op must be either i1 or vector of i1.");
-
-  std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
-  if (!Res)
+  if (SI.getCondition()->getType() != Op->getType())
     return nullptr;
-
-  Value *Zero = Constant::getNullValue(A->getType());
-  Value *One = Constant::getAllOnesValue(A->getType());
-
-  if (*Res == true) {
-    if (IsAnd)
-      // select op, (select cond, A, B), false => select op, A, false
-      // and    op, (select cond, A, B)        => select op, A, false
-      //   if op = true implies condval = true.
-      return SelectInst::Create(Op, A, Zero);
-    else
-      // select op, true, (select cond, A, B) => select op, true, A
-      // or     op, (select cond, A, B)       => select op, true, A
-      //   if op = false implies condval = true.
-      return SelectInst::Create(Op, One, A);
-  } else {
-    if (IsAnd)
-      // select op, (select cond, A, B), false => select op, B, false
-      // and    op, (select cond, A, B)        => select op, B, false
-      //   if op = true implies condval = false.
-      return SelectInst::Create(Op, B, Zero);
-    else
-      // select op, true, (select cond, A, B) => select op, true, B
-      // or     op, (select cond, A, B)       => select op, true, B
-      //   if op = false implies condval = false.
-      return SelectInst::Create(Op, One, B);
-  }
+  if (Value *V = simplifyNestedSelectsUsingImpliedCond(SI, Op, IsAnd, DL))
+    return SelectInst::Create(Op,
+                              IsAnd ? V : ConstantInt::getTrue(Op->getType()),
+                              IsAnd ? ConstantInt::getFalse(Op->getType()) : V);
+  return nullptr;
 }
 
 // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need
@@ -3138,11 +3125,6 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
       return replaceInstUsesWith(SI, Op1);
     }
 
-    if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
-      if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
-                                                      /* IsAnd */ IsAnd))
-        return I;
-
     if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
       if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
         if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
@@ -3643,12 +3625,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
 
   if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
     if (TrueSI->getCondition()->getType() == CondVal->getType()) {
-      // select(C, select(C, a, b), c) -> select(C, a, c)
-      if (TrueSI->getCondition() == CondVal) {
-        if (SI.getTrueValue() == TrueSI->getTrueValue())
-          return nullptr;
-        return replaceOperand(SI, 1, TrueSI->getTrueValue());
-      }
+      // Fold nested selects if the inner condition can be implied by the outer
+      // condition.
+      if (Value *V = simplifyNestedSelectsUsingImpliedCond(
+              *TrueSI, CondVal, /*CondIsTrue=*/true, DL))
+        return replaceOperand(SI, 1, V);
+
       // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
       // We choose this as normal form to enable folding on the And and
       // shortening paths for the values (this helps getUnderlyingObjects() for
@@ -3663,12 +3645,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   }
   if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
     if (FalseSI->getCondition()->getType() == CondVal->getType()) {
-      // select(C, a, select(C, b, c)) -> select(C, a, c)
-      if (FalseSI->getCondition() == CondVal) {
-        if (SI.getFalseValue() == FalseSI->getFalseValue())
-          return nullptr;
-        return replaceOperand(SI, 2, FalseSI->getFalseValue());
-      }
+      // Fold nested selects if the inner condition can be implied by the outer
+      // condition.
+      if (Value *V = simplifyNestedSelectsUsingImpliedCond(
+              *FalseSI, CondVal, /*CondIsTrue=*/false, DL))
+        return replaceOperand(SI, 2, V);
+
       // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
       if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
         Value *Or = Builder.CreateLogicalOr(CondVal, FalseSI->getCondition());

diff  --git a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
index d03e22bc4c9fbf..b5ef1f466958d7 100644
--- a/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
+++ b/llvm/test/Transforms/InstCombine/canonicalize-clamp-like-pattern-between-negative-and-positive-thresholds.ll
@@ -189,10 +189,8 @@ define i32 @n9_ult_slt_neg17(i32 %x, i32 %replacement_low, i32 %replacement_high
 ; Regression test for PR53252.
 define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 ; CHECK-LABEL: @n10_ugt_slt(
-; CHECK-NEXT:    [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
-; CHECK-NEXT:    [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = icmp ugt i32 [[X]], 128
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[T1]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ugt i32 [[X:%.*]], 128
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[X]], i32 [[REPLACEMENT_HIGH:%.*]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %t0 = icmp slt i32 %x, 0
@@ -204,10 +202,8 @@ define i32 @n10_ugt_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 
 define i32 @n11_uge_slt(i32 %x, i32 %replacement_low, i32 %replacement_high) {
 ; CHECK-LABEL: @n11_uge_slt(
-; CHECK-NEXT:    [[T0:%.*]] = icmp slt i32 [[X:%.*]], 0
-; CHECK-NEXT:    [[T1:%.*]] = select i1 [[T0]], i32 [[REPLACEMENT_LOW:%.*]], i32 [[REPLACEMENT_HIGH:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = icmp ult i32 [[X]], 129
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[T1]], i32 [[X]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ult i32 [[X:%.*]], 129
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[T2]], i32 [[REPLACEMENT_HIGH:%.*]], i32 [[X]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %t0 = icmp slt i32 %x, 0

diff  --git a/llvm/test/Transforms/InstCombine/nested-select.ll b/llvm/test/Transforms/InstCombine/nested-select.ll
index 42a0f81e7b85a2..d01dcf0793ade2 100644
--- a/llvm/test/Transforms/InstCombine/nested-select.ll
+++ b/llvm/test/Transforms/InstCombine/nested-select.ll
@@ -498,3 +498,94 @@ define i1 @orcond.111.inv.all.conds(i1 %inner.cond, i1 %alt.cond, i1 %inner.sel.
   %outer.sel = select i1 %not.outer.cond, i1 true, i1 %inner.sel
   ret i1 %outer.sel
 }
+
+define i8 @test_implied_true(i8 %x) {
+; CHECK-LABEL: @test_implied_true(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 0, i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, 10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+define <2 x i8> @test_implied_true_vec(<2 x i8> %x) {
+; CHECK-LABEL: @test_implied_true_vec(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt <2 x i8> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    [[SEL2:%.*]] = select <2 x i1> [[CMP2]], <2 x i8> zeroinitializer, <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
+;
+  %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
+  %cmp2 = icmp slt <2 x i8> %x, zeroinitializer
+  %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+  %sel2 = select <2 x i1> %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
+  ret <2 x i8> %sel2
+}
+
+define i8 @test_implied_true_falseval(i8 %x) {
+; CHECK-LABEL: @test_implied_true_falseval(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp sgt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 20, i8 0
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, 10
+  %cmp2 = icmp sgt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 20, i8 %sel1
+  ret i8 %sel2
+}
+
+define i8 @test_implied_false(i8 %x) {
+; CHECK-LABEL: @test_implied_false(
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X:%.*]], 0
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 5, i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp sgt i8 %x, 10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+; Negative tests
+
+define i8 @test_imply_fail(i8 %x) {
+; CHECK-LABEL: @test_imply_fail(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i8 [[X:%.*]], -10
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[X]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select i1 [[CMP1]], i8 0, i8 5
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], i8 [[SEL1]], i8 20
+; CHECK-NEXT:    ret i8 [[SEL2]]
+;
+  %cmp1 = icmp slt i8 %x, -10
+  %cmp2 = icmp slt i8 %x, 0
+  %sel1 = select i1 %cmp1, i8 0, i8 5
+  %sel2 = select i1 %cmp2, i8 %sel1, i8 20
+  ret i8 %sel2
+}
+
+define <2 x i8> @test_imply_type_mismatch(<2 x i8> %x, i8 %y) {
+; CHECK-LABEL: @test_imply_type_mismatch(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt <2 x i8> [[X:%.*]], <i8 10, i8 10>
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp slt i8 [[Y:%.*]], 0
+; CHECK-NEXT:    [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+; CHECK-NEXT:    [[SEL2:%.*]] = select i1 [[CMP2]], <2 x i8> [[SEL1]], <2 x i8> <i8 20, i8 20>
+; CHECK-NEXT:    ret <2 x i8> [[SEL2]]
+;
+  %cmp1 = icmp slt <2 x i8> %x, <i8 10, i8 10>
+  %cmp2 = icmp slt i8 %y, 0
+  %sel1 = select <2 x i1> %cmp1, <2 x i8> zeroinitializer, <2 x i8> <i8 5, i8 5>
+  %sel2 = select i1 %cmp2, <2 x i8> %sel1, <2 x i8> <i8 20, i8 20>
+  ret <2 x i8> %sel2
+}
+
+define <4 x i1> @test_dont_crash(i1 %cond, <4 x i1> %a, <4 x i1> %b) {
+entry:
+  %sel = select i1 %cond, <4 x i1> %a, <4 x i1> zeroinitializer
+  %and = and <4 x i1> %sel, %b
+  ret <4 x i1> %and
+}


        


More information about the llvm-commits mailing list