[llvm] [SimplfyCFG] Set `MD_prof` for `select` used for certain conditional simplifications (PR #154426)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 27 10:13:13 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/154426

>From 4fea73e3f0a3be82ee922952447309edb6716c92 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 19 Aug 2025 14:35:01 -0700
Subject: [PATCH] [SimplfyCFG] Set `MD_prof` for `select` used for certain
 conditional simplifications

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 27 ++++++++++++++-
 .../SimplifyCFG/branch-fold-threshold.ll      | 29 +++++++++++-----
 .../Transforms/SimplifyCFG/branch-fold.ll     | 18 +++++++---
 .../SimplifyCFG/preserve-branchweights.ll     | 34 +++++++++----------
 4 files changed, 76 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index ef110a6922f05..e3ccafc56c6ed 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
     cl::desc("Limit number of blocks a define in a threaded block is allowed "
              "to be live in"));
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps,
           "Number of switch instructions turned into linear mapping");
@@ -330,6 +332,13 @@ class SimplifyCFGOpt {
   }
 };
 
+bool isSelectInRoleOfConjunctionOrDisjunction(const SelectInst *SI) {
+  return ((isa<ConstantInt>(SI->getTrueValue()) &&
+           (dyn_cast<ConstantInt>(SI->getTrueValue())->isOne())) ||
+          (isa<ConstantInt>(SI->getFalseValue()) &&
+           (dyn_cast<ConstantInt>(SI->getFalseValue())->isNullValue())));
+}
+
 } // end anonymous namespace
 
 /// Return true if all the PHI nodes in the basic block \p BB
@@ -4024,6 +4033,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
 
   // Try to update branch weights.
   uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight;
+  SmallVector<uint32_t, 2> MDWeights;
   if (extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight,
                              SuccTrueWeight, SuccFalseWeight)) {
     SmallVector<uint64_t, 8> NewWeights;
@@ -4054,7 +4064,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
     // Halve the weights if any of them cannot fit in an uint32_t
     fitWeights(NewWeights);
 
-    SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(), NewWeights.end());
+    append_range(MDWeights, NewWeights);
     setBranchWeights(PBI, MDWeights[0], MDWeights[1], /*IsExpected=*/false);
 
     // TODO: If BB is reachable from all paths through PredBlock, then we
@@ -4091,6 +4101,13 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
   Value *BICond = VMap[BI->getCondition()];
   PBI->setCondition(
       createLogicalOp(Builder, Opc, PBI->getCondition(), BICond, "or.cond"));
+  if (!ProfcheckDisableMetadataFixes)
+    if (auto *SI = dyn_cast<SelectInst>(PBI->getCondition()))
+      if (!MDWeights.empty()) {
+        assert(isSelectInRoleOfConjunctionOrDisjunction(SI));
+        setBranchWeights(SI, MDWeights[0], MDWeights[1],
+                         /*IsExpected=*/false);
+      }
 
   ++NumFoldBranchToCommonDest;
   return true;
@@ -4789,6 +4806,14 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
     fitWeights(NewWeights);
 
     setBranchWeights(PBI, NewWeights[0], NewWeights[1], /*IsExpected=*/false);
+    // Cond may be a select instruction with the first operand set to "true", or
+    // the second to "false" (see how createLogicalOp works for `and` and `or`)
+    if (!ProfcheckDisableMetadataFixes)
+      if (auto *SI = dyn_cast<SelectInst>(Cond)) {
+        assert(isSelectInRoleOfConjunctionOrDisjunction(SI));
+        setBranchWeights(SI, NewWeights[0], NewWeights[1],
+                         /*IsExpected=*/false);
+      }
   }
 
   // OtherDest may have phi nodes.  If so, add an entry from PBI's
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-fold-threshold.ll b/llvm/test/Transforms/SimplifyCFG/branch-fold-threshold.ll
index 4384847ce156b..71ad069fb8d06 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-fold-threshold.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-fold-threshold.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 5
 ; RUN: opt %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefixes=NORMAL,BASELINE
 ; RUN: opt %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S -bonus-inst-threshold=2 | FileCheck %s --check-prefixes=NORMAL,AGGRESSIVE
 ; RUN: opt %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S -bonus-inst-threshold=4 | FileCheck %s --check-prefixes=WAYAGGRESSIVE
@@ -11,12 +11,12 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
 ; BASELINE-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]], ptr [[INPUT:%.*]]) {
 ; BASELINE-NEXT:  [[ENTRY:.*]]:
 ; BASELINE-NEXT:    [[CMP:%.*]] = icmp sgt i32 [[D]], 3
-; BASELINE-NEXT:    br i1 [[CMP]], label %[[COND_END:.*]], label %[[LOR_LHS_FALSE:.*]]
+; BASELINE-NEXT:    br i1 [[CMP]], label %[[COND_END:.*]], label %[[LOR_LHS_FALSE:.*]], !prof [[PROF0:![0-9]+]]
 ; BASELINE:       [[LOR_LHS_FALSE]]:
 ; BASELINE-NEXT:    [[MUL:%.*]] = shl i32 [[C]], 1
 ; BASELINE-NEXT:    [[ADD:%.*]] = add nsw i32 [[MUL]], [[A]]
 ; BASELINE-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[ADD]], [[B]]
-; BASELINE-NEXT:    br i1 [[CMP1]], label %[[COND_FALSE:.*]], label %[[COND_END]]
+; BASELINE-NEXT:    br i1 [[CMP1]], label %[[COND_FALSE:.*]], label %[[COND_END]], !prof [[PROF1:![0-9]+]]
 ; BASELINE:       [[COND_FALSE]]:
 ; BASELINE-NEXT:    [[TMP0:%.*]] = load i32, ptr [[INPUT]], align 4
 ; BASELINE-NEXT:    br label %[[COND_END]]
@@ -31,8 +31,8 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
 ; AGGRESSIVE-NEXT:    [[MUL:%.*]] = shl i32 [[C]], 1
 ; AGGRESSIVE-NEXT:    [[ADD:%.*]] = add nsw i32 [[MUL]], [[A]]
 ; AGGRESSIVE-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[ADD]], [[B]]
-; AGGRESSIVE-NEXT:    [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false
-; AGGRESSIVE-NEXT:    br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]]
+; AGGRESSIVE-NEXT:    [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false, !prof [[PROF0:![0-9]+]]
+; AGGRESSIVE-NEXT:    br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]], !prof [[PROF0]]
 ; AGGRESSIVE:       [[COND_FALSE]]:
 ; AGGRESSIVE-NEXT:    [[TMP0:%.*]] = load i32, ptr [[INPUT]], align 4
 ; AGGRESSIVE-NEXT:    br label %[[COND_END]]
@@ -47,8 +47,8 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
 ; WAYAGGRESSIVE-NEXT:    [[MUL:%.*]] = shl i32 [[C]], 1
 ; WAYAGGRESSIVE-NEXT:    [[ADD:%.*]] = add nsw i32 [[MUL]], [[A]]
 ; WAYAGGRESSIVE-NEXT:    [[CMP1:%.*]] = icmp slt i32 [[ADD]], [[B]]
-; WAYAGGRESSIVE-NEXT:    [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false
-; WAYAGGRESSIVE-NEXT:    br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]]
+; WAYAGGRESSIVE-NEXT:    [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false, !prof [[PROF0:![0-9]+]]
+; WAYAGGRESSIVE-NEXT:    br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]], !prof [[PROF0]]
 ; WAYAGGRESSIVE:       [[COND_FALSE]]:
 ; WAYAGGRESSIVE-NEXT:    [[TMP0:%.*]] = load i32, ptr [[INPUT]], align 4
 ; WAYAGGRESSIVE-NEXT:    br label %[[COND_END]]
@@ -58,13 +58,13 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
 ;
 entry:
   %cmp = icmp sgt i32 %d, 3
-  br i1 %cmp, label %cond.end, label %lor.lhs.false
+  br i1 %cmp, label %cond.end, label %lor.lhs.false, !prof !0
 
 lor.lhs.false:
   %mul = shl i32 %c, 1
   %add = add nsw i32 %mul, %a
   %cmp1 = icmp slt i32 %add, %b
-  br i1 %cmp1, label %cond.false, label %cond.end
+  br i1 %cmp1, label %cond.false, label %cond.end, !prof !1
 
 cond.false:
   %0 = load i32, ptr %input, align 4
@@ -160,3 +160,14 @@ cond.end:
   %cond = phi i32 [ %0, %cond.false ], [ 0, %lor.lhs.false ],[ 0, %pred_a ],[ 0, %pred_b ]
   ret i32 %cond
 }
+
+!0 = !{!"branch_weights", i32 7, i32 11}
+!1 = !{!"branch_weights", i32 13, i32 5}
+;.
+; BASELINE: [[PROF0]] = !{!"branch_weights", i32 7, i32 11}
+; BASELINE: [[PROF1]] = !{!"branch_weights", i32 13, i32 5}
+;.
+; AGGRESSIVE: [[PROF0]] = !{!"branch_weights", i32 143, i32 181}
+;.
+; WAYAGGRESSIVE: [[PROF0]] = !{!"branch_weights", i32 143, i32 181}
+;.
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-fold.ll b/llvm/test/Transforms/SimplifyCFG/branch-fold.ll
index 2f5fb4f33013d..b51398dd6222d 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-fold.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-fold.ll
@@ -1,12 +1,12 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
 ; RUN: opt < %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s
 
 define void @test(ptr %P, ptr %Q, i1 %A, i1 %B) {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[A_NOT:%.*]] = xor i1 [[A:%.*]], true
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[A_NOT]], i1 true, i1 [[B:%.*]]
-; CHECK-NEXT:    br i1 [[BRMERGE]], label [[B:%.*]], label [[COMMON_RET:%.*]]
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[A_NOT]], i1 true, i1 [[B:%.*]], !prof [[PROF0:![0-9]+]]
+; CHECK-NEXT:    br i1 [[BRMERGE]], label [[B:%.*]], label [[COMMON_RET:%.*]], !prof [[PROF0]]
 ; CHECK:       common.ret:
 ; CHECK-NEXT:    ret void
 ; CHECK:       b:
@@ -15,9 +15,9 @@ define void @test(ptr %P, ptr %Q, i1 %A, i1 %B) {
 ;
 
 entry:
-  br i1 %A, label %a, label %b
+  br i1 %A, label %a, label %b, !prof !0
 a:
-  br i1 %B, label %b, label %c
+  br i1 %B, label %b, label %c, !prof !1
 b:
   store i32 123, ptr %P
   ret void
@@ -146,3 +146,11 @@ Succ:
 }
 
 declare void @dummy()
+
+!0 = !{!"branch_weights", i32 3, i32 7}
+!1 = !{!"branch_weights", i32 11, i32 4}
+;.
+; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind ssp memory(read) uwtable }
+;.
+; CHECK: [[PROF0]] = !{!"branch_weights", i32 138, i32 12}
+;.
diff --git a/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll b/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll
index 0f78e236b4248..7d1153d5c0a0e 100644
--- a/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll
+++ b/llvm/test/Transforms/SimplifyCFG/preserve-branchweights.ll
@@ -268,8 +268,8 @@ define void @test7(i1 %a, i1 %b) {
 ; CHECK-LABEL: @test7(
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[C:%.*]] = or i1 [[B:%.*]], false
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[C]]
-; CHECK-NEXT:    br i1 [[BRMERGE]], label [[Y:%.*]], label [[Z:%.*]], !prof [[PROF6:![0-9]+]]
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[C]], !prof [[PROF6:![0-9]+]]
+; CHECK-NEXT:    br i1 [[BRMERGE]], label [[Y:%.*]], label [[Z:%.*]], !prof [[PROF6]]
 ; CHECK:       common.ret:
 ; CHECK-NEXT:    ret void
 ; CHECK:       Y:
@@ -557,9 +557,9 @@ return:
 define i32 @SimplifyCondBranchToCondBranch(i1 %cmpa, i1 %cmpb) {
 ; CHECK-LABEL: @SimplifyCondBranchToCondBranch(
 ; CHECK-NEXT:  block1:
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMPA:%.*]], i1 true, i1 [[CMPB:%.*]]
-; CHECK-NEXT:    [[DOTMUX:%.*]] = select i1 [[CMPA]], i32 0, i32 2, !prof [[PROF13:![0-9]+]]
-; CHECK-NEXT:    [[OUTVAL:%.*]] = select i1 [[BRMERGE]], i32 [[DOTMUX]], i32 1, !prof [[PROF14:![0-9]+]]
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMPA:%.*]], i1 true, i1 [[CMPB:%.*]], !prof [[PROF13:![0-9]+]]
+; CHECK-NEXT:    [[DOTMUX:%.*]] = select i1 [[CMPA]], i32 0, i32 2, !prof [[PROF14:![0-9]+]]
+; CHECK-NEXT:    [[OUTVAL:%.*]] = select i1 [[BRMERGE]], i32 [[DOTMUX]], i32 1, !prof [[PROF13]]
 ; CHECK-NEXT:    ret i32 [[OUTVAL]]
 ;
 block1:
@@ -584,9 +584,9 @@ define i32 @SimplifyCondBranchToCondBranchSwap(i1 %cmpa, i1 %cmpb) {
 ; CHECK-NEXT:  block1:
 ; CHECK-NEXT:    [[CMPA_NOT:%.*]] = xor i1 [[CMPA:%.*]], true
 ; CHECK-NEXT:    [[CMPB_NOT:%.*]] = xor i1 [[CMPB:%.*]], true
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMPA_NOT]], i1 true, i1 [[CMPB_NOT]]
-; CHECK-NEXT:    [[DOTMUX:%.*]] = select i1 [[CMPA_NOT]], i32 0, i32 2, !prof [[PROF15:![0-9]+]]
-; CHECK-NEXT:    [[OUTVAL:%.*]] = select i1 [[BRMERGE]], i32 [[DOTMUX]], i32 1, !prof [[PROF16:![0-9]+]]
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMPA_NOT]], i1 true, i1 [[CMPB_NOT]], !prof [[PROF15:![0-9]+]]
+; CHECK-NEXT:    [[DOTMUX:%.*]] = select i1 [[CMPA_NOT]], i32 0, i32 2, !prof [[PROF16:![0-9]+]]
+; CHECK-NEXT:    [[OUTVAL:%.*]] = select i1 [[BRMERGE]], i32 [[DOTMUX]], i32 1, !prof [[PROF15]]
 ; CHECK-NEXT:    ret i32 [[OUTVAL]]
 ;
 block1:
@@ -609,9 +609,9 @@ define i32 @SimplifyCondBranchToCondBranchSwapMissingWeight(i1 %cmpa, i1 %cmpb)
 ; CHECK-NEXT:  block1:
 ; CHECK-NEXT:    [[CMPA_NOT:%.*]] = xor i1 [[CMPA:%.*]], true
 ; CHECK-NEXT:    [[CMPB_NOT:%.*]] = xor i1 [[CMPB:%.*]], true
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMPA_NOT]], i1 true, i1 [[CMPB_NOT]]
-; CHECK-NEXT:    [[DOTMUX:%.*]] = select i1 [[CMPA_NOT]], i32 0, i32 2, !prof [[PROF17:![0-9]+]]
-; CHECK-NEXT:    [[OUTVAL:%.*]] = select i1 [[BRMERGE]], i32 [[DOTMUX]], i32 1, !prof [[PROF18:![0-9]+]]
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[CMPA_NOT]], i1 true, i1 [[CMPB_NOT]], !prof [[PROF17:![0-9]+]]
+; CHECK-NEXT:    [[DOTMUX:%.*]] = select i1 [[CMPA_NOT]], i32 0, i32 2, !prof [[PROF18:![0-9]+]]
+; CHECK-NEXT:    [[OUTVAL:%.*]] = select i1 [[BRMERGE]], i32 [[DOTMUX]], i32 1, !prof [[PROF17]]
 ; CHECK-NEXT:    ret i32 [[OUTVAL]]
 ;
 block1:
@@ -1114,12 +1114,12 @@ exit:
 ; CHECK: [[PROF10]] = !{!"branch_weights", i32 8, i32 33}
 ; CHECK: [[PROF11]] = !{!"branch_weights", i32 112017436, i32 -735157296}
 ; CHECK: [[PROF12]] = !{!"branch_weights", i32 3, i32 5}
-; CHECK: [[PROF13]] = !{!"branch_weights", i32 22, i32 12}
-; CHECK: [[PROF14]] = !{!"branch_weights", i32 34, i32 21}
-; CHECK: [[PROF15]] = !{!"branch_weights", i32 33, i32 14}
-; CHECK: [[PROF16]] = !{!"branch_weights", i32 47, i32 8}
-; CHECK: [[PROF17]] = !{!"branch_weights", i32 6, i32 2}
-; CHECK: [[PROF18]] = !{!"branch_weights", i32 8, i32 2}
+; CHECK: [[PROF13]] = !{!"branch_weights", i32 34, i32 21}
+; CHECK: [[PROF14]] = !{!"branch_weights", i32 22, i32 12}
+; CHECK: [[PROF15]] = !{!"branch_weights", i32 47, i32 8}
+; CHECK: [[PROF16]] = !{!"branch_weights", i32 33, i32 14}
+; CHECK: [[PROF17]] = !{!"branch_weights", i32 8, i32 2}
+; CHECK: [[PROF18]] = !{!"branch_weights", i32 6, i32 2}
 ; CHECK: [[PROF19]] = !{!"branch_weights", i32 99, i32 1}
 ; CHECK: [[PROF20]] = !{!"branch_weights", i32 1, i32 99}
 ; CHECK: [[PROF21]] = !{!"branch_weights", i32 199, i32 1}



More information about the llvm-commits mailing list