[llvm] [InstCombine] Simplify switch with selects (PR #84143)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 6 00:59:43 PST 2024
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/84143
An example from https://github.com/image-rs/image:
```
define void @test_ult_rhsc(i8 %x) {
%val = add nsw i8 %x, -2
%cmp = icmp ult i8 %val, 11
%cond = select i1 %cmp, i8 %val, i8 6
switch i8 %cond, label %bb1 [
i8 0, label %bb2
i8 10, label %bb3
]
bb1:
call void @func1()
unreachable
bb2:
call void @func2()
unreachable
bb3:
call void @func3()
unreachable
}
```
When `%cmp` evaluates to false, we can prove that the range of `%val` is [11, umax]. Thus we can safely replace `%cond` with `%val` since both `switch 6` and `switch %val` go to the default dest `%bb1`.
Alive2: https://alive2.llvm.org/ce/z/uSTj6w
This patch will benefit many rust applications and some C/C++ applications (e.g., cvc5).
>From 4e5e9223a51a2d5c2fea95b149d48ced12c895fe Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 6 Mar 2024 16:43:52 +0800
Subject: [PATCH 1/2] [InstCombine] Add pre-commit tests. NFC.
---
.../Transforms/InstCombine/switch-select.ll | 164 ++++++++++++++++++
1 file changed, 164 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/switch-select.ll
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
new file mode 100644
index 00000000000000..86375ab6520454
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -0,0 +1,164 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define void @test_ult_rhsc(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %val = add nsw i8 %x, -2
+ %cmp = icmp ult i8 %val, 11
+ %cond = select i1 %cmp, i8 %val, i8 6
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @test_eq_lhsc(i8 %x) {
+; CHECK-LABEL: define void @test_eq_lhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], 4
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 6, i8 [[X]]
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cmp = icmp eq i8 %x, 4
+ %cond = select i1 %cmp, i8 6, i8 %x
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
+; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
+; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
+; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y]], 11
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: i8 13, label [[BB3]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %val = add nsw i8 %x, -2
+ %cmp = icmp ult i8 %y, 11
+ %cond = select i1 %cmp, i8 %val, i8 6
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ i8 13, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @test_ult_rhsc_fail(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc_fail(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: i8 13, label [[BB3]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %val = add nsw i8 %x, -2
+ %cmp = icmp ult i8 %val, 11
+ %cond = select i1 %cmp, i8 %val, i8 6
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ i8 13, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+declare void @func1()
+declare void @func2()
+declare void @func3()
>From f782c5e659ecda4666be0f3af10dce099e28c2d2 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 6 Mar 2024 16:44:45 +0800
Subject: [PATCH 2/2] [InstCombine] Simplify switch with selects
---
.../InstCombine/InstructionCombining.cpp | 41 +++++++++++++++++++
.../Transforms/InstCombine/switch-select.ll | 13 ++----
2 files changed, 45 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 80ce0c9275b2cb..3d231f9fc7ff4b 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3334,6 +3334,37 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return nullptr;
}
+// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
+// we can prove that both (switch C) and (switch X) go to the default when cond
+// is false/true.
+static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
+ SelectInst *Select,
+ unsigned CstOpIdx) {
+ auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
+ if (!C)
+ return nullptr;
+
+ BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
+ if (CstBB != SI.getDefaultDest())
+ return nullptr;
+ Value *X = Select->getOperand(3 - CstOpIdx);
+ ICmpInst::Predicate Pred;
+ const APInt *RHSC;
+ if (!match(Select->getCondition(),
+ m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
+ return nullptr;
+ if (CstOpIdx == 1)
+ Pred = ICmpInst::getInversePredicate(Pred);
+
+ // See whether we can replace the select with X
+ ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+ for (auto Case : SI.cases())
+ if (!CR.contains(Case.getCaseValue()->getValue()))
+ return nullptr;
+
+ return X;
+}
+
Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();
Value *Op0;
@@ -3407,6 +3438,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
}
}
+ // Fold switch(select cond, X, Y) into switch(X/Y) if possible
+ if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+ if (Value *V =
+ simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/1))
+ return replaceOperand(SI, 0, V);
+ if (Value *V =
+ simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/2))
+ return replaceOperand(SI, 0, V);
+ }
+
KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
index 86375ab6520454..60757c5d22527f 100644
--- a/llvm/test/Transforms/InstCombine/switch-select.ll
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -4,12 +4,9 @@
define void @test_ult_rhsc(i8 %x) {
; CHECK-LABEL: define void @test_ult_rhsc(
; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
-; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
-; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
-; CHECK-NEXT: i8 0, label [[BB2:%.*]]
-; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 2, label [[BB2:%.*]]
+; CHECK-NEXT: i8 12, label [[BB3:%.*]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
@@ -43,9 +40,7 @@ bb3:
define void @test_eq_lhsc(i8 %x) {
; CHECK-LABEL: define void @test_eq_lhsc(
; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], 4
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 6, i8 [[X]]
-; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
; CHECK-NEXT: ]
More information about the llvm-commits
mailing list