[llvm] [InstCombine] Simplify switch with selects (PR #84143)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 22 00:14:30 PDT 2024
https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/84143
>From 0c4c262a1e48b6c023ed6b1eb19bf3a54c22651a Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 6 Mar 2024 16:43:52 +0800
Subject: [PATCH 1/3] [InstCombine] Add pre-commit tests. NFC.
---
.../Transforms/InstCombine/switch-select.ll | 164 ++++++++++++++++++
1 file changed, 164 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/switch-select.ll
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
new file mode 100644
index 00000000000000..86375ab6520454
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -0,0 +1,164 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define void @test_ult_rhsc(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %val = add nsw i8 %x, -2
+ %cmp = icmp ult i8 %val, 11
+ %cond = select i1 %cmp, i8 %val, i8 6
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @test_eq_lhsc(i8 %x) {
+; CHECK-LABEL: define void @test_eq_lhsc(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], 4
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 6, i8 [[X]]
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %cmp = icmp eq i8 %x, 4
+ %cond = select i1 %cmp, i8 6, i8 %x
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @test_ult_rhsc_invalid_cond(i8 %x, i8 %y) {
+; CHECK-LABEL: define void @test_ult_rhsc_invalid_cond(
+; CHECK-SAME: i8 [[X:%.*]], i8 [[Y:%.*]]) {
+; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[Y]], 11
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: i8 13, label [[BB3]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %val = add nsw i8 %x, -2
+ %cmp = icmp ult i8 %y, 11
+ %cond = select i1 %cmp, i8 %val, i8 6
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ i8 13, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+define void @test_ult_rhsc_fail(i8 %x) {
+; CHECK-LABEL: define void @test_ult_rhsc_fail(
+; CHECK-SAME: i8 [[X:%.*]]) {
+; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
+; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
+; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
+; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 0, label [[BB2:%.*]]
+; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: i8 13, label [[BB3]]
+; CHECK-NEXT: ]
+; CHECK: bb1:
+; CHECK-NEXT: call void @func1()
+; CHECK-NEXT: unreachable
+; CHECK: bb2:
+; CHECK-NEXT: call void @func2()
+; CHECK-NEXT: unreachable
+; CHECK: bb3:
+; CHECK-NEXT: call void @func3()
+; CHECK-NEXT: unreachable
+;
+ %val = add nsw i8 %x, -2
+ %cmp = icmp ult i8 %val, 11
+ %cond = select i1 %cmp, i8 %val, i8 6
+ switch i8 %cond, label %bb1 [
+ i8 0, label %bb2
+ i8 10, label %bb3
+ i8 13, label %bb3
+ ]
+
+bb1:
+ call void @func1()
+ unreachable
+bb2:
+ call void @func2()
+ unreachable
+bb3:
+ call void @func3()
+ unreachable
+}
+
+declare void @func1()
+declare void @func2()
+declare void @func3()
>From 28ed09c4857def525f9a97bfcd86d1d07b318843 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Wed, 6 Mar 2024 16:44:45 +0800
Subject: [PATCH 2/3] [InstCombine] Simplify switch with selects
---
.../InstCombine/InstructionCombining.cpp | 41 +++++++++++++++++++
.../Transforms/InstCombine/switch-select.ll | 13 ++----
2 files changed, 45 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 7c40fb4fc86082..bcc1a2ffd677f5 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3572,6 +3572,37 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return nullptr;
}
+// Replaces (switch (select cond, X, C)/(select cond, C, X)) with (switch X) if
+// we can prove that both (switch C) and (switch X) go to the default when cond
+// is false/true.
+static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
+ SelectInst *Select,
+ unsigned CstOpIdx) {
+ auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
+ if (!C)
+ return nullptr;
+
+ BasicBlock *CstBB = SI.findCaseValue(C)->getCaseSuccessor();
+ if (CstBB != SI.getDefaultDest())
+ return nullptr;
+ Value *X = Select->getOperand(3 - CstOpIdx);
+ ICmpInst::Predicate Pred;
+ const APInt *RHSC;
+ if (!match(Select->getCondition(),
+ m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
+ return nullptr;
+ if (CstOpIdx == 1)
+ Pred = ICmpInst::getInversePredicate(Pred);
+
+ // See whether we can replace the select with X
+ ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+ for (auto Case : SI.cases())
+ if (!CR.contains(Case.getCaseValue()->getValue()))
+ return nullptr;
+
+ return X;
+}
+
Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
Value *Cond = SI.getCondition();
Value *Op0;
@@ -3645,6 +3676,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
}
}
+ // Fold switch(select cond, X, Y) into switch(X/Y) if possible
+ if (auto *Select = dyn_cast<SelectInst>(Cond)) {
+ if (Value *V =
+ simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/1))
+ return replaceOperand(SI, 0, V);
+ if (Value *V =
+ simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/2))
+ return replaceOperand(SI, 0, V);
+ }
+
KnownBits Known = computeKnownBits(Cond, 0, &SI);
unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
unsigned LeadingKnownOnes = Known.countMinLeadingOnes();
diff --git a/llvm/test/Transforms/InstCombine/switch-select.ll b/llvm/test/Transforms/InstCombine/switch-select.ll
index 86375ab6520454..60757c5d22527f 100644
--- a/llvm/test/Transforms/InstCombine/switch-select.ll
+++ b/llvm/test/Transforms/InstCombine/switch-select.ll
@@ -4,12 +4,9 @@
define void @test_ult_rhsc(i8 %x) {
; CHECK-LABEL: define void @test_ult_rhsc(
; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT: [[VAL:%.*]] = add nsw i8 [[X]], -2
-; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[VAL]], 11
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 [[VAL]], i8 6
-; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
-; CHECK-NEXT: i8 0, label [[BB2:%.*]]
-; CHECK-NEXT: i8 10, label [[BB3:%.*]]
+; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
+; CHECK-NEXT: i8 2, label [[BB2:%.*]]
+; CHECK-NEXT: i8 12, label [[BB3:%.*]]
; CHECK-NEXT: ]
; CHECK: bb1:
; CHECK-NEXT: call void @func1()
@@ -43,9 +40,7 @@ bb3:
define void @test_eq_lhsc(i8 %x) {
; CHECK-LABEL: define void @test_eq_lhsc(
; CHECK-SAME: i8 [[X:%.*]]) {
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X]], 4
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i8 6, i8 [[X]]
-; CHECK-NEXT: switch i8 [[COND]], label [[BB1:%.*]] [
+; CHECK-NEXT: switch i8 [[X]], label [[BB1:%.*]] [
; CHECK-NEXT: i8 0, label [[BB2:%.*]]
; CHECK-NEXT: i8 10, label [[BB3:%.*]]
; CHECK-NEXT: ]
>From 1828c728260d6907354a39360b58302f83e197fe Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Fri, 22 Mar 2024 14:52:13 +0800
Subject: [PATCH 3/3] [InstCombine] Address comments.
---
llvm/lib/Transforms/InstCombine/InstructionCombining.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index bcc1a2ffd677f5..7432499373c906 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3577,7 +3577,8 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
// is false/true.
static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
SelectInst *Select,
- unsigned CstOpIdx) {
+ bool IsTrueArm) {
+ unsigned CstOpIdx = IsTrueArm ? 1 : 2;
auto *C = dyn_cast<ConstantInt>(Select->getOperand(CstOpIdx));
if (!C)
return nullptr;
@@ -3591,7 +3592,7 @@ static Value *simplifySwitchOnSelectUsingRanges(SwitchInst &SI,
if (!match(Select->getCondition(),
m_ICmp(Pred, m_Specific(X), m_APInt(RHSC))))
return nullptr;
- if (CstOpIdx == 1)
+ if (IsTrueArm)
Pred = ICmpInst::getInversePredicate(Pred);
// See whether we can replace the select with X
@@ -3679,10 +3680,10 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
// Fold switch(select cond, X, Y) into switch(X/Y) if possible
if (auto *Select = dyn_cast<SelectInst>(Cond)) {
if (Value *V =
- simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/1))
+ simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/true))
return replaceOperand(SI, 0, V);
if (Value *V =
- simplifySwitchOnSelectUsingRanges(SI, Select, /*CstOpIdx=*/2))
+ simplifySwitchOnSelectUsingRanges(SI, Select, /*IsTrueArm=*/false))
return replaceOperand(SI, 0, V);
}
More information about the llvm-commits
mailing list