[llvm] [SimplifyCFG] Eliminate dead edges of switches according to the domain of conditions (PR #165748)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 30 10:19:42 PDT 2025
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/165748
llvm-opt-benchmark: https://github.com/dtcxzyw/llvm-opt-benchmark/pull/2986
Closes https://github.com/llvm/llvm-project/issues/165179.
>From efe45cf2e1b879e19597a1133d785f8ee72fd540 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 28 Oct 2025 01:21:49 +0800
Subject: [PATCH] test
---
llvm/include/llvm/Analysis/ValueTracking.h | 8 +++
llvm/lib/Analysis/ValueTracking.cpp | 66 +++++++++++++++++++++
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 67 +++++++++++++---------
3 files changed, 114 insertions(+), 27 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index af218ba564081..ef5500b7bf0b6 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1024,6 +1024,14 @@ findValuesAffectedByCondition(Value *Cond, bool IsAssume,
LLVM_ABI Value *stripNullTest(Value *V);
LLVM_ABI const Value *stripNullTest(const Value *V);
+/// Enumerates all possible values of V.
+/// Return true if the result is complete.
+/// Otherwise, the result is invalid.
+LLVM_ABI bool
+collectPossibleValues(const Value *V,
+ SmallPtrSetImpl<const Constant *> &Constants,
+ unsigned MaxCount, unsigned Depth = 0);
+
} // end namespace llvm
#endif // LLVM_ANALYSIS_VALUETRACKING_H
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0a72076f51824..9cc7683fd9561 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -10405,3 +10405,69 @@ const Value *llvm::stripNullTest(const Value *V) {
Value *llvm::stripNullTest(Value *V) {
return const_cast<Value *>(stripNullTest(const_cast<const Value *>(V)));
}
+
+bool llvm::collectPossibleValues(const Value *V,
+ SmallPtrSetImpl<const Constant *> &Constants,
+ unsigned MaxCount, unsigned Depth) {
+ auto Add = [&](const Constant *C) -> bool {
+ if (Constants.contains(C))
+ return true;
+ if (Constants.size() == MaxCount)
+ return false;
+ Constants.insert(C);
+ return true;
+ };
+ if (auto *C = dyn_cast<Constant>(V))
+ return Add(C);
+ if (Depth++ == MaxAnalysisRecursionDepth)
+ return false;
+ auto *Inst = dyn_cast<Instruction>(V);
+ if (!Inst)
+ return false;
+ Type *Ty = Inst->getType();
+ switch (Inst->getOpcode()) {
+ case Instruction::Select: {
+ if (!collectPossibleValues(Inst->getOperand(0), Constants, Depth))
+ return false;
+ if (!collectPossibleValues(Inst->getOperand(1), Constants, Depth))
+ return false;
+ return true;
+ }
+ case Instruction::PHI: {
+ for (Value *IncomingValue : cast<PHINode>(Inst)->incoming_values()) {
+ if (IncomingValue == Inst)
+ continue;
+ if (auto *C = dyn_cast<Constant>(IncomingValue)) {
+ if (!Add(C))
+ return false;
+ continue;
+ }
+ return false;
+ }
+ return true;
+ }
+ case Instruction::And: {
+ const APInt *Bit;
+ if (match(Inst->getOperand(1), m_Power2(Bit))) {
+ Constant *Zero = ConstantInt::getNullValue(Ty);
+ Constant *BitVal = ConstantInt::get(Ty, *Bit);
+ return Add(Zero) && Add(BitVal);
+ }
+ break;
+ }
+ case Instruction::ZExt:
+ case Instruction::SExt: {
+ if (Inst->getOperand(0)->getType()->isIntegerTy(1)) {
+ Constant *Zero = ConstantInt::getNullValue(Ty);
+ Constant *One = ConstantInt::get(Ty, APInt(Ty->getScalarSizeInBits(),
+ isa<ZExtInst>(Inst) ? 1 : -1,
+ /*IsSigned=*/true));
+ return Add(Zero) && Add(One);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return false;
+}
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index c537be5cba37c..60c4db50ac9d4 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -5990,6 +5990,8 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
const DataLayout &DL) {
Value *Cond = SI->getCondition();
KnownBits Known = computeKnownBits(Cond, DL, AC, SI);
+ SmallPtrSet<const Constant *, 4> KnownValues;
+ bool IsKnownValuesValid = collectPossibleValues(Cond, KnownValues, 4);
// We can also eliminate cases by determining that their values are outside of
// the limited range of the condition based on how many significant (non-sign)
@@ -6009,15 +6011,18 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
UniqueSuccessors.push_back(Successor);
++It->second;
}
- const APInt &CaseVal = Case.getCaseValue()->getValue();
+ ConstantInt *CaseC = Case.getCaseValue();
+ const APInt &CaseVal = CaseC->getValue();
if (Known.Zero.intersects(CaseVal) || !Known.One.isSubsetOf(CaseVal) ||
- (CaseVal.getSignificantBits() > MaxSignificantBitsInCond)) {
- DeadCases.push_back(Case.getCaseValue());
+ (CaseVal.getSignificantBits() > MaxSignificantBitsInCond) ||
+ (IsKnownValuesValid && !KnownValues.contains(CaseC))) {
+ DeadCases.push_back(CaseC);
if (DTU)
--NumPerSuccessorCases[Successor];
LLVM_DEBUG(dbgs() << "SimplifyCFG: switch case " << CaseVal
<< " is dead.\n");
- }
+ } else if (IsKnownValuesValid)
+ KnownValues.erase(CaseC);
}
// If we can prove that the cases must cover all possible values, the
@@ -6028,33 +6033,41 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, DomTreeUpdater *DTU,
const unsigned NumUnknownBits =
Known.getBitWidth() - (Known.Zero | Known.One).popcount();
assert(NumUnknownBits <= Known.getBitWidth());
- if (HasDefault && DeadCases.empty() &&
- NumUnknownBits < 64 /* avoid overflow */) {
- uint64_t AllNumCases = 1ULL << NumUnknownBits;
- if (SI->getNumCases() == AllNumCases) {
+ if (HasDefault && DeadCases.empty()) {
+ if (IsKnownValuesValid && all_of(KnownValues, IsaPred<UndefValue>)) {
createUnreachableSwitchDefault(SI, DTU);
return true;
}
- // When only one case value is missing, replace default with that case.
- // Eliminating the default branch will provide more opportunities for
- // optimization, such as lookup tables.
- if (SI->getNumCases() == AllNumCases - 1) {
- assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
- IntegerType *CondTy = cast<IntegerType>(Cond->getType());
- if (CondTy->getIntegerBitWidth() > 64 ||
- !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
- return false;
- uint64_t MissingCaseVal = 0;
- for (const auto &Case : SI->cases())
- MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
- auto *MissingCase =
- cast<ConstantInt>(ConstantInt::get(Cond->getType(), MissingCaseVal));
- SwitchInstProfUpdateWrapper SIW(*SI);
- SIW.addCase(MissingCase, SI->getDefaultDest(), SIW.getSuccessorWeight(0));
- createUnreachableSwitchDefault(SI, DTU, /*RemoveOrigDefaultBlock*/ false);
- SIW.setSuccessorWeight(0, 0);
- return true;
+ if (NumUnknownBits < 64 /* avoid overflow */) {
+ uint64_t AllNumCases = 1ULL << NumUnknownBits;
+ if (SI->getNumCases() == AllNumCases) {
+ createUnreachableSwitchDefault(SI, DTU);
+ return true;
+ }
+ // When only one case value is missing, replace default with that case.
+ // Eliminating the default branch will provide more opportunities for
+ // optimization, such as lookup tables.
+ if (SI->getNumCases() == AllNumCases - 1) {
+ assert(NumUnknownBits > 1 && "Should be canonicalized to a branch");
+ IntegerType *CondTy = cast<IntegerType>(Cond->getType());
+ if (CondTy->getIntegerBitWidth() > 64 ||
+ !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
+ return false;
+
+ uint64_t MissingCaseVal = 0;
+ for (const auto &Case : SI->cases())
+ MissingCaseVal ^= Case.getCaseValue()->getValue().getLimitedValue();
+ auto *MissingCase = cast<ConstantInt>(
+ ConstantInt::get(Cond->getType(), MissingCaseVal));
+ SwitchInstProfUpdateWrapper SIW(*SI);
+ SIW.addCase(MissingCase, SI->getDefaultDest(),
+ SIW.getSuccessorWeight(0));
+ createUnreachableSwitchDefault(SI, DTU,
+ /*RemoveOrigDefaultBlock*/ false);
+ SIW.setSuccessorWeight(0, 0);
+ return true;
+ }
}
}
More information about the llvm-commits
mailing list