[llvm] [SimplifyCFG] Eliminate dead edges of switches according to the domain of conditions (PR #165748)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 30 10:19:42 PDT 2025


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/165748

llvm-opt-benchmark: https://github.com/dtcxzyw/llvm-opt-benchmark/pull/2986
Closes https://github.com/llvm/llvm-project/issues/165179.


>From efe45cf2e1b879e19597a1133d785f8ee72fd540 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 28 Oct 2025 01:21:49 +0800
Subject: [PATCH] test

---
 llvm/include/llvm/Analysis/ValueTracking.h |  8 +++
 llvm/lib/Analysis/ValueTracking.cpp        | 66 +++++++++++++++++++++
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp  | 67 +++++++++++++---------
 3 files changed, 114 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index af218ba564081..ef5500b7bf0b6 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1024,6 +1024,14 @@ findValuesAffectedByCondition(Value *Cond, bool IsAssume,
 LLVM_ABI Value *stripNullTest(Value *V);
 LLVM_ABI const Value *stripNullTest(const Value *V);
 
+/// Enumerates all possible values of V.
+/// Return true if the result is complete.
+/// Otherwise, the result is invalid.
+LLVM_ABI bool
+collectPossibleValues(const Value *V,
+                      SmallPtrSetImpl<const Constant *> &Constants,
+                      unsigned MaxCount, unsigned Depth = 0);
+
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_VALUETRACKING_H
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0a72076f51824..9cc7683fd9561 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -10405,3 +10405,69 @@ const Value *llvm::stripNullTest(const Value *V) {
 Value *llvm::stripNullTest(Value *V) {
   return const_cast<Value *>(stripNullTest(const_cast<const Value *>(V)));
 }
+
+bool llvm::collectPossibleValues(const Value *V,
+                                 SmallPtrSetImpl<const Constant *> &Constants,
+                                 unsigned MaxCount, unsigned Depth) {
+  auto Add = [&](const Constant *C) -> bool {
+    if (Constants.contains(C))
+      return true;
+    if (Constants.size() == MaxCount)
+      return false;
+    Constants.insert(C);
+    return true;
+  };
+  if (auto *C = dyn_cast<Constant>(V))
+    return Add(C);
+  if (Depth++ == MaxAnalysisRecursionDepth)
+    return false;
+  auto *Inst = dyn_cast<Instruction>(V);
+  if (!Inst)
+    return false;
+  Type *Ty = Inst->getType();
+  switch (Inst->getOpcode()) {
+  case Instruction::Select: {
+    if (!collectPossibleValues(Inst->getOperand(0), Constants, Depth))
+      return false;
+    if (!collectPossibleValues(Inst->getOperand(1), Constants, Depth))
+      return false;
+    return true;
+  }
+  case Instruction::PHI: {
+    for (Value *IncomingValue : cast<PHINode>(Inst)->incoming_values()) {
+      if (IncomingValue == Inst)
+        continue;
+      if (auto *C = dyn_cast<Constant>(IncomingValue)) {
+        if (!Add(C))
+          return false;
+        continue;
+      }
+      return false;
+    }
+    return true;
+  }
+  case Instruction::And: {
+    const APInt *Bit;
+    if (match(Inst->getOperand(1), m_Power2(Bit))) {
+      Constant *Zero = ConstantInt::getNullValue(Ty);
+      Constant *BitVal = ConstantInt::get(Ty, *Bit);
+      return Add(Zero) && Add(BitVal);
+    }
+    break;
+  }
+  case Instruction::ZExt:
+  case Instruction::SExt: {
+    if (Inst->getOperand(0)->getType()->isIntegerTy(1)) {
+      Constant *Zero = ConstantInt::getNullValue(Ty);
+      Constant *One = ConstantInt::get(Ty, APInt(Ty->getScalarSizeInBits(),
+                                                 isa<ZExtInst>(Inst) ? 1 : -1,
+                                                 /*IsSigned=*/true));
+      return Add(Zero) && Add(One);
+    }
+    break;
+  }
+  default:
+    break;
+  }
+  return false;
+}
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index c537be5cba37c..60c4db50ac9d4 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5990,6 +5990,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
                                      const DataLayout &DL) {
   Value *Cond = SI->getCondition();
   KnownBits Known = computeKnownBits(Cond, DL, AC, SI);
+  SmallPtrSet<const Constant *, 4> KnownValues;
+  bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4);
 
   // We can also eliminate cases by determining that their values are outside of
   // the limited range of the condition based on how many significant (non-sign)
@@ -6009,15 +6011,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
         UniqueSuccessors.push_back(Successor);
       ++It->second;
     }
-    const APInt &CaseVal = Case.getCaseValue()->getValue();
+    ConstantInt *CaseC = Case.getCaseValue();
+    const APInt &CaseVal = CaseC->getValue();
     if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) ||
-        (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) {
-      DeadCases.push_back(Case.getCaseValue());
+        (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) ||
+        (IsKnownValuesValid && !KnownValues.contains(CaseC))) {
+      DeadCases.push_back(CaseC);
       if (DTU)
         --NumPerSuccessorCases[Successor];
       LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal
                         << " is dead.\n");
-    }
+    } else if (IsKnownValuesValid)
+      KnownValues.erase(CaseC);
   }
 
   // If we can prove that the cases must cover all possible values, the
@@ -6028,33 +6033,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
   const unsigned NumUnknownBits =
       Known.getBitWidth() - (Known.Zero | Known.One).popcount();
   assert(NumUnknownBits <= Known.getBitWidth());
-  if (HasDefault && DeadCases.empty() &&
-      NumUnknownBits < 64 /* avoid overflow */) {
-    uint64_t AllNumCases = 1ULL << NumUnknownBits;
-    if (SI->getNumCases() == AllNumCases) {
+  if (HasDefault && DeadCases.empty()) {
+    if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) {
       createUnreachableSwitchDefault(SI, DTU);
       return true;
     }
-    // When only one case value is missing, replace default with that case.
-    // Eliminating the default branch will provide more opportunities for
-    // optimization, such as lookup tables.
-    if (SI->getNumCases() == AllNumCases - 1) {
-      assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
-      IntegerType *CondTy = cast<IntegerType>(Cond->getType());
-      if (CondTy->getIntegerBitWidth() > 64 ||
-          !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
-        return false;
 
-      uint64_t MissingCaseVal = 0;
-      for (const auto &Case : SI->cases())
-        MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
-      auto *MissingCase =
-          cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal));
-      SwitchInstProfUpdateWrapper SIW(*SI);
-      SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0));
-      createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false);
-      SIW.setSuccessorWeight(0, 0);
-      return true;
+    if (NumUnknownBits < 64 /* avoid overflow */) {
+      uint64_t AllNumCases = 1ULL << NumUnknownBits;
+      if (SI->getNumCases() == AllNumCases) {
+        createUnreachableSwitchDefault(SI, DTU);
+        return true;
+      }
+      // When only one case value is missing, replace default with that case.
+      // Eliminating the default branch will provide more opportunities for
+      // optimization, such as lookup tables.
+      if (SI->getNumCases() == AllNumCases - 1) {
+        assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
+        IntegerType *CondTy = cast<IntegerType>(Cond->getType());
+        if (CondTy->getIntegerBitWidth() > 64 ||
+            !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
+          return false;
+
+        uint64_t MissingCaseVal = 0;
+        for (const auto &Case : SI->cases())
+          MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
+        auto *MissingCase = cast<ConstantInt>(
+            ConstantInt::get(Cond->getType(), MissingCaseVal));
+        SwitchInstProfUpdateWrapper SIW(*SI);
+        SIW.addCase(MissingCase, SI->getDefaultDest(),
+                    SIW.getSuccessorWeight(0));
+        createUnreachableSwitchDefault(SI, DTU,
+                                       /*RemoveOrigDefaultBlock*/ false);
+        SIW.setSuccessorWeight(0, 0);
+        return true;
+      }
     }
   }
 



More information about the llvm-commits mailing list