[llvm] [SimplifyCFG] When only one case value is missing, replace default with that case (PR #76669)

Quentin Dian via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 1 04:38:15 PST 2024


https://github.com/DianQK updated https://github.com/llvm/llvm-project/pull/76669

>From 3316c46ae21db7da683297ec16e3ca77ecb76032 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Mon, 1 Jan 2024 15:46:57 +0800
Subject: [PATCH 1/2] [SimplifyCFG] When only one case value is missing,
 replace default with that case

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 37 +++++++++++++++----
 .../SimplifyCFG/switch-dead-default.ll        | 14 ++++---
 2 files changed, 38 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 55e375670cc61e..a808861e97f3b7 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5414,11 +5414,13 @@ static bool CasesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
 }
 
 static void createUnreachableSwitchDefault(SwitchInst *Switch,
-                                           DomTreeUpdater *DTU) {
+                                           DomTreeUpdater *DTU,
+                                           bool RemoveOrigDefaultBlock = true) {
   LLVM_DEBUG(dbgs() << "SimplifyCFG: switch default is dead.\n");
   auto *BB = Switch->getParent();
   auto *OrigDefaultBlock = Switch->getDefaultDest();
-  OrigDefaultBlock->removePredecessor(BB);
+  if (RemoveOrigDefaultBlock)
+    OrigDefaultBlock->removePredecessor(BB);
   BasicBlock *NewDefaultBlock = BasicBlock::Create(
       BB->getContext(), BB->getName() + ".unreachabledefault", BB->getParent(),
       OrigDefaultBlock);
@@ -5427,7 +5429,8 @@ static void createUnreachableSwitchDefault(SwitchInst *Switch,
   if (DTU) {
     SmallVector<DominatorTree::UpdateType, 2> Updates;
     Updates.push_back({DominatorTree::Insert, BB, &*NewDefaultBlock});
-    if (!is_contained(successors(BB), OrigDefaultBlock))
+    if (RemoveOrigDefaultBlock &&
+        !is_contained(successors(BB), OrigDefaultBlock))
       Updates.push_back({DominatorTree::Delete, BB, &*OrigDefaultBlock});
     DTU->applyUpdates(Updates);
   }
@@ -5609,10 +5612,30 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
       Known.getBitWidth() - (Known.Zero | Known.One).popcount();
   assert(NumUnknownBits <= Known.getBitWidth());
   if (HasDefault && DeadCases.empty() &&
-      NumUnknownBits < 64 /* avoid overflow */ &&
-      SI->getNumCases() == (1ULL << NumUnknownBits)) {
-    createUnreachableSwitchDefault(SI, DTU);
-    return true;
+      NumUnknownBits < 64 /* avoid overflow */) {
+    uint64_t AllNumCases = 1ULL << NumUnknownBits;
+    if (SI->getNumCases() == AllNumCases) {
+      createUnreachableSwitchDefault(SI, DTU);
+      return true;
+    }
+    // When only one case value is missing, replace default with that case.
+    if (SI->getNumCases() == AllNumCases - 1) {
+      uint64_t MissingCaseVal = 0;
+      for (const auto &Case : SI->cases())
+        MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
+      for (uint64_t I = 0; I < AllNumCases; I++)
+        MissingCaseVal ^= I;
+      auto *MissingCase =
+          cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal));
+      SI->addCase(MissingCase, SI->getDefaultDest());
+      SwitchInstProfUpdateWrapper SIW(*SI);
+      auto DefaultCaseWeight = SIW.getSuccessorWeight(0);
+      SIW.setSuccessorWeight(
+          SI->findCaseValue(MissingCase)->getSuccessorIndex(),
+          DefaultCaseWeight);
+      createUnreachableSwitchDefault(SI, DTU, false);
+      return true;
+    }
   }
 
   if (DeadCases.empty())
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
index 1662bb99f27bcc..c0893a4893bfc6 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
@@ -77,14 +77,14 @@ default:
   ret void
 }
 
-; This one is a negative test - we know the value of the default,
-; but that's about it
+; We can replace the default branch with case 3 since it is the only case that is missing.
 define void @test3(i2 %a) {
 ; CHECK-LABEL: @test3(
-; CHECK-NEXT:    switch i2 [[A:%.*]], label [[DEFAULT:%.*]] [
-; CHECK-NEXT:    i2 0, label [[CASE0:%.*]]
-; CHECK-NEXT:    i2 1, label [[CASE1:%.*]]
-; CHECK-NEXT:    i2 -2, label [[CASE2:%.*]]
+; CHECK-NEXT:    switch i2 [[A:%.*]], label [[DOTUNREACHABLEDEFAULT:%.*]] [
+; CHECK-NEXT:      i2 0, label [[CASE0:%.*]]
+; CHECK-NEXT:      i2 1, label [[CASE1:%.*]]
+; CHECK-NEXT:      i2 -2, label [[CASE2:%.*]]
+; CHECK-NEXT:      i2 -1, label [[DEFAULT:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       common.ret:
 ; CHECK-NEXT:    ret void
@@ -97,6 +97,8 @@ define void @test3(i2 %a) {
 ; CHECK:       case2:
 ; CHECK-NEXT:    call void @foo(i32 2)
 ; CHECK-NEXT:    br label [[COMMON_RET]]
+; CHECK:       .unreachabledefault:
+; CHECK-NEXT:    unreachable
 ; CHECK:       default:
 ; CHECK-NEXT:    call void @foo(i32 3)
 ; CHECK-NEXT:    br label [[COMMON_RET]]

>From 44b251b9de2b6737a6213f41f1cf92ec692b2004 Mon Sep 17 00:00:00 2001
From: DianQK <dianqk at dianqk.net>
Date: Mon, 1 Jan 2024 20:05:36 +0800
Subject: [PATCH 2/2] fixup! [SimplifyCFG] When only one case value is missing,
 replace default with that case

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp              | 10 +++-------
 .../test/Transforms/SimplifyCFG/switch-dead-default.ll |  5 +++--
 2 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index a808861e97f3b7..f0b791e9bc110f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5620,20 +5620,16 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
     }
     // When only one case value is missing, replace default with that case.
     if (SI->getNumCases() == AllNumCases - 1) {
+      assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
       uint64_t MissingCaseVal = 0;
       for (const auto &Case : SI->cases())
         MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
-      for (uint64_t I = 0; I < AllNumCases; I++)
-        MissingCaseVal ^= I;
       auto *MissingCase =
           cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal));
-      SI->addCase(MissingCase, SI->getDefaultDest());
       SwitchInstProfUpdateWrapper SIW(*SI);
-      auto DefaultCaseWeight = SIW.getSuccessorWeight(0);
-      SIW.setSuccessorWeight(
-          SI->findCaseValue(MissingCase)->getSuccessorIndex(),
-          DefaultCaseWeight);
+      SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0));
       createUnreachableSwitchDefault(SI, DTU, false);
+      SIW.setSuccessorWeight(0, 0);
       return true;
     }
   }
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
index c0893a4893bfc6..ff639bc4419e0a 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll
@@ -85,7 +85,7 @@ define void @test3(i2 %a) {
 ; CHECK-NEXT:      i2 1, label [[CASE1:%.*]]
 ; CHECK-NEXT:      i2 -2, label [[CASE2:%.*]]
 ; CHECK-NEXT:      i2 -1, label [[DEFAULT:%.*]]
-; CHECK-NEXT:    ]
+; CHECK-NEXT:    ], !prof [[PROF0:![0-9]+]]
 ; CHECK:       common.ret:
 ; CHECK-NEXT:    ret void
 ; CHECK:       case0:
@@ -105,7 +105,7 @@ define void @test3(i2 %a) {
 ;
   switch i2 %a, label %default [i2 0, label %case0
   i2 1, label %case1
-  i2 2, label %case2]
+  i2 2, label %case2], !prof !0
 
 case0:
   call void @foo(i32 0)
@@ -262,3 +262,4 @@ default:
 
 declare void @llvm.assume(i1)
 
+!0 = !{!"branch_weights", i32 8, i32 4, i32 2, i32 1}



More information about the llvm-commits mailing list