[llvm] [SimplifyCFG] Ensure selects have not been constant folded in `foldSwitchToSelect` (PR #161153)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 29 02:04:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Antonio Frighetto (antoniofrighetto)

<details>
<summary>Changes</summary>

Make sure selects do exist prior to assigning weights to edges.

Fixes: https://github.com/llvm/llvm-project/issues/161137.

---
Full diff: https://github.com/llvm/llvm-project/pull/161153.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+21-22) 
- (modified) llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll (+19) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 2d84b4ae1ba5c..216bdf4eb9efb 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -84,7 +84,6 @@
 #include <cstdint>
 #include <iterator>
 #include <map>
-#include <numeric>
 #include <optional>
 #include <set>
 #include <tuple>
@@ -6356,25 +6355,25 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
     if (DefaultResult) {
       Value *ValueCompare =
           Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp");
-      SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect(
-          ValueCompare, ResultVector[1].first, DefaultResult, "switch.select"));
-      SelectValue = SelectValueInst;
-      if (HasBranchWeights) {
+      SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first,
+                                         DefaultResult, "switch.select");
+      if (auto *SI = dyn_cast<SelectInst>(SelectValue);
+          SI && HasBranchWeights) {
         // We start with 3 probabilities, where the numerator is the
         // corresponding BranchWeights[i], and the denominator is the sum over
         // BranchWeights. We want the probability and negative probability of
         // Condition == SecondCase.
         assert(BranchWeights.size() == 3);
-        setBranchWeights(SelectValueInst, BranchWeights[2],
+        setBranchWeights(SI, BranchWeights[2],
                          BranchWeights[0] + BranchWeights[1],
                          /*IsExpected=*/false);
       }
     }
     Value *ValueCompare =
         Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp");
-    SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect(
-        ValueCompare, ResultVector[0].first, SelectValue, "switch.select"));
-    if (HasBranchWeights) {
+    Value *Ret = Builder.CreateSelect(ValueCompare, ResultVector[0].first,
+                                      SelectValue, "switch.select");
+    if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
       // We may have had a DefaultResult. Base the position of the first and
       // second's branch weights accordingly. Also the proability that Condition
       // != FirstCase needs to take that into account.
@@ -6382,7 +6381,7 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
       size_t FirstCasePos = (Condition != nullptr);
       size_t SecondCasePos = FirstCasePos + 1;
       uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0;
-      setBranchWeights(Ret, BranchWeights[FirstCasePos],
+      setBranchWeights(SI, BranchWeights[FirstCasePos],
                        DefaultCase + BranchWeights[SecondCasePos],
                        /*IsExpected=*/false);
     }
@@ -6422,13 +6421,13 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
           Value *And = Builder.CreateAnd(Condition, AndMask);
           Value *Cmp = Builder.CreateICmpEQ(
               And, Constant::getIntegerValue(And->getType(), AndMask));
-          SelectInst *Ret = cast<SelectInst>(
-              Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
-          if (HasBranchWeights) {
+          Value *Ret =
+              Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
+          if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
             // We know there's a Default case. We base the resulting branch
             // weights off its probability.
             assert(BranchWeights.size() >= 2);
-            setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
+            setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
                              BranchWeights[0], /*IsExpected=*/false);
           }
           return Ret;
@@ -6448,11 +6447,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
         Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and");
         Value *Cmp = Builder.CreateICmpEQ(
             And, Constant::getNullValue(And->getType()), "switch.selectcmp");
-        SelectInst *Ret = cast<SelectInst>(
-            Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
-        if (HasBranchWeights) {
+        Value *Ret =
+            Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
+        if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
           assert(BranchWeights.size() >= 2);
-          setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
+          setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
                            BranchWeights[0], /*IsExpected=*/false);
         }
         return Ret;
@@ -6466,11 +6465,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
       Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1],
                                          "switch.selectcmp.case2");
       Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp");
-      SelectInst *Ret = cast<SelectInst>(
-          Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
-      if (HasBranchWeights) {
+      Value *Ret =
+          Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
+      if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
         assert(BranchWeights.size() >= 2);
-        setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
+        setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
                          BranchWeights[0], /*IsExpected=*/false);
       }
       return Ret;
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll b/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll
index 39703e9b53b6b..9d78b97c204a8 100644
--- a/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll
+++ b/llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll
@@ -755,6 +755,25 @@ bb3:
   ret i1 %phi
 }
 
+define i32 @negative_constfold_select() {
+; CHECK-LABEL: @negative_constfold_select(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    ret i32 poison
+;
+entry:
+  switch i32 poison, label %default [
+  i32 0, label %bb
+  i32 2, label %bb
+  ]
+
+bb:
+  br label %default
+
+default:
+  %ret = phi i32 [ poison, %entry ], [ poison, %bb ]
+  ret i32 %ret
+}
+
 !0 = !{!"function_entry_count", i64 1000}
 !1 = !{!"branch_weights", i32 3, i32 5, i32 7}
 !2 = !{!"branch_weights", i32 3, i32 5, i32 7, i32 11, i32 13}

``````````

</details>


https://github.com/llvm/llvm-project/pull/161153


More information about the llvm-commits mailing list