[llvm] [InstCombine] Simplify switch with selects (PR #84143)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 00:59:43 PST 2024


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/84143

An example from https://github.com/image-rs/image:
```
define void @test_ult_rhsc(i8 %x) {
  %val = add nsw i8 %x, -2
  %cmp = icmp ult i8 %val, 11
  %cond = select i1 %cmp, i8 %val, i8 6
  switch i8 %cond, label %bb1 [
  i8 0, label %bb2
  i8 10, label %bb3
  ]

bb1:
  call void @func1()
  unreachable
bb2:
  call void @func2()
  unreachable
bb3:
  call void @func3()
  unreachable
}
```

When `%cmp` evaluates to false, we can prove that the range of `%val` is [11, umax]. Thus we can safely replace `%cond` with `%val` since both `switch 6` and `switch %val` go to the default dest `%bb1`.

Alive2: https://alive2.llvm.org/ce/z/uSTj6w

This patch will benefit many rust applications and some C/C++ applications (e.g., cvc5).


>From 4e5e9223a51a2d5c2fea95b149d48ced12c895fe Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 6 Mar 2024 16:43:52 +0800
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests. NFC.

---
 .../Transforms/InstCombine/switch-select.ll   | 164 ++++++++++++++++++
 1 file changed, 164 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/switch-select.ll

diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
new file mode 100644
index 00000000000000..86375ab6520454
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -0,0 +1,164 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define void @test_ult_rhsc(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_eq_lhsc(i8 %x) {
+; CHECK-LABEL: define void @test_eq_lhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[X]], 4
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 6, i8 [[X]]
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %cmp = icmp eq i8 %x, 4
+  %cond = select i1 %cmp, i8 6, i8 %x
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
+; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
+; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[Y]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %y, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+define void @test_ult_rhsc_fail(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc_fail(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:      i8 13, label [[BB3]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    call void @func1()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    call void @func2()
+; CHECK-NEXT:    unreachable
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @func3()
+; CHECK-NEXT:    unreachable
+;
+  %val = add nsw i8 %x, -2
+  %cmp = icmp ult i8 %val, 11
+  %cond = select i1 %cmp, i8 %val, i8 6
+  switch i8 %cond, label %bb1 [
+  i8 0, label %bb2
+  i8 10, label %bb3
+  i8 13, label %bb3
+  ]
+
+bb1:
+  call void @func1()
+  unreachable
+bb2:
+  call void @func2()
+  unreachable
+bb3:
+  call void @func3()
+  unreachable
+}
+
+declare void @func1()
+declare void @func2()
+declare void @func3()

>From f782c5e659ecda4666be0f3af10dce099e28c2d2 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 6 Mar 2024 16:44:45 +0800
Subject: [PATCH 2/2] [InstCombine] Simplify switch with selects

---
 .../InstCombine/InstructionCombining.cpp      | 41 +++++++++++++++++++
 .../Transforms/InstCombine/switch-select.ll   | 13 ++----
 2 files changed, 45 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 80ce0c9275b2cb..3d231f9fc7ff4b 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3334,6 +3334,37 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
   return nullptr;
 }
 
+// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
+// we can prove that both (switch C) and (switch X) go to the default when cond
+// is false/true.
+static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
+                                                SelectInst *Select,
+                                                unsigned CstOpIdx) {
+  auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
+  if (!C)
+    return nullptr;
+
+  BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
+  if (CstBB != SI.getDefaultDest())
+    return nullptr;
+  Value *X = Select->getOperand(3 - CstOpIdx);
+  ICmpInst::Predicate Pred;
+  const APInt *RHSC;
+  if (!match(Select->getCondition(),
+             m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
+    return nullptr;
+  if (CstOpIdx == 1)
+    Pred = ICmpInst::getInversePredicate(Pred);
+
+  // See whether we can replace the select with X
+  ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+  for (auto Case : SI.cases())
+    if (!CR.contains(Case.getCaseValue()->getValue()))
+      return nullptr;
+
+  return X;
+}
+
 Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
   Value *Cond = SI.getCondition();
   Value *Op0;
@@ -3407,6 +3438,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
     }
   }
 
+  // Fold switch(select cond, X, Y) into switch(X/Y) if possible
+  if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/1))
+      return replaceOperand(SI, 0, V);
+    if (Value *V =
+            simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/2))
+      return replaceOperand(SI, 0, V);
+  }
+
   KnownBits Known = computeKnownBits(Cond, 0, &SI);
   unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
   unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
index 86375ab6520454..60757c5d22527f 100644
--- a/llvm/test/Transforms/InstCombine/switch-select.ll
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -4,12 +4,9 @@
 define void @test_ult_rhsc(i8 %x) {
 ; CHECK-LABEL: define void @test_ult_rhsc(
 ; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT:    [[VAL:%.*]] = add nsw i8 [[X]], -2
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
-; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
-; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
-; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT:      i8 2, label [[BB2:%.*]]
+; CHECK-NEXT:      i8 12, label [[BB3:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       bb1:
 ; CHECK-NEXT:    call void @func1()
@@ -43,9 +40,7 @@ bb3:
 define void @test_eq_lhsc(i8 %x) {
 ; CHECK-LABEL: define void @test_eq_lhsc(
 ; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[X]], 4
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 6, i8 [[X]]
-; CHECK-NEXT:    switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT:    switch i8 [[X]], label [[BB1:%.*]] [
 ; CHECK-NEXT:      i8 0, label [[BB2:%.*]]
 ; CHECK-NEXT:      i8 10, label [[BB3:%.*]]
 ; CHECK-NEXT:    ]



More information about the llvm-commits mailing list