[llvm] 5fe1466 - [InstCombine] Simplify switch with selects (#84143)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 15 01:40:20 PDT 2024


Author: Yingwei Zheng
Date: 2024-04-15T16:40:16+08:00
New Revision: 5fe146672d2b1c9f257a6ee045e0bc13fed1e504

URL: https://github.com/llvm/llvm-project/commit/5fe146672d2b1c9f257a6ee045e0bc13fed1e504
DIFF: https://github.com/llvm/llvm-project/commit/5fe146672d2b1c9f257a6ee045e0bc13fed1e504.diff

LOG: [InstCombine] Simplify switch with selects (#84143)

An example from https://github.com/image-rs/image:
```
define void @test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @func1()
  unreachable
bb2:
  call void @func2()
  unreachable
bb3:
  call void @func3()
  unreachable
}
```

When `%cmp` evaluates to false, we can prove that the range of `%val` is
[11, umax]. Thus we can safely replace `%cond` with `%val` since both
`switch 6` and `switch %val` go to the default dest `%bb1`.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w
Godbolt: https://godbolt.org/z/MGrG84bzr

This patch will benefit many rust applications and some C/C++
applications (e.g., cvc5).

Added: 
    llvm/test/Transforms/InstCombine/switch-select.ll

Modified: 
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 4cfafa7ac80060..4c00f2a0ea1761 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3572,6 +3572,38 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
   return nullptr;
 }
 
+// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
+// we can prove that both (switch C) and (switch X) go to the default when cond
+// is false/true.
+static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
+                                                SelectInst *Select,
+                                                bool IsTrueArm) {
+  unsigned CstOpIdx = IsTrueArm ? 1 : 2;
+  auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
+  if (!C)
+    return nullptr;
+
+  BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
+  if (CstBB != SI.getDefaultDest())
+    return nullptr;
+  Value *X = Select->getOperand(3 - CstOpIdx);
+  ICmpInst::Predicate Pred;
+  const APInt *RHSC;
+  if (!match(Select->getCondition(),
+             m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
+    return nullptr;
+  if (IsTrueArm)
+    Pred = ICmpInst::getInversePredicate(Pred);
+
+  // See whether we can replace the select with X
+  ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+  for (auto Case : SI.cases())
+    if (!CR.contains(Case.getCaseValue()->getValue()))
+      return nullptr;
+
+  return X;
+}
+
 Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
   Value *Cond = SI.getCondition();
   Value *Op0;
@@ -3645,6 +3677,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
     }
   }
 
+  // Fold switch(select cond, X, Y) into switch(X/Y) if possible
+  if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/true))
+      return replaceOperand(SI, 0, V);
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/false))
+      return replaceOperand(SI, 0, V);
+  }
+
   KnownBits Known = computeKnownBits(Cond, 0, &SI);
   unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
   unsigned LeadingKnownOnes = Known.countMinLeadingOnes();

diff  --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
new file mode 100644
index 00000000000000..60757c5d22527f
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define void @test_ult_rhsc(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 2, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 12, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_eq_lhsc(i8 %x) {
+; CHECK-LABEL: define void @test_eq_lhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %cmp = icmp eq i8 %x, 4
+  %cond = select i1 %cmp, i8 6, i8 %x
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
+; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
+; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[Y]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %y, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_fail(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc_fail(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+declare void @func1()
+declare void @func2()
+declare void @func3()


        


More information about the llvm-commits mailing list