[llvm] [SimplifyCFG] When only one case value is missing, replace default with that case (PR #76669)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 1 00:55:51 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Quentin Dian (DianQK)
<details>
<summary>Changes</summary>
Closes #<!-- -->73446.
---
Full diff: https://github.com/llvm/llvm-project/pull/76669.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+30-7)
- (modified) llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll (+8-6)
``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 55e375670cc61e..a808861e97f3b7 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5414,11 +5414,13 @@ static bool CasesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
}
static void createUnreachableSwitchDefault(SwitchInst *Switch,
- DomTreeUpdater *DTU) {
+ DomTreeUpdater *DTU,
+ bool RemoveOrigDefaultBlock = true) {
LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n");
auto *BB = Switch->getParent();
auto *OrigDefaultBlock = Switch->getDefaultDest();
- OrigDefaultBlock->removePredecessor(BB);
+ if (RemoveOrigDefaultBlock)
+ OrigDefaultBlock->removePredecessor(BB);
BasicBlock *NewDefaultBlock = BasicBlock::Create(
BB->getContext(), BB->getName() + ".unreachabledefault", BB->getParent(),
OrigDefaultBlock);
@@ -5427,7 +5429,8 @@ static void createUnreachableSwitchDefault(SwitchInst *Switch,
if (DTU) {
SmallVector<DominatorTree::UpdateType, 2> Updates;
Updates.push_back({DominatorTree::Insert, BB, &*NewDefaultBlock});
- if (!is_contained(successors(BB), OrigDefaultBlock))
+ if (RemoveOrigDefaultBlock &&
+ !is_contained(successors(BB), OrigDefaultBlock))
Updates.push_back({DominatorTree::Delete, BB, &*OrigDefaultBlock});
DTU->applyUpdates(Updates);
}
@@ -5609,10 +5612,30 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
Known.getBitWidth() - (Known.Zero | Known.One).popcount();
assert(NumUnknownBits <= Known.getBitWidth());
if (HasDefault && DeadCases.empty() &&
- NumUnknownBits < 64 /* avoid overflow */ &&
- SI->getNumCases() == (1ULL << NumUnknownBits)) {
- createUnreachableSwitchDefault(SI, DTU);
- return true;
+ NumUnknownBits < 64 /* avoid overflow */) {
+ uint64_t AllNumCases = 1ULL << NumUnknownBits;
+ if (SI->getNumCases() == AllNumCases) {
+ createUnreachableSwitchDefault(SI, DTU);
+ return true;
+ }
+ // When only one case value is missing, replace default with that case.
+ if (SI->getNumCases() == AllNumCases - 1) {
+ uint64_t MissingCaseVal = 0;
+ for (const auto &Case : SI->cases())
+ MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
+ for (uint64_t I = 0; I < AllNumCases; I++)
+ MissingCaseVal ^= I;
+ auto *MissingCase =
+ cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal));
+ SI->addCase(MissingCase, SI->getDefaultDest());
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ auto DefaultCaseWeight = SIW.getSuccessorWeight(0);
+ SIW.setSuccessorWeight(
+ SI->findCaseValue(MissingCase)->getSuccessorIndex(),
+ DefaultCaseWeight);
+ createUnreachableSwitchDefault(SI, DTU, false);
+ return true;
+ }
}
if (DeadCases.empty())
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
index 1662bb99f27bcc..c0893a4893bfc6 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
@@ -77,14 +77,14 @@ default:
ret void
}
-; This one is a negative test - we know the value of the default,
-; but that's about it
+; We can replace the default branch with case 3 since it is the only case that is missing.
define void @test3(i2 %a) {
; CHECK-LABEL: @test3(
-; CHECK-NEXT: switch i2 [[A:%.*]], label [[DEFAULT:%.*]] [
-; CHECK-NEXT: i2 0, label [[CASE0:%.*]]
-; CHECK-NEXT: i2 1, label [[CASE1:%.*]]
-; CHECK-NEXT: i2 -2, label [[CASE2:%.*]]
+; CHECK-NEXT: switch i2 [[A:%.*]], label [[DOTUNREACHABLEDEFAULT:%.*]] [
+; CHECK-NEXT: i2 0, label [[CASE0:%.*]]
+; CHECK-NEXT: i2 1, label [[CASE1:%.*]]
+; CHECK-NEXT: i2 -2, label [[CASE2:%.*]]
+; CHECK-NEXT: i2 -1, label [[DEFAULT:%.*]]
; CHECK-NEXT: ]
; CHECK: common.ret:
; CHECK-NEXT: ret void
@@ -97,6 +97,8 @@ define void @test3(i2 %a) {
; CHECK: case2:
; CHECK-NEXT: call void @foo(i32 2)
; CHECK-NEXT: br label [[COMMON_RET]]
+; CHECK: .unreachabledefault:
+; CHECK-NEXT: unreachable
; CHECK: default:
; CHECK-NEXT: call void @foo(i32 3)
; CHECK-NEXT: br label [[COMMON_RET]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/76669
More information about the llvm-commits
mailing list