[llvm] [DFAJumpThreading] Unify getNextCaseSuccessor (PR #166422)
Hongyu Chen via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 5 00:58:26 PST 2025
https://github.com/XChy updated https://github.com/llvm/llvm-project/pull/166422
>From 05b5f2ea85b3ad8e7c825c632e3ce7b20f343bd2 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Wed, 5 Nov 2025 03:03:17 +0800
Subject: [PATCH] [DFAJumpThreading] Unify getNextCaseSuccessor
---
.../Transforms/Scalar/DFAJumpThreading.cpp | 88 ++++++++++---------
.../DFAJumpThreading/dfa-unfold-select.ll | 10 +--
2 files changed, 50 insertions(+), 48 deletions(-)
diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
index 66e45ecbde7df..1d906270fe046 100644
--- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
@@ -575,7 +575,7 @@ struct AllSwitchPaths {
AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE,
LoopInfo *LI, Loop *L)
: Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), ORE(ORE),
- LI(LI), SwitchOuterLoop(L) {}
+ DefaultDest(nullptr), LI(LI), SwitchOuterLoop(L) {}
std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; }
unsigned getNumThreadingPaths() { return TPaths.size(); }
@@ -587,6 +587,30 @@ struct AllSwitchPaths {
unifyTPaths();
}
+ /// Fast helper to get the successor corresponding to a particular case value
+ /// for a switch statement.
+ BasicBlock *getNextCaseSuccessor(const APInt &NextState) {
+ // Precompute the value => successor mapping
+ if (CaseValToDest.empty()) {
+ for (auto Case : Switch->cases()) {
+ APInt CaseVal = Case.getCaseValue()->getValue();
+ CaseValToDest[CaseVal] = Case.getCaseSuccessor();
+ }
+ DefaultDest = Switch->getDefaultDest();
+ }
+
+ auto SuccIt = CaseValToDest.find(NextState);
+ return SuccIt == CaseValToDest.end() ? DefaultDest : SuccIt->second;
+ }
+
+ void updateDefaultDest(BasicBlock *DefaultDest) {
+ this->DefaultDest = DefaultDest;
+ }
+
+ void updateNextCase(const APInt &NextState, BasicBlock *NextCase) {
+ CaseValToDest[NextState] = NextCase;
+ }
+
private:
// Value: an instruction that defines a switch state;
// Key: the parent basic block of that instruction.
@@ -818,22 +842,6 @@ struct AllSwitchPaths {
TPaths = std::move(TempList);
}
- /// Fast helper to get the successor corresponding to a particular case value
- /// for a switch statement.
- BasicBlock *getNextCaseSuccessor(const APInt &NextState) {
- // Precompute the value => successor mapping
- if (CaseValToDest.empty()) {
- for (auto Case : Switch->cases()) {
- APInt CaseVal = Case.getCaseValue()->getValue();
- CaseValToDest[CaseVal] = Case.getCaseSuccessor();
- }
- }
-
- auto SuccIt = CaseValToDest.find(NextState);
- return SuccIt == CaseValToDest.end() ? Switch->getDefaultDest()
- : SuccIt->second;
- }
-
// Two states are equivalent if they have the same switch destination.
// Unify the states in different threading path if the states are equivalent.
void unifyTPaths() {
@@ -858,6 +866,7 @@ struct AllSwitchPaths {
OptimizationRemarkEmitter *ORE;
std::vector<ThreadingPath> TPaths;
DenseMap<APInt, BasicBlock *> CaseValToDest;
+ BasicBlock *DefaultDest;
LoopInfo *LI;
Loop *SwitchOuterLoop;
};
@@ -1159,24 +1168,6 @@ struct TransformDFA {
SSAUpdate.RewriteAllUses(&DTU->getDomTree());
}
- /// Helper to get the successor corresponding to a particular case value for
- /// a switch statement.
- /// TODO: Unify it with SwitchPaths->getNextCaseSuccessor(SwitchInst *Switch)
- /// by updating cached value => successor mapping during threading.
- static BasicBlock *getNextCaseSuccessor(SwitchInst *Switch,
- const APInt &NextState) {
- BasicBlock *NextCase = nullptr;
- for (auto Case : Switch->cases()) {
- if (Case.getCaseValue()->getValue() == NextState) {
- NextCase = Case.getCaseSuccessor();
- break;
- }
- }
- if (!NextCase)
- NextCase = Switch->getDefaultDest();
- return NextCase;
- }
-
/// Clones a basic block, and adds it to the CFG.
///
/// This function also includes updating phi nodes in the successors of the
@@ -1231,8 +1222,7 @@ struct TransformDFA {
// If BB is the last block in the path, we can simply update the one case
// successor that will be reached.
if (BB == SwitchPaths->getSwitchBlock()) {
- SwitchInst *Switch = SwitchPaths->getSwitchInst();
- BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
+ BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState);
BlocksToUpdate.push_back(NextCase);
BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
if (ClonedSucc)
@@ -1283,6 +1273,15 @@ struct TransformDFA {
return;
Instruction *PrevTerm = PrevBB->getTerminator();
+ // Update cached value => destination mapping.
+ if (PrevTerm == SwitchPaths->getSwitchInst()) {
+ for (auto Case : SwitchPaths->getSwitchInst()->cases())
+ if (Case.getCaseSuccessor() == OldBB)
+ SwitchPaths->updateNextCase(Case.getCaseValue()->getValue(), NewBB);
+ if (SwitchPaths->getSwitchInst()->getDefaultDest() == OldBB)
+ SwitchPaths->updateDefaultDest(NewBB);
+ }
+ // Replace actual successors.
for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) {
if (PrevTerm->getSuccessor(Idx) == OldBB) {
OldBB->removePredecessor(PrevBB, /* KeepOneInputPHIs = */ true);
@@ -1341,17 +1340,20 @@ struct TransformDFA {
// updated yet
if (!isa<SwitchInst>(LastBlock->getTerminator()))
return;
- SwitchInst *Switch = cast<SwitchInst>(LastBlock->getTerminator());
- BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
+ assert(BB->getTerminator() == SwitchPaths->getSwitchInst() &&
+ "Original last block must contain the threaded switch");
+ BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState);
std::vector<DominatorTree::UpdateType> DTUpdates;
SmallPtrSet<BasicBlock *, 4> SuccSet;
- for (BasicBlock *Succ : successors(LastBlock)) {
- if (Succ != NextCase && SuccSet.insert(Succ).second)
+ for (BasicBlock *Succ : successors(LastBlock))
+ if (SuccSet.insert(Succ).second && Succ != NextCase)
DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ});
- }
- Switch->eraseFromParent();
+ if (!SuccSet.count(NextCase))
+ DTUpdates.push_back({DominatorTree::Insert, LastBlock, NextCase});
+
+ LastBlock->getTerminator()->eraseFromParent();
BranchInst::Create(NextCase, LastBlock);
DTU->applyUpdates(DTUpdates);
diff --git a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
index 95d3ffaa21b30..0c0b6b5184562 100644
--- a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
+++ b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
@@ -109,7 +109,7 @@ define i32 @test2(i32 %num) {
; CHECK: for.body.jt3:
; CHECK-NEXT: [[COUNT_JT3:%.*]] = phi i32 [ [[INC_JT3:%.*]], [[FOR_INC_JT3:%.*]] ]
; CHECK-NEXT: [[STATE_JT3:%.*]] = phi i32 [ [[STATE_NEXT_JT3:%.*]], [[FOR_INC_JT3]] ]
-; CHECK-NEXT: br label [[FOR_INC]]
+; CHECK-NEXT: br label [[FOR_INC_JT1]]
; CHECK: case1:
; CHECK-NEXT: [[COUNT6:%.*]] = phi i32 [ [[COUNT_JT1]], [[FOR_BODY_JT1:%.*]] ], [ [[COUNT]], [[FOR_BODY]] ]
; CHECK-NEXT: [[CMP_C1:%.*]] = icmp slt i32 [[COUNT6]], 50
@@ -156,8 +156,8 @@ define i32 @test2(i32 %num) {
; CHECK-NEXT: [[DOTSI_UNFOLD_PHI4_JT2:%.*]] = phi i32 [ 2, [[STATE1_1_SI_UNFOLD_TRUE:%.*]] ]
; CHECK-NEXT: br label [[FOR_INC_JT2]]
; CHECK: for.inc:
-; CHECK-NEXT: [[COUNT5:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ]
-; CHECK-NEXT: [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ 1, [[FOR_BODY_JT3]] ]
+; CHECK-NEXT: [[COUNT5:%.*]] = phi i32 [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ]
+; CHECK-NEXT: [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ]
; CHECK-NEXT: [[INC]] = add nsw i32 [[COUNT5]], 1
; CHECK-NEXT: [[CMP_EXIT:%.*]] = icmp slt i32 [[INC]], [[NUM:%.*]]
; CHECK-NEXT: br i1 [[CMP_EXIT]], label [[FOR_BODY]], label [[FOR_END:%.*]]
@@ -167,8 +167,8 @@ define i32 @test2(i32 %num) {
; CHECK-NEXT: [[CMP_EXIT_JT2:%.*]] = icmp slt i32 [[INC_JT2]], [[NUM]]
; CHECK-NEXT: br i1 [[CMP_EXIT_JT2]], label [[FOR_BODY_JT2:%.*]], label [[FOR_END]]
; CHECK: for.inc.jt1:
-; CHECK-NEXT: [[COUNT7:%.*]] = phi i32 [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ]
-; CHECK-NEXT: [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ]
+; CHECK-NEXT: [[COUNT7:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ]
+; CHECK-NEXT: [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ 1, [[FOR_BODY_JT3]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ]
; CHECK-NEXT: [[INC_JT1]] = add nsw i32 [[COUNT7]], 1
; CHECK-NEXT: [[CMP_EXIT_JT1:%.*]] = icmp slt i32 [[INC_JT1]], [[NUM]]
; CHECK-NEXT: br i1 [[CMP_EXIT_JT1]], label [[FOR_BODY_JT1]], label [[FOR_END]]
More information about the llvm-commits
mailing list