[llvm] [InstCombine] Simplify switch with selects (PR #84143)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 01:00:10 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

An example from https://github.com/image-rs/image:
```
define void @<!-- -->test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @<!-- -->func1()
  unreachable
bb2:
  call void @<!-- -->func2()
  unreachable
bb3:
  call void @<!-- -->func3()
  unreachable
}
```

When `%cmp` evaluates to false, we can prove that the range of `%val` is [11, umax]. Thus we can safely replace `%cond` with `%val` since both `switch 6` and `switch %val` go to the default dest `%bb1`.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w

This patch will benefit many rust applications and some C/C++ applications (e.g., cvc5).


---
Full diff: https://github.com/llvm/llvm-project/pull/84143.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+41) 
- (added) llvm/test/Transforms/InstCombine/switch-select.ll (+159) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 80ce0c9275b2cb..3d231f9fc7ff4b 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3334,6 +3334,37 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
   return nullptr;
 }
 
+// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
+// we can prove that both (switch C) and (switch X) go to the default when cond
+// is false/true.
+static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
+                                                SelectInst *Select,
+                                                unsigned CstOpIdx) {
+  auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
+  if (!C)
+    return nullptr;
+
+  BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
+  if (CstBB != SI.getDefaultDest())
+    return nullptr;
+  Value *X = Select->getOperand(3 - CstOpIdx);
+  ICmpInst::Predicate Pred;
+  const APInt *RHSC;
+  if (!match(Select->getCondition(),
+             m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
+    return nullptr;
+  if (CstOpIdx == 1)
+    Pred = ICmpInst::getInversePredicate(Pred);
+
+  // See whether we can replace the select with X
+  ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+  for (auto Case : SI.cases())
+    if (!CR.contains(Case.getCaseValue()->getValue()))
+      return nullptr;
+
+  return X;
+}
+
 Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
   Value *Cond = SI.getCondition();
   Value *Op0;
@@ -3407,6 +3438,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
     }
   }
 
+  // Fold switch(select cond, X, Y) into switch(X/Y) if possible
+  if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/1))
+      return replaceOperand(SI, 0, V);
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/2))
+      return replaceOperand(SI, 0, V);
+  }
+
   KnownBits Known = computeKnownBits(Cond, 0, &SI);
   unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
   unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
new file mode 100644
index 00000000000000..60757c5d22527f
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define void @test_ult_rhsc(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 2, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 12, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_eq_lhsc(i8 %x) {
+; CHECK-LABEL: define void @test_eq_lhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %cmp = icmp eq i8 %x, 4
+  %cond = select i1 %cmp, i8 6, i8 %x
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
+; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
+; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[Y]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %y, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_fail(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc_fail(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+declare void @func1()
+declare void @func2()
+declare void @func3()

``````````

</details>


https://github.com/llvm/llvm-project/pull/84143


More information about the llvm-commits mailing list