[llvm] [SimplifyCFG] Ensure selects have not been constant folded in `foldSwitchToSelect` (PR #161153)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 29 02:04:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Antonio Frighetto (antoniofrighetto)
<details>
<summary>Changes</summary>
Make sure selects do exist prior to assigning weights to edges.
Fixes: https://github.com/llvm/llvm-project/issues/161137.
---
Full diff: https://github.com/llvm/llvm-project/pull/161153.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+21-22)
- (modified) llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll (+19)
``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 2d84b4ae1ba5c..216bdf4eb9efb 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -84,7 +84,6 @@
#include <cstdint>
#include <iterator>
#include <map>
-#include <numeric>
#include <optional>
#include <set>
#include <tuple>
@@ -6356,25 +6355,25 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
if (DefaultResult) {
Value *ValueCompare =
Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp");
- SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect(
- ValueCompare, ResultVector[1].first, DefaultResult, "switch.select"));
- SelectValue = SelectValueInst;
- if (HasBranchWeights) {
+ SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first,
+ DefaultResult, "switch.select");
+ if (auto *SI = dyn_cast<SelectInst>(SelectValue);
+ SI && HasBranchWeights) {
// We start with 3 probabilities, where the numerator is the
// corresponding BranchWeights[i], and the denominator is the sum over
// BranchWeights. We want the probability and negative probability of
// Condition == SecondCase.
assert(BranchWeights.size() == 3);
- setBranchWeights(SelectValueInst, BranchWeights[2],
+ setBranchWeights(SI, BranchWeights[2],
BranchWeights[0] + BranchWeights[1],
/*IsExpected=*/false);
}
}
Value *ValueCompare =
Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp");
- SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect(
- ValueCompare, ResultVector[0].first, SelectValue, "switch.select"));
- if (HasBranchWeights) {
+ Value *Ret = Builder.CreateSelect(ValueCompare, ResultVector[0].first,
+ SelectValue, "switch.select");
+ if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
// We may have had a DefaultResult. Base the position of the first and
// second's branch weights accordingly. Also the proability that Condition
// != FirstCase needs to take that into account.
@@ -6382,7 +6381,7 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
size_t FirstCasePos = (Condition != nullptr);
size_t SecondCasePos = FirstCasePos + 1;
uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0;
- setBranchWeights(Ret, BranchWeights[FirstCasePos],
+ setBranchWeights(SI, BranchWeights[FirstCasePos],
DefaultCase + BranchWeights[SecondCasePos],
/*IsExpected=*/false);
}
@@ -6422,13 +6421,13 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
Value *And = Builder.CreateAnd(Condition, AndMask);
Value *Cmp = Builder.CreateICmpEQ(
And, Constant::getIntegerValue(And->getType(), AndMask));
- SelectInst *Ret = cast<SelectInst>(
- Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
- if (HasBranchWeights) {
+ Value *Ret =
+ Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
+ if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
// We know there's a Default case. We base the resulting branch
// weights off its probability.
assert(BranchWeights.size() >= 2);
- setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
+ setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
BranchWeights[0], /*IsExpected=*/false);
}
return Ret;
@@ -6448,11 +6447,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and");
Value *Cmp = Builder.CreateICmpEQ(
And, Constant::getNullValue(And->getType()), "switch.selectcmp");
- SelectInst *Ret = cast<SelectInst>(
- Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
- if (HasBranchWeights) {
+ Value *Ret =
+ Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
+ if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
assert(BranchWeights.size() >= 2);
- setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
+ setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
BranchWeights[0], /*IsExpected=*/false);
}
return Ret;
@@ -6466,11 +6465,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1],
"switch.selectcmp.case2");
Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp");
- SelectInst *Ret = cast<SelectInst>(
- Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
- if (HasBranchWeights) {
+ Value *Ret =
+ Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
+ if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
assert(BranchWeights.size() >= 2);
- setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
+ setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
BranchWeights[0], /*IsExpected=*/false);
}
return Ret;
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll b/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll
index 39703e9b53b6b..9d78b97c204a8 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll
@@ -755,6 +755,25 @@ bb3:
ret i1 %phi
}
+define i32 @negative_constfold_select() {
+; CHECK-LABEL: @negative_constfold_select(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: ret i32 poison
+;
+entry:
+ switch i32 poison, label %default [
+ i32 0, label %bb
+ i32 2, label %bb
+ ]
+
+bb:
+ br label %default
+
+default:
+ %ret = phi i32 [ poison, %entry ], [ poison, %bb ]
+ ret i32 %ret
+}
+
!0 = !{!"function_entry_count", i64 1000}
!1 = !{!"branch_weights", i32 3, i32 5, i32 7}
!2 = !{!"branch_weights", i32 3, i32 5, i32 7, i32 11, i32 13}
``````````
</details>
https://github.com/llvm/llvm-project/pull/161153
More information about the llvm-commits
mailing list