[llvm] [SimplifyCFG] Prevent merging cbranch to cbranch if the branch probability from the first to second is too low. (PR #69375)

Valery Pykhtin via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 04:47:59 PDT 2023


https://github.com/vpykhtin updated https://github.com/llvm/llvm-project/pull/69375

>From fc1457d4856b3f4ae0a080a451669e344b0b1354 Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Tue, 17 Oct 2023 20:50:48 +0200
Subject: [PATCH 1/4] [SimplifyCFG] Add test on prevent merging cbranch to
 cbranch if the branch probabililty from the first to second is too low.

---
 .../SimplifyCFG/branch-cond-dont-merge.ll     | 59 +++++++++++++++++++
 1 file changed, 59 insertions(+)
 create mode 100644 llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll

diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
new file mode 100644
index 000000000000000..b62bb825ab30b62
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -0,0 +1,59 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=simplifycfg -simplifycfg-cbranch-to-cbranch-weight-ratio=100 -S | FileCheck %s
+
+declare void @bar()
+declare i1 @uniform_result(i1 %c)
+
+define void @dont_merge_cbranches1(i32 %V) {
+; CHECK-LABEL: @dont_merge_cbranches1(
+; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
+; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
+; CHECK-NEXT:    [[UNIFORM_COND_NOT:%.*]] = xor i1 [[UNIFORM_COND]], true
+; CHECK-NEXT:    [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND_NOT]], i1 true, i1 [[DIVERGENT_COND_NOT]]
+; CHECK-NEXT:    br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF0:![0-9]+]]
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+  %divergent_cond = icmp ne i32 %V, 0
+  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
+  br i1 %uniform_cond, label %bb2, label %exit, !prof !0
+bb2:
+  br i1 %divergent_cond, label %bb3, label %exit
+bb3:
+  call void @bar( )
+  br label %exit
+exit:
+  ret void
+}
+
+define void @dont_merge_cbranches2(i32 %V) {
+; CHECK-LABEL: @dont_merge_cbranches2(
+; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
+; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
+; CHECK-NEXT:    [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND]], i1 true, i1 [[DIVERGENT_COND_NOT]]
+; CHECK-NEXT:    br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF0]]
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+  %divergent_cond = icmp ne i32 %V, 0
+  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
+  br i1 %uniform_cond, label %exit, label %bb2, !prof !1
+bb2:
+  br i1 %divergent_cond, label %bb3, label %exit
+bb3:
+  call void @bar( )
+  br label %exit
+exit:
+  ret void
+}
+
+!0 = !{!"branch_weights", i32 1, i32 1000}
+!1 = !{!"branch_weights", i32 1000, i32 1}

>From 4becda1f7b123460164ee2e3ff666b6906cb0db9 Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Tue, 17 Oct 2023 20:53:04 +0200
Subject: [PATCH 2/4] [SimplifyCFG] Prevent merging cbranch to cbranch if the
 branch probabililty from the first to second is too low.

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp         | 15 +++++++++++++++
 .../SimplifyCFG/branch-cond-dont-merge.ll         | 13 ++++++-------
 2 files changed, 21 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 35fead111aa9666..eafd63bb4257bbb 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -180,6 +180,11 @@ static cl::opt<unsigned> MaxSwitchCasesPerResult(
     "max-switch-cases-per-result", cl::Hidden, cl::init(16),
     cl::desc("Limit cases to analyze when converting a switch to select"));
 
+static cl::opt<unsigned> CondBranchToCondBranchWeightRatio(
+    "simplifycfg-cbranch-to-cbranch-weight-ratio", cl::Hidden, cl::init(10000),
+    cl::desc("Don't merge conditional branches if the branch probability from "
+             "the first to second is below of the reciprocal of this value"));
+
 STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps,
           "Number of switch instructions turned into linear mapping");
@@ -4347,6 +4352,16 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
   if (PBI->getSuccessor(PBIOp) == BB)
     return false;
 
+  // If predecessor's branch probability to BB is too low don't merge branches.
+  SmallVector<uint32_t, 2> PredWeights;
+  if (extractBranchWeights(*PBI, PredWeights)) {
+    auto BIWeight = PredWeights[PBIOp ^ 1];
+    auto CommonWeight = PredWeights[PBIOp];
+    if (BIWeight &&
+        (CommonWeight / BIWeight > CondBranchToCondBranchWeightRatio))
+      return false;
+  }
+
   // Do not perform this transformation if it would require
   // insertion of a large number of select instructions. For targets
   // without predication/cmovs, this is a big pessimization.
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
index b62bb825ab30b62..6dcdfee21932f12 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -8,10 +8,9 @@ define void @dont_merge_cbranches1(i32 %V) {
 ; CHECK-LABEL: @dont_merge_cbranches1(
 ; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
 ; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
-; CHECK-NEXT:    [[UNIFORM_COND_NOT:%.*]] = xor i1 [[UNIFORM_COND]], true
-; CHECK-NEXT:    [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND_NOT]], i1 true, i1 [[DIVERGENT_COND_NOT]]
-; CHECK-NEXT:    br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF0:![0-9]+]]
+; CHECK-NEXT:    br i1 [[UNIFORM_COND]], label [[BB2:%.*]], label [[EXIT:%.*]], !prof [[PROF0:![0-9]+]]
+; CHECK:       bb2:
+; CHECK-NEXT:    br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[EXIT]]
@@ -34,9 +33,9 @@ define void @dont_merge_cbranches2(i32 %V) {
 ; CHECK-LABEL: @dont_merge_cbranches2(
 ; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
 ; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
-; CHECK-NEXT:    [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
-; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND]], i1 true, i1 [[DIVERGENT_COND_NOT]]
-; CHECK-NEXT:    br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF0]]
+; CHECK-NEXT:    br i1 [[UNIFORM_COND]], label [[EXIT:%.*]], label [[BB2:%.*]], !prof [[PROF1:![0-9]+]]
+; CHECK:       bb2:
+; CHECK-NEXT:    br i1 [[DIVERGENT_COND]], label [[BB3:%.*]], label [[EXIT]]
 ; CHECK:       bb3:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[EXIT]]

>From 6fed106982dff26ac629dc5d2ca7ffaa19e24e23 Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Sat, 21 Oct 2023 23:33:53 +0200
Subject: [PATCH 3/4] Replaced threshold value with
 TTI->getPredictableBranchThreshold()

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 18 ++++++------
 .../SimplifyCFG/branch-cond-dont-merge.ll     | 28 ++++++++++++++++++-
 2 files changed, 35 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index eafd63bb4257bbb..1a8d7e6307e8de9 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -180,11 +180,6 @@ static cl::opt<unsigned> MaxSwitchCasesPerResult(
     "max-switch-cases-per-result", cl::Hidden, cl::init(16),
     cl::desc("Limit cases to analyze when converting a switch to select"));
 
-static cl::opt<unsigned> CondBranchToCondBranchWeightRatio(
-    "simplifycfg-cbranch-to-cbranch-weight-ratio", cl::Hidden, cl::init(10000),
-    cl::desc("Don't merge conditional branches if the branch probability from "
-             "the first to second is below of the reciprocal of this value"));
-
 STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps,
           "Number of switch instructions turned into linear mapping");
@@ -4354,11 +4349,14 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
 
   // If predecessor's branch probability to BB is too low don't merge branches.
   SmallVector<uint32_t, 2> PredWeights;
-  if (extractBranchWeights(*PBI, PredWeights)) {
-    auto BIWeight = PredWeights[PBIOp ^ 1];
-    auto CommonWeight = PredWeights[PBIOp];
-    if (BIWeight &&
-        (CommonWeight / BIWeight > CondBranchToCondBranchWeightRatio))
+  if (extractBranchWeights(*PBI, PredWeights) &&
+      (PredWeights[0] + PredWeights[1]) != 0) {
+
+    BranchProbability CommonDestProb = BranchProbability::getBranchProbability(
+        PredWeights[PBIOp], PredWeights[0] + PredWeights[1]);
+
+    BranchProbability Likely = TTI.getPredictableBranchThreshold();
+    if (CommonDestProb >= Likely)
       return false;
   }
 
diff --git a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
index 6dcdfee21932f12..5c21f163826ee25 100644
--- a/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
+++ b/llvm/test/Transforms/SimplifyCFG/branch-cond-dont-merge.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=simplifycfg -simplifycfg-cbranch-to-cbranch-weight-ratio=100 -S | FileCheck %s
+; RUN: opt < %s -passes=simplifycfg -S | FileCheck %s
 
 declare void @bar()
 declare i1 @uniform_result(i1 %c)
@@ -54,5 +54,31 @@ exit:
   ret void
 }
 
+define void @merge_cbranches(i32 %V) {
+; CHECK-LABEL: @merge_cbranches(
+; CHECK-NEXT:    [[DIVERGENT_COND:%.*]] = icmp ne i32 [[V:%.*]], 0
+; CHECK-NEXT:    [[UNIFORM_COND:%.*]] = call i1 @uniform_result(i1 [[DIVERGENT_COND]])
+; CHECK-NEXT:    [[DIVERGENT_COND_NOT:%.*]] = xor i1 [[DIVERGENT_COND]], true
+; CHECK-NEXT:    [[BRMERGE:%.*]] = select i1 [[UNIFORM_COND]], i1 true, i1 [[DIVERGENT_COND_NOT]]
+; CHECK-NEXT:    br i1 [[BRMERGE]], label [[EXIT:%.*]], label [[BB3:%.*]], !prof [[PROF2:![0-9]+]]
+; CHECK:       bb3:
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+  %divergent_cond = icmp ne i32 %V, 0
+  %uniform_cond = call i1 @uniform_result(i1 %divergent_cond)
+  br i1 %uniform_cond, label %exit, label %bb2, !prof !2
+bb2:
+  br i1 %divergent_cond, label %bb3, label %exit
+bb3:
+  call void @bar( )
+  br label %exit
+exit:
+  ret void
+}
+
 !0 = !{!"branch_weights", i32 1, i32 1000}
 !1 = !{!"branch_weights", i32 1000, i32 1}
+!2 = !{!"branch_weights", i32 3, i32 2}

>From 4a18667728f4b6076017479b41a29c21f2925a6e Mon Sep 17 00:00:00 2001
From: Valery Pykhtin <valery.pykhtin at gmail.com>
Date: Fri, 27 Oct 2023 13:47:42 +0200
Subject: [PATCH 4/4] added MD_unpredictable check

---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 1a8d7e6307e8de9..adb14a9085aebf5 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -4349,7 +4349,8 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
 
   // If predecessor's branch probability to BB is too low don't merge branches.
   SmallVector<uint32_t, 2> PredWeights;
-  if (extractBranchWeights(*PBI, PredWeights) &&
+  if (!PBI->getMetadata(LLVMContext::MD_unpredictable) &&
+      extractBranchWeights(*PBI, PredWeights) &&
       (PredWeights[0] + PredWeights[1]) != 0) {
 
     BranchProbability CommonDestProb = BranchProbability::getBranchProbability(



More information about the llvm-commits mailing list