[llvm] [SimplifyCFG] Fold switch over ucmp/scmp to icmp and br (PR #105636)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 22 07:05:21 PDT 2024
https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/105636
>From 1620a6ee2ac01e9fe989faf645c13fa1a2b606d0 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 22 Aug 2024 11:49:46 +0200
Subject: [PATCH 1/4] [SimplifyCFG] Fold switch over ucmp/scmp to icmp and br
If we switch over ucmp/scmp and have two switch cases going to
the same destination, we can convert into icmp+br.
Fixes https://github.com/llvm/llvm-project/issues/105632.
---
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 56 ++++++++++++++++
.../Transforms/SimplifyCFG/switch-on-cmp.ll | 65 ++++++-------------
2 files changed, 75 insertions(+), 46 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 00efd3c0eb72ec..d7295d89078b6c 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7131,6 +7131,59 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
return true;
}
+/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
+/// the same destination.
+static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI,
+ IRBuilderBase &Builder) {
+ auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
+ if (!Cmp || !Cmp->hasOneUse() || SI->getNumCases() != 2)
+ return false;
+
+ // Find which of 1, 0 or -1 is missing.
+ SmallSet<int64_t, 3> Missing;
+ Missing.insert(1);
+ Missing.insert(0);
+ Missing.insert(-1);
+ BasicBlock *Succ = nullptr;
+ for (auto &Case : SI->cases()) {
+ std::optional<int64_t> Val = Case.getCaseValue()->getValue().trySExtValue();
+ if (!Val)
+ return false;
+ if (!Missing.erase(*Val))
+ return false;
+ if (Succ && Succ != Case.getCaseSuccessor())
+ return false;
+ Succ = Case.getCaseSuccessor();
+ }
+
+ // Determine predicate for the missing case.
+ ICmpInst::Predicate Pred;
+ assert(Missing.size() == 1 && "Should have one case left");
+ switch (*Missing.begin()) {
+ case 1:
+ Pred = ICmpInst::ICMP_UGT;
+ break;
+ case 0:
+ Pred = ICmpInst::ICMP_EQ;
+ break;
+ case -1:
+ Pred = ICmpInst::ICMP_ULT;
+ break;
+ }
+ if (Cmp->isSigned())
+ Pred = ICmpInst::getSignedPredicate(Pred);
+
+ // The dominator tree does not change, because it treats multi-edges like
+ // a single edge anyway.
+ Builder.SetInsertPoint(SI->getIterator());
+ Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
+ Builder.CreateCondBr(ICmp, SI->getDefaultDest(), Succ);
+ Succ->removePredecessor(SI->getParent());
+ SI->eraseFromParent();
+ Cmp->eraseFromParent();
+ return true;
+}
+
bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
BasicBlock *BB = SI->getParent();
@@ -7163,6 +7216,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
return requestResimplify();
+ if (simplifySwitchOfCmpIntrinsic(SI, Builder))
+ return requestResimplify();
+
if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
return requestResimplify();
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
index 1ce18533d156d0..a9b7a0b9e2f096 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
@@ -4,11 +4,8 @@
define void @ucmp_gt1(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_gt1(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT: i8 0, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -32,11 +29,8 @@ bb2:
define void @ucmp_gt2(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_gt2(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 0, label %[[BB2:.*]]
-; CHECK-NEXT: i8 -1, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -60,11 +54,8 @@ bb2:
define void @ucmp_lt1(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_lt1(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB2:.*]] [
-; CHECK-NEXT: i8 1, label %[[BB1:.*]]
-; CHECK-NEXT: i8 0, label %[[BB1]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB2:.*]], label %[[BB1:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -88,11 +79,8 @@ bb2:
define void @ucmp_lt2(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_lt2(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB2:.*]] [
-; CHECK-NEXT: i8 0, label %[[BB1:.*]]
-; CHECK-NEXT: i8 1, label %[[BB1]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB2:.*]], label %[[BB1:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -116,11 +104,8 @@ bb2:
define void @ucmp_eq1(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_eq1(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT: i8 1, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -144,11 +129,8 @@ bb2:
define void @ucmp_eq2(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_eq2(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 1, label %[[BB2:.*]]
-; CHECK-NEXT: i8 -1, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -172,11 +154,8 @@ bb2:
define void @scmp_gt1(i32 %a, i32 %b) {
; CHECK-LABEL: define void @scmp_gt1(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT: i8 0, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -200,11 +179,8 @@ bb2:
define void @scmp_gt2(i32 %a, i32 %b) {
; CHECK-LABEL: define void @scmp_gt2(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 0, label %[[BB2:.*]]
-; CHECK-NEXT: i8 -1, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP1:%.*]] = icmp sgt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
@@ -259,16 +235,13 @@ define i32 @ucmp_gt_phi(i32 %a, i32 %b) {
; CHECK-LABEL: define i32 @ucmp_gt_phi(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*]]:
-; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
-; CHECK-NEXT: switch i8 [[RES]], label %[[BB1:.*]] [
-; CHECK-NEXT: i8 -1, label %[[BB2:.*]]
-; CHECK-NEXT: i8 0, label %[[BB2]]
-; CHECK-NEXT: ]
+; CHECK-NEXT: [[TMP0:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP0]], label %[[BB1:.*]], label %[[BB2:.*]]
; CHECK: [[BB1]]:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label %[[BB2]]
; CHECK: [[BB2]]:
-; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ 0, %[[BB1]] ], [ 1, %[[ENTRY]] ], [ 1, %[[ENTRY]] ]
+; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ 0, %[[BB1]] ], [ 1, %[[ENTRY]] ]
; CHECK-NEXT: ret i32 [[PHI]]
;
entry:
>From 4d89ec6f7634a33a2ea59699e7d59f0a14ddc37a Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 22 Aug 2024 12:58:31 +0200
Subject: [PATCH 2/4] Support switch with unreachable default
---
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 90 ++++--
.../Transforms/SimplifyCFG/switch-on-cmp.ll | 268 ++++++++++++++++++
2 files changed, 333 insertions(+), 25 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index d7295d89078b6c..98a07e79640214 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7133,33 +7133,70 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
/// Fold switch over ucmp/scmp intrinsic to br if two of the switch arms have
/// the same destination.
-static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI,
- IRBuilderBase &Builder) {
+static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
+ DomTreeUpdater *DTU) {
auto *Cmp = dyn_cast<CmpIntrinsic>(SI->getCondition());
- if (!Cmp || !Cmp->hasOneUse() || SI->getNumCases() != 2)
+ if (!Cmp || !Cmp->hasOneUse())
return false;
- // Find which of 1, 0 or -1 is missing.
- SmallSet<int64_t, 3> Missing;
- Missing.insert(1);
- Missing.insert(0);
- Missing.insert(-1);
- BasicBlock *Succ = nullptr;
- for (auto &Case : SI->cases()) {
- std::optional<int64_t> Val = Case.getCaseValue()->getValue().trySExtValue();
- if (!Val)
- return false;
- if (!Missing.erase(*Val))
- return false;
- if (Succ && Succ != Case.getCaseSuccessor())
- return false;
- Succ = Case.getCaseSuccessor();
+ // Normalize to [us]cmp == Res ? Succ : OtherSucc.
+ int64_t Res;
+ BasicBlock *Succ, *OtherSucc;
+ BasicBlock *Unreachable = nullptr;
+
+ if (SI->getNumCases() == 2) {
+ // Find which of 1, 0 or -1 is missing (handled by default dest).
+ SmallSet<int64_t, 3> Missing;
+ Missing.insert(1);
+ Missing.insert(0);
+ Missing.insert(-1);
+
+ Succ = SI->getDefaultDest();
+ OtherSucc = nullptr;
+ for (auto &Case : SI->cases()) {
+ std::optional<int64_t> Val =
+ Case.getCaseValue()->getValue().trySExtValue();
+ if (!Val)
+ return false;
+ if (!Missing.erase(*Val))
+ return false;
+ if (OtherSucc && OtherSucc != Case.getCaseSuccessor())
+ return false;
+ OtherSucc = Case.getCaseSuccessor();
+ }
+
+ assert(Missing.size() == 1 && "Should have one case left");
+ Res = *Missing.begin();
+ } else if (SI->getNumCases() == 3 && SI->defaultDestUndefined()) {
+ // Normalize so that Succ is taken once and OtherSucc twice.
+ Unreachable = SI->getDefaultDest();
+ Succ = OtherSucc = nullptr;
+ for (auto &Case : SI->cases()) {
+ BasicBlock *NewSucc = Case.getCaseSuccessor();
+ if (!OtherSucc || OtherSucc == NewSucc)
+ OtherSucc = NewSucc;
+ else if (!Succ)
+ Succ = NewSucc;
+ else if (Succ == NewSucc)
+ std::swap(Succ, OtherSucc);
+ else
+ return false;
+ }
+ for (auto &Case : SI->cases()) {
+ std::optional<int64_t> Val =
+ Case.getCaseValue()->getValue().trySExtValue();
+ if (!Val || (Val != 1 && Val != 0 && Val != -1))
+ return false;
+ if (Case.getCaseSuccessor() == Succ)
+ Res = *Val;
+ }
+ } else {
+ return false;
}
// Determine predicate for the missing case.
ICmpInst::Predicate Pred;
- assert(Missing.size() == 1 && "Should have one case left");
- switch (*Missing.begin()) {
+ switch (Res) {
case 1:
Pred = ICmpInst::ICMP_UGT;
break;
@@ -7173,14 +7210,17 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI,
if (Cmp->isSigned())
Pred = ICmpInst::getSignedPredicate(Pred);
- // The dominator tree does not change, because it treats multi-edges like
- // a single edge anyway.
+ BasicBlock *BB = SI->getParent();
Builder.SetInsertPoint(SI->getIterator());
Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
- Builder.CreateCondBr(ICmp, SI->getDefaultDest(), Succ);
- Succ->removePredecessor(SI->getParent());
+ Builder.CreateCondBr(ICmp, Succ, OtherSucc);
+ OtherSucc->removePredecessor(BB);
+ if (Unreachable)
+ Unreachable->removePredecessor(BB);
SI->eraseFromParent();
Cmp->eraseFromParent();
+ if (DTU && Unreachable)
+ DTU->applyUpdates({{DominatorTree::Delete, BB, Unreachable}});
return true;
}
@@ -7216,7 +7256,7 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
if (eliminateDeadSwitchCases(SI, DTU, Options.AC, DL))
return requestResimplify();
- if (simplifySwitchOfCmpIntrinsic(SI, Builder))
+ if (simplifySwitchOfCmpIntrinsic(SI, Builder, DTU))
return requestResimplify();
if (trySwitchToSelect(SI, Builder, DTU, DL, TTI))
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
index a9b7a0b9e2f096..d43a1d0dd7e05a 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
@@ -353,5 +353,273 @@ bb2:
ret void
}
+define void @ucmp_gt_unreachable(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb2
+ i8 0, label %bb2
+ i8 1, label %bb1
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @ucmp_lt_unreachable(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_lt_unreachable(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb1
+ i8 0, label %bb2
+ i8 1, label %bb2
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @ucmp_eq_unreachable(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_eq_unreachable(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb2
+ i8 0, label %bb1
+ i8 1, label %bb2
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @ucmp_gt_unreachable_multi_edge(i8 %x, i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable_multi_edge(
+; CHECK-SAME: i8 [[X:%.*]], i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: switch i8 [[X]], label %[[UNREACHABLE:.*]] [
+; CHECK-NEXT: i8 0, label %[[SW:.*]]
+; CHECK-NEXT: i8 1, label %[[BB1:.*]]
+; CHECK-NEXT: ]
+; CHECK: [[SW]]:
+; CHECK-NEXT: [[TMP0:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP0]], label %[[BB1]], label %[[BB2:.*]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+; CHECK: [[UNREACHABLE]]:
+; CHECK-NEXT: unreachable
+;
+entry:
+ switch i8 %x, label %unreachable [
+ i8 0, label %sw
+ i8 1, label %bb1
+ ]
+
+sw:
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb2
+ i8 0, label %bb2
+ i8 1, label %bb1
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ %phi = phi i32 [ 0, %entry ], [ 1, %sw ]
+ unreachable
+}
+
+define void @ucmp_gt_unreachable_wrong_case(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable_wrong_case(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: switch i8 [[RES]], label %[[UNREACHABLE:.*]] [
+; CHECK-NEXT: i8 -2, label %[[BB2:.*]]
+; CHECK-NEXT: i8 0, label %[[BB2]]
+; CHECK-NEXT: i8 1, label %[[BB1:.*]]
+; CHECK-NEXT: ]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+; CHECK: [[UNREACHABLE]]:
+; CHECK-NEXT: unreachable
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -2, label %bb2
+ i8 0, label %bb2
+ i8 1, label %bb1
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @ucmp_gt_unreachable_no_two_equal_cases(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable_no_two_equal_cases(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: switch i8 [[RES]], label %[[UNREACHABLE:.*]] [
+; CHECK-NEXT: i8 -1, label %[[BB3:.*]]
+; CHECK-NEXT: i8 0, label %[[BB2:.*]]
+; CHECK-NEXT: i8 1, label %[[BB1:.*]]
+; CHECK-NEXT: ]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB3]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+; CHECK: [[UNREACHABLE]]:
+; CHECK-NEXT: unreachable
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb3
+ i8 0, label %bb2
+ i8 1, label %bb1
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb3:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @ucmp_gt_unreachable_three_equal_cases(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable_three_equal_cases(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[BB1:.*:]]
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb1
+ i8 0, label %bb1
+ i8 1, label %bb1
+ ]
+
+bb1:
+ call void @foo()
+ ret void
+
+unreachable:
+ unreachable
+}
+
+define void @ucmp_gt_unreachable_default_not_unreachable(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable_default_not_unreachable(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[RES:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
+; CHECK-NEXT: switch i8 [[RES]], label %[[NOT_UNREACHABLE:.*]] [
+; CHECK-NEXT: i8 -1, label %[[BB2:.*]]
+; CHECK-NEXT: i8 0, label %[[BB2]]
+; CHECK-NEXT: i8 1, label %[[BB1:.*]]
+; CHECK-NEXT: ]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+; CHECK: [[NOT_UNREACHABLE]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %not.unreachable [
+ i8 -1, label %bb2
+ i8 0, label %bb2
+ i8 1, label %bb1
+ ]
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+not.unreachable:
+ call void @foo()
+ br label %bb2
+}
+
declare void @use(i8)
declare void @foo()
>From 3c10169ee827b941ec074d3a7fc67f6ede47318b Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 22 Aug 2024 14:24:06 +0200
Subject: [PATCH 3/4] Preserve branch weights
---
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 28 +++++--
.../Transforms/SimplifyCFG/switch-on-cmp.ll | 83 +++++++++++++++++++
2 files changed, 106 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 98a07e79640214..714f5563447ec9 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7139,9 +7139,15 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
if (!Cmp || !Cmp->hasOneUse())
return false;
+ SmallVector<uint32_t, 4> Weights;
+ bool HasWeights = extractBranchWeights(getBranchWeightMDNode(*SI), Weights);
+ if (!HasWeights)
+ Weights.resize(4); // Avoid checking HasWeights everywhere.
+
// Normalize to [us]cmp == Res ? Succ : OtherSucc.
int64_t Res;
BasicBlock *Succ, *OtherSucc;
+ uint32_t SuccWeight = 0, OtherSuccWeight = 0;
BasicBlock *Unreachable = nullptr;
if (SI->getNumCases() == 2) {
@@ -7152,6 +7158,7 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
Missing.insert(-1);
Succ = SI->getDefaultDest();
+ SuccWeight = Weights[0];
OtherSucc = nullptr;
for (auto &Case : SI->cases()) {
std::optional<int64_t> Val =
@@ -7163,6 +7170,7 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
if (OtherSucc && OtherSucc != Case.getCaseSuccessor())
return false;
OtherSucc = Case.getCaseSuccessor();
+ OtherSuccWeight += Weights[Case.getSuccessorIndex()];
}
assert(Missing.size() == 1 && "Should have one case left");
@@ -7173,13 +7181,17 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
Succ = OtherSucc = nullptr;
for (auto &Case : SI->cases()) {
BasicBlock *NewSucc = Case.getCaseSuccessor();
- if (!OtherSucc || OtherSucc == NewSucc)
+ uint32_t Weight = Weights[Case.getSuccessorIndex()];
+ if (!OtherSucc || OtherSucc == NewSucc) {
OtherSucc = NewSucc;
- else if (!Succ)
+ OtherSuccWeight += Weight;
+ } else if (!Succ) {
Succ = NewSucc;
- else if (Succ == NewSucc)
+ SuccWeight = Weight;
+ } else if (Succ == NewSucc) {
std::swap(Succ, OtherSucc);
- else
+ std::swap(SuccWeight, OtherSuccWeight);
+ } else
return false;
}
for (auto &Case : SI->cases()) {
@@ -7210,10 +7222,16 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
if (Cmp->isSigned())
Pred = ICmpInst::getSignedPredicate(Pred);
+ MDNode *NewWeights = nullptr;
+ if (HasWeights)
+ NewWeights = MDBuilder(SI->getContext())
+ .createBranchWeights(SuccWeight, OtherSuccWeight);
+
BasicBlock *BB = SI->getParent();
Builder.SetInsertPoint(SI->getIterator());
Value *ICmp = Builder.CreateICmp(Pred, Cmp->getLHS(), Cmp->getRHS());
- Builder.CreateCondBr(ICmp, Succ, OtherSucc);
+ Builder.CreateCondBr(ICmp, Succ, OtherSucc, NewWeights,
+ SI->getMetadata(LLVMContext::MD_unpredictable));
OtherSucc->removePredecessor(BB);
if (Unreachable)
Unreachable->removePredecessor(BB);
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
index d43a1d0dd7e05a..6230a319495dba 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-on-cmp.ll
@@ -353,6 +353,56 @@ bb2:
ret void
}
+define void @ucmp_gt_unpredictable(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unpredictable(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]], !unpredictable [[META0:![0-9]+]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %bb1 [
+ i8 -1, label %bb2
+ i8 0, label %bb2
+ ], !unpredictable !{}
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+}
+
+define void @ucmp_gt_weights(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_weights(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]], !prof [[PROF1:![0-9]+]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %bb1 [
+ i8 -1, label %bb2
+ i8 0, label %bb2
+ ], !prof !{!"branch_weights", i32 5, i32 10, i32 20}
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+}
+
define void @ucmp_gt_unreachable(i32 %a, i32 %b) {
; CHECK-LABEL: define void @ucmp_gt_unreachable(
; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
@@ -621,5 +671,38 @@ not.unreachable:
br label %bb2
}
+define void @ucmp_gt_unreachable_weights(i32 %a, i32 %b) {
+; CHECK-LABEL: define void @ucmp_gt_unreachable_weights(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[B:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[A]], [[B]]
+; CHECK-NEXT: br i1 [[TMP1]], label %[[BB1:.*]], label %[[BB2:.*]], !prof [[PROF1]]
+; CHECK: [[BB1]]:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: br label %[[BB2]]
+; CHECK: [[BB2]]:
+; CHECK-NEXT: ret void
+;
+ %res = call i8 @llvm.ucmp.i8.i32(i32 %a, i32 %b)
+ switch i8 %res, label %unreachable [
+ i8 -1, label %bb2
+ i8 0, label %bb2
+ i8 1, label %bb1
+ ], !prof !{!"branch_weights", i32 0, i32 10, i32 20, i32 5}
+
+bb1:
+ call void @foo()
+ br label %bb2
+
+bb2:
+ ret void
+
+unreachable:
+ unreachable
+}
+
declare void @use(i8)
declare void @foo()
+;.
+; CHECK: [[META0]] = !{}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 5, i32 30}
+;.
>From 4a0eaadc27df8ef8fb822c0c529647641615ebbc Mon Sep 17 00:00:00 2001
From: Nikita Popov <github at npopov.com>
Date: Thu, 22 Aug 2024 16:05:12 +0200
Subject: [PATCH 4/4] Update llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Co-authored-by: Yingwei Zheng <dtcxzyw at qq.com>
---
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 714f5563447ec9..da4d57f808e9bf 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -7199,8 +7199,10 @@ static bool simplifySwitchOfCmpIntrinsic(SwitchInst *SI, IRBuilderBase &Builder,
Case.getCaseValue()->getValue().trySExtValue();
if (!Val || (Val != 1 && Val != 0 && Val != -1))
return false;
- if (Case.getCaseSuccessor() == Succ)
+ if (Case.getCaseSuccessor() == Succ) {
Res = *Val;
+ break;
+ }
}
} else {
return false;
More information about the llvm-commits
mailing list