[llvm] [DFAJumpThreading] Unify getNextCaseSuccessor (PR #166422)

Hongyu Chen via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 5 00:58:26 PST 2025


https://github.com/XChy updated https://github.com/llvm/llvm-project/pull/166422

>From 05b5f2ea85b3ad8e7c825c632e3ce7b20f343bd2 Mon Sep 17 00:00:00 2001
From: XChy <xxs_chy at outlook.com>
Date: Wed, 5 Nov 2025 03:03:17 +0800
Subject: [PATCH] [DFAJumpThreading] Unify getNextCaseSuccessor

---
 .../Transforms/Scalar/DFAJumpThreading.cpp    | 88 ++++++++++---------
 .../DFAJumpThreading/dfa-unfold-select.ll     | 10 +--
 2 files changed, 50 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
index 66e45ecbde7df..1d906270fe046 100644
--- a/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
@@ -575,7 +575,7 @@ struct AllSwitchPaths {
   AllSwitchPaths(const MainSwitch *MSwitch, OptimizationRemarkEmitter *ORE,
                  LoopInfo *LI, Loop *L)
       : Switch(MSwitch->getInstr()), SwitchBlock(Switch->getParent()), ORE(ORE),
-        LI(LI), SwitchOuterLoop(L) {}
+        DefaultDest(nullptr), LI(LI), SwitchOuterLoop(L) {}
 
   std::vector<ThreadingPath> &getThreadingPaths() { return TPaths; }
   unsigned getNumThreadingPaths() { return TPaths.size(); }
@@ -587,6 +587,30 @@ struct AllSwitchPaths {
     unifyTPaths();
   }
 
+  /// Fast helper to get the successor corresponding to a particular case value
+  /// for a switch statement.
+  BasicBlock *getNextCaseSuccessor(const APInt &NextState) {
+    // Precompute the value => successor mapping
+    if (CaseValToDest.empty()) {
+      for (auto Case : Switch->cases()) {
+        APInt CaseVal = Case.getCaseValue()->getValue();
+        CaseValToDest[CaseVal] = Case.getCaseSuccessor();
+      }
+      DefaultDest = Switch->getDefaultDest();
+    }
+
+    auto SuccIt = CaseValToDest.find(NextState);
+    return SuccIt == CaseValToDest.end() ? DefaultDest : SuccIt->second;
+  }
+
+  void updateDefaultDest(BasicBlock *DefaultDest) {
+    this->DefaultDest = DefaultDest;
+  }
+
+  void updateNextCase(const APInt &NextState, BasicBlock *NextCase) {
+    CaseValToDest[NextState] = NextCase;
+  }
+
 private:
   // Value: an instruction that defines a switch state;
   // Key: the parent basic block of that instruction.
@@ -818,22 +842,6 @@ struct AllSwitchPaths {
     TPaths = std::move(TempList);
   }
 
-  /// Fast helper to get the successor corresponding to a particular case value
-  /// for a switch statement.
-  BasicBlock *getNextCaseSuccessor(const APInt &NextState) {
-    // Precompute the value => successor mapping
-    if (CaseValToDest.empty()) {
-      for (auto Case : Switch->cases()) {
-        APInt CaseVal = Case.getCaseValue()->getValue();
-        CaseValToDest[CaseVal] = Case.getCaseSuccessor();
-      }
-    }
-
-    auto SuccIt = CaseValToDest.find(NextState);
-    return SuccIt == CaseValToDest.end() ? Switch->getDefaultDest()
-                                         : SuccIt->second;
-  }
-
   // Two states are equivalent if they have the same switch destination.
   // Unify the states in different threading path if the states are equivalent.
   void unifyTPaths() {
@@ -858,6 +866,7 @@ struct AllSwitchPaths {
   OptimizationRemarkEmitter *ORE;
   std::vector<ThreadingPath> TPaths;
   DenseMap<APInt, BasicBlock *> CaseValToDest;
+  BasicBlock *DefaultDest;
   LoopInfo *LI;
   Loop *SwitchOuterLoop;
 };
@@ -1159,24 +1168,6 @@ struct TransformDFA {
     SSAUpdate.RewriteAllUses(&DTU->getDomTree());
   }
 
-  /// Helper to get the successor corresponding to a particular case value for
-  /// a switch statement.
-  /// TODO: Unify it with SwitchPaths->getNextCaseSuccessor(SwitchInst *Switch)
-  /// by updating cached value => successor mapping during threading.
-  static BasicBlock *getNextCaseSuccessor(SwitchInst *Switch,
-                                          const APInt &NextState) {
-    BasicBlock *NextCase = nullptr;
-    for (auto Case : Switch->cases()) {
-      if (Case.getCaseValue()->getValue() == NextState) {
-        NextCase = Case.getCaseSuccessor();
-        break;
-      }
-    }
-    if (!NextCase)
-      NextCase = Switch->getDefaultDest();
-    return NextCase;
-  }
-
   /// Clones a basic block, and adds it to the CFG.
   ///
   /// This function also includes updating phi nodes in the successors of the
@@ -1231,8 +1222,7 @@ struct TransformDFA {
     // If BB is the last block in the path, we can simply update the one case
     // successor that will be reached.
     if (BB == SwitchPaths->getSwitchBlock()) {
-      SwitchInst *Switch = SwitchPaths->getSwitchInst();
-      BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
+      BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState);
       BlocksToUpdate.push_back(NextCase);
       BasicBlock *ClonedSucc = getClonedBB(NextCase, NextState, DuplicateMap);
       if (ClonedSucc)
@@ -1283,6 +1273,15 @@ struct TransformDFA {
       return;
 
     Instruction *PrevTerm = PrevBB->getTerminator();
+    // Update cached value => destination mapping.
+    if (PrevTerm == SwitchPaths->getSwitchInst()) {
+      for (auto Case : SwitchPaths->getSwitchInst()->cases())
+        if (Case.getCaseSuccessor() == OldBB)
+          SwitchPaths->updateNextCase(Case.getCaseValue()->getValue(), NewBB);
+      if (SwitchPaths->getSwitchInst()->getDefaultDest() == OldBB)
+        SwitchPaths->updateDefaultDest(NewBB);
+    }
+    // Replace actual successors.
     for (unsigned Idx = 0; Idx < PrevTerm->getNumSuccessors(); Idx++) {
       if (PrevTerm->getSuccessor(Idx) == OldBB) {
         OldBB->removePredecessor(PrevBB, /* KeepOneInputPHIs = */ true);
@@ -1341,17 +1340,20 @@ struct TransformDFA {
     // updated yet
     if (!isa<SwitchInst>(LastBlock->getTerminator()))
       return;
-    SwitchInst *Switch = cast<SwitchInst>(LastBlock->getTerminator());
-    BasicBlock *NextCase = getNextCaseSuccessor(Switch, NextState);
+    assert(BB->getTerminator() == SwitchPaths->getSwitchInst() &&
+           "Original last block must contain the threaded switch");
+    BasicBlock *NextCase = SwitchPaths->getNextCaseSuccessor(NextState);
 
     std::vector<DominatorTree::UpdateType> DTUpdates;
     SmallPtrSet<BasicBlock *, 4> SuccSet;
-    for (BasicBlock *Succ : successors(LastBlock)) {
-      if (Succ != NextCase && SuccSet.insert(Succ).second)
+    for (BasicBlock *Succ : successors(LastBlock))
+      if (SuccSet.insert(Succ).second && Succ != NextCase)
         DTUpdates.push_back({DominatorTree::Delete, LastBlock, Succ});
-    }
 
-    Switch->eraseFromParent();
+    if (!SuccSet.count(NextCase))
+      DTUpdates.push_back({DominatorTree::Insert, LastBlock, NextCase});
+
+    LastBlock->getTerminator()->eraseFromParent();
     BranchInst::Create(NextCase, LastBlock);
 
     DTU->applyUpdates(DTUpdates);
diff --git a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
index 95d3ffaa21b30..0c0b6b5184562 100644
--- a/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
+++ b/llvm/test/Transforms/DFAJumpThreading/dfa-unfold-select.ll
@@ -109,7 +109,7 @@ define i32 @test2(i32 %num) {
 ; CHECK:       for.body.jt3:
 ; CHECK-NEXT:    [[COUNT_JT3:%.*]] = phi i32 [ [[INC_JT3:%.*]], [[FOR_INC_JT3:%.*]] ]
 ; CHECK-NEXT:    [[STATE_JT3:%.*]] = phi i32 [ [[STATE_NEXT_JT3:%.*]], [[FOR_INC_JT3]] ]
-; CHECK-NEXT:    br label [[FOR_INC]]
+; CHECK-NEXT:    br label [[FOR_INC_JT1]]
 ; CHECK:       case1:
 ; CHECK-NEXT:    [[COUNT6:%.*]] = phi i32 [ [[COUNT_JT1]], [[FOR_BODY_JT1:%.*]] ], [ [[COUNT]], [[FOR_BODY]] ]
 ; CHECK-NEXT:    [[CMP_C1:%.*]] = icmp slt i32 [[COUNT6]], 50
@@ -156,8 +156,8 @@ define i32 @test2(i32 %num) {
 ; CHECK-NEXT:    [[DOTSI_UNFOLD_PHI4_JT2:%.*]] = phi i32 [ 2, [[STATE1_1_SI_UNFOLD_TRUE:%.*]] ]
 ; CHECK-NEXT:    br label [[FOR_INC_JT2]]
 ; CHECK:       for.inc:
-; CHECK-NEXT:    [[COUNT5:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ]
-; CHECK-NEXT:    [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ 1, [[FOR_BODY_JT3]] ]
+; CHECK-NEXT:    [[COUNT5:%.*]] = phi i32 [ undef, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_FALSE]] ], [ undef, [[STATE1_2_SI_UNFOLD_FALSE:%.*]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE]] ]
+; CHECK-NEXT:    [[STATE_NEXT]] = phi i32 [ [[STATE2_1_SI_UNFOLD_PHI]], [[STATE2_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_2_SI_UNFOLD_FALSE]] ], [ poison, [[STATE1_1_SI_UNFOLD_TRUE]] ], [ [[DOTSI_UNFOLD_PHI4]], [[STATE1_1_SI_UNFOLD_FALSE]] ]
 ; CHECK-NEXT:    [[INC]] = add nsw i32 [[COUNT5]], 1
 ; CHECK-NEXT:    [[CMP_EXIT:%.*]] = icmp slt i32 [[INC]], [[NUM:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP_EXIT]], label [[FOR_BODY]], label [[FOR_END:%.*]]
@@ -167,8 +167,8 @@ define i32 @test2(i32 %num) {
 ; CHECK-NEXT:    [[CMP_EXIT_JT2:%.*]] = icmp slt i32 [[INC_JT2]], [[NUM]]
 ; CHECK-NEXT:    br i1 [[CMP_EXIT_JT2]], label [[FOR_BODY_JT2:%.*]], label [[FOR_END]]
 ; CHECK:       for.inc.jt1:
-; CHECK-NEXT:    [[COUNT7:%.*]] = phi i32 [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ]
-; CHECK-NEXT:    [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ]
+; CHECK-NEXT:    [[COUNT7:%.*]] = phi i32 [ [[COUNT_JT3]], [[FOR_BODY_JT3:%.*]] ], [ [[COUNT6]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ], [ [[COUNT]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[COUNT]], [[FOR_BODY]] ]
+; CHECK-NEXT:    [[STATE_NEXT_JT1]] = phi i32 [ 1, [[FOR_BODY]] ], [ 1, [[FOR_BODY_JT3]] ], [ [[STATE2_1_SI_UNFOLD_PHI_JT1]], [[STATE2_2_SI_UNFOLD_FALSE_JT1]] ], [ [[DOTSI_UNFOLD_PHI3_JT1]], [[STATE1_1_SI_UNFOLD_TRUE_JT1]] ]
 ; CHECK-NEXT:    [[INC_JT1]] = add nsw i32 [[COUNT7]], 1
 ; CHECK-NEXT:    [[CMP_EXIT_JT1:%.*]] = icmp slt i32 [[INC_JT1]], [[NUM]]
 ; CHECK-NEXT:    br i1 [[CMP_EXIT_JT1]], label [[FOR_BODY_JT1]], label [[FOR_END]]



More information about the llvm-commits mailing list